Commit 23e1ca5
authored
Changing logic to deal with graphs with derived quantization spec (#16357)
Summary:
We want to add a test for `default_addmm_A8W8` to fully finish testing
`CadenceDefaultQuantizer`. However there are a couple changes we need to
make to the testing function.
## Change 1: We allow passing `None` in the vec of `QuantizationSpec`
This is because the addmm op has 3 inputs: `bias`, `mat1`, `mat2`. The
bias uses a `DerivedQuantizationSpec`, which is dynamically constructed
with references to the actual graph nodes (`mat1` and `mat2`). We can't
construct an identical `DerivedQuantizationSpec` in the test because
we'd need to reference the exact same node objects that the quantizer
creates internally. Since we can't compare it directly, we use `None` to
skip validation for that input. If `mat1` and `mat2` are quantized
correctly, the derived bias spec will be correct too.
https://www.internalfb.com/code/fbsource/[2cfdb40fd8b628da2f46366115516408cfb9f50f]/xplat/executorch/backends/cadence/aot/quantizer/patterns.py?lines=91-103
## Change 2: We changed how we iterate through `input_qspec_map`
`input_qspec_map` is a dictionary mapping input nodes to their `qspecs`.
The iteration order depends on insertion order, which follows how the
quantizer processes `PartitionAnchors`.
Each `QuantizationPattern` implements a `get_anchors()` method that
returns a `PartitionAnchors` describing which arguments are inputs,
weights, biases and nodes. This is relevant because for `addmm`, the
`PartitionAnchors` lists them as `inputs=[(node, 1)], weights=[(node,
2)], biases=[(node, 0, ...)]. ` So the map might iterate in order `mat1,
mat2, bias` (args indices 1, 2, 0) rather than `bias, mat1, mat2` (args
indices 0, 1, 2).
This means that our previous way of iterating wouldn't work. Thus, we
now use the following way to iterate:
```
for input_node, input_qspec in annotation.input_qspec_map.items():
// Find the index of this input node in the op's args
arg_index = None
for i, arg in enumerate(op_node.args):
if arg is input_node:
arg_index = i
break
self.assertIsNotNone(
arg_index,
f"Input node {input_node} not found in op_node.args",
)
# Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec)
if expected_input_qspecs[arg_index] is not None:
self.assertEqual(
input_qspec,
expected_input_qspecs[arg_index],
f"Input qspec mismatch at arg index {arg_index}",
)
```
The new code looks up which argument index each input_node corresponds
to by searching through `op_node.args`, rather than assuming the
enumeration index i matches the argument position.
Differential Revision: D889557611 parent 7815c38 commit 23e1ca5
File tree
0 file changed
+0
-0
lines changed0 file changed
+0
-0
lines changed
0 commit comments