Skip to content

Commit 23e1ca5

Browse files
authored
Changing logic to deal with graphs with derived quantization spec (#16357)
Summary: We want to add a test for `default_addmm_A8W8` to fully finish testing `CadenceDefaultQuantizer`. However there are a couple changes we need to make to the testing function. ## Change 1: We allow passing `None` in the vec of `QuantizationSpec` This is because the addmm op has 3 inputs: `bias`, `mat1`, `mat2`. The bias uses a `DerivedQuantizationSpec`, which is dynamically constructed with references to the actual graph nodes (`mat1` and `mat2`). We can't construct an identical `DerivedQuantizationSpec` in the test because we'd need to reference the exact same node objects that the quantizer creates internally. Since we can't compare it directly, we use `None` to skip validation for that input. If `mat1` and `mat2` are quantized correctly, the derived bias spec will be correct too. https://www.internalfb.com/code/fbsource/[2cfdb40fd8b628da2f46366115516408cfb9f50f]/xplat/executorch/backends/cadence/aot/quantizer/patterns.py?lines=91-103 ## Change 2: We changed how we iterate through `input_qspec_map` `input_qspec_map` is a dictionary mapping input nodes to their `qspecs`. The iteration order depends on insertion order, which follows how the quantizer processes `PartitionAnchors`. Each `QuantizationPattern` implements a `get_anchors()` method that returns a `PartitionAnchors` describing which arguments are inputs, weights, biases and nodes. This is relevant because for `addmm`, the `PartitionAnchors` lists them as `inputs=[(node, 1)], weights=[(node, 2)], biases=[(node, 0, ...)]. ` So the map might iterate in order `mat1, mat2, bias` (args indices 1, 2, 0) rather than `bias, mat1, mat2` (args indices 0, 1, 2). This means that our previous way of iterating wouldn't work. Thus, we now use the following way to iterate: ``` for input_node, input_qspec in annotation.input_qspec_map.items(): // Find the index of this input node in the op's args arg_index = None for i, arg in enumerate(op_node.args): if arg is input_node: arg_index = i break self.assertIsNotNone( arg_index, f"Input node {input_node} not found in op_node.args", ) # Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec) if expected_input_qspecs[arg_index] is not None: self.assertEqual( input_qspec, expected_input_qspecs[arg_index], f"Input qspec mismatch at arg index {arg_index}", ) ``` The new code looks up which argument index each input_node corresponds to by searching through `op_node.args`, rather than assuming the enumeration index i matches the argument position. Differential Revision: D88955761
1 parent 7815c38 commit 23e1ca5

File tree

0 file changed

+0
-0
lines changed

    0 file changed

    +0
    -0
    lines changed

    0 commit comments

    Comments
     (0)