Skip to content

Commit d4d24ec

Browse files
authored
Adding mixed quantization support
Differential Revision: D81519735 Pull Request resolved: #14134
1 parent eca4fc6 commit d4d24ec

File tree

3 files changed

+142
-98
lines changed

3 files changed

+142
-98
lines changed

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
471471
pattern.partition_types(),
472472
)
473473
for fused_partition in fused_partitions:
474-
anchors = pattern.get_anchors(graph_module, fused_partition)
474+
anchors, op_node = pattern.get_anchors(graph_module, fused_partition)
475475
if not anchors or anchors.empty:
476476
continue
477477
if any(self.is_fused(p.nodes) for p in fused_partition):
@@ -512,13 +512,10 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
512512
bias_inputs = [node.args[0] for node in dequants_biases]
513513
other_inputs = [node.args[idx] for node, idx in anchors.others]
514514

515-
# The node is the first index of the list and first of the tuple
516-
anchor_output_node = anchors.output[0][0]
515+
assert op_node is not None, "op_node is None"
516+
quant_node = list(op_node.users.keys())[0]
517517

518-
assert len(anchor_output_node.users) == 1
519-
quant_node = list(anchor_output_node.users.keys())[0]
520-
521-
with graph_module.graph.inserting_after(anchor_output_node):
518+
with graph_module.graph.inserting_after(op_node):
522519
args = tuple(
523520
inputs_inputs + weights_inputs + other_inputs + bias_inputs
524521
)
@@ -532,7 +529,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
532529
)
533530
elif isinstance(pattern, CatPattern):
534531
args, kwargs = get_args_and_kwargs_cat(
535-
inputs_inputs, other_inputs, anchor_output_node
532+
inputs_inputs, other_inputs, op_node
536533
)
537534
elif isinstance(pattern, ConvReluPatterns):
538535
# For ConvReLU, we are fusing Conv+ReLU
@@ -563,7 +560,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
563560
dequants_weights,
564561
bias_inputs,
565562
quant_node,
566-
anchor_output_node,
563+
op_node,
567564
)
568565
elif isinstance(pattern, LinearPattern):
569566
args, kwargs = get_args_and_kwargs_linear(
@@ -618,20 +615,28 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
618615
inputs_inputs,
619616
dequants_inputs,
620617
quant_node,
621-
anchor_output_node,
618+
op_node,
622619
)
620+
623621
fused = graph_module.graph.call_function(
624622
pattern.replacement_op(),
625623
args,
626624
kwargs,
627625
)
628-
fused.meta = quant_node.meta
629-
quant_node.replace_all_uses_with(fused)
626+
627+
if len(anchors.output) > 0:
628+
fused.meta = quant_node.meta
629+
quant_node.replace_all_uses_with(fused)
630+
else:
631+
fused.meta = op_node.meta
632+
op_node.replace_all_uses_with(fused)
633+
if op_node.op == "output":
634+
_ = graph_module.graph.output((fused,))
630635

631636
legalize_graph(graph_module)
632637
graph_module.graph.eliminate_dead_code()
633-
# pyre-fixme[7]: Incompatible return type
634638
graph_module.recompile()
639+
return PassResult(graph_module, True)
635640

636641
@classmethod
637642
# pyre-ignore[2]: Parameter `nodes` has no type specified

0 commit comments

Comments
 (0)