Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
pattern.partition_types(),
)
for fused_partition in fused_partitions:
anchors = pattern.get_anchors(graph_module, fused_partition)
anchors, op_node = pattern.get_anchors(graph_module, fused_partition)
if not anchors or anchors.empty:
continue
if any(self.is_fused(p.nodes) for p in fused_partition):
Expand Down Expand Up @@ -512,13 +512,10 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
bias_inputs = [node.args[0] for node in dequants_biases]
other_inputs = [node.args[idx] for node, idx in anchors.others]

# The node is the first index of the list and first of the tuple
anchor_output_node = anchors.output[0][0]
assert op_node is not None, "op_node is None"
quant_node = list(op_node.users.keys())[0]

assert len(anchor_output_node.users) == 1
quant_node = list(anchor_output_node.users.keys())[0]

with graph_module.graph.inserting_after(anchor_output_node):
with graph_module.graph.inserting_after(op_node):
args = tuple(
inputs_inputs + weights_inputs + other_inputs + bias_inputs
)
Expand All @@ -532,7 +529,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
)
elif isinstance(pattern, CatPattern):
args, kwargs = get_args_and_kwargs_cat(
inputs_inputs, other_inputs, anchor_output_node
inputs_inputs, other_inputs, op_node
)
elif isinstance(pattern, ConvReluPatterns):
# For ConvReLU, we are fusing Conv+ReLU
Expand Down Expand Up @@ -563,7 +560,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
dequants_weights,
bias_inputs,
quant_node,
anchor_output_node,
op_node,
)
elif isinstance(pattern, LinearPattern):
args, kwargs = get_args_and_kwargs_linear(
Expand Down Expand Up @@ -618,20 +615,28 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
inputs_inputs,
dequants_inputs,
quant_node,
anchor_output_node,
op_node,
)

fused = graph_module.graph.call_function(
pattern.replacement_op(),
args,
kwargs,
)
fused.meta = quant_node.meta
quant_node.replace_all_uses_with(fused)

if len(anchors.output) > 0:
fused.meta = quant_node.meta
quant_node.replace_all_uses_with(fused)
else:
fused.meta = op_node.meta
op_node.replace_all_uses_with(fused)
if op_node.op == "output":
_ = graph_module.graph.output((fused,))

legalize_graph(graph_module)
graph_module.graph.eliminate_dead_code()
# pyre-fixme[7]: Incompatible return type
graph_module.recompile()
return PassResult(graph_module, True)

@classmethod
# pyre-ignore[2]: Parameter `nodes` has no type specified
Expand Down
Loading