Skip to content

Commit dbe398d

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Update FuseMulScalarIntoDequantPass, FuseMulTensorIntoQuantPass, and FuseMulTensorIntoDequantPass to use new pass interface
Summary: Updates - FuseMulScalarIntoDequantPass - FuseMulTensorIntoDequantPass - FuseMulTensorIntoQuantPass Differential Revision: D87887841
1 parent 4160430 commit dbe398d

File tree

2 files changed

+83
-72
lines changed

2 files changed

+83
-72
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 53 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -818,30 +818,32 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
818818

819819

820820
@register_cadence_pass(CadencePassAttribute(opt_level=1))
821-
class FuseMulScalarIntoDequantPass(ExportPass):
821+
class FuseMulScalarIntoDequantPass(RemoveOrReplacePassInterface):
822822
"""
823823
Looks for the pattern where aten.mul.Scalar is multiplying the
824824
outputs of dequantize. If found, updates the dequant scale
825825
to reflect the multiplication and removes the mul node.
826826
"""
827827

828-
def attempt_fusion(
829-
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
830-
) -> None:
831-
if node.target not in {
828+
@property
829+
def targets(self) -> list[EdgeOpOverload]:
830+
return [
832831
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
833832
exir_ops.edge.cadence.dequantize_per_tensor.default,
834-
}:
835-
return
833+
]
834+
835+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
836+
# Ensure that the single user of dequant is aten.mul.Scalar
837+
if len(node.users) != 1:
838+
return False
836839

837-
# ensure that the single user of dequant is aten.mul.Scalar
838840
user = list(node.users.keys())[0]
839-
if len(node.users) != 1 or user.target != exir_ops.edge.aten.mul.Scalar:
840-
return
841+
if user.target != exir_ops.edge.aten.mul.Scalar:
842+
return False
841843

842-
# ensure that the other arg to mul is a node (i.e. not a constant)
844+
# Ensure that the other arg to mul is not a node (i.e. it's a constant)
843845
if len(user.args) > 1 and isinstance(user.args[1], torch.fx.Node):
844-
return
846+
return False
845847

846848
new_deq_args = list(node.args)
847849
assert isinstance(node.args[1], Number)
@@ -853,36 +855,36 @@ def attempt_fusion(
853855
f"Fused {node} and {user} into {node}. Updated scale from {node.args[1]} to {new_deq_args[1]}"
854856
)
855857

858+
# Replace all uses of mul with the dequant node
856859
user.replace_all_uses_with(node)
860+
# Update the dequant node's args with the new scale
857861
node.args = tuple(new_deq_args)
858862

859-
graph_module.graph.erase_node(user)
860-
861-
graph_module.recompile()
863+
# Erase the mul node
864+
node.graph.erase_node(user)
862865

863-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
864-
for node in graph_module.graph.nodes:
865-
self.attempt_fusion(graph_module, node)
866-
result = super().call(graph_module)
867-
return result
866+
return True
868867

869868

870869
@register_cadence_pass(CadencePassAttribute(opt_level=1))
871-
class FuseMulTensorIntoQuantPass(ExportPass):
870+
class FuseMulTensorIntoQuantPass(RemoveOrReplacePassInterface):
872871
"""
873872
Looks for the pattern where aten.mul.Tensor is followed by quant node.
874873
If found, updates the quant scale to reflect the multiplication and
875874
removes the mul node.
876875
"""
877876

878-
def attempt_fusion(
879-
self, graph_module: torch.fx.GraphModule, mul_node: torch.fx.Node
880-
) -> None:
881-
if len(mul_node.args) != 2 or len(mul_node.users) != 1:
882-
return
877+
@property
878+
def targets(self) -> list[EdgeOpOverload]:
879+
return [exir_ops.edge.aten.mul.Tensor]
883880

884-
first_arg = cast(torch.fx.Node, mul_node.args[0])
885-
second_arg = cast(torch.fx.Node, mul_node.args[1])
881+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
882+
# Check that mul has exactly 2 args and 1 user
883+
if len(node.args) != 2 or len(node.users) != 1:
884+
return False
885+
886+
first_arg = cast(torch.fx.Node, node.args[0])
887+
second_arg = cast(torch.fx.Node, node.args[1])
886888

887889
input_node = first_arg
888890
full_node = second_arg
@@ -895,20 +897,20 @@ def attempt_fusion(
895897
input_node = second_arg
896898
else:
897899
# Full node is not found, skip.
898-
return
900+
return False
899901

900902
# Ensure that the mul op does not do any broadcasting.
901-
if input_node.meta["val"].shape != mul_node.meta["val"].shape:
902-
return
903+
if input_node.meta["val"].shape != node.meta["val"].shape:
904+
return False
903905

904-
mul_user = list(mul_node.users.keys())[0]
906+
mul_user = list(node.users.keys())[0]
905907

906908
# Ensure only the expected quant ops are using the current mul op.
907909
if mul_user.target not in {
908910
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
909911
exir_ops.edge.cadence.quantize_per_tensor.default,
910912
}:
911-
return
913+
return False
912914

913915
quant_node = mul_user
914916

@@ -927,39 +929,32 @@ def attempt_fusion(
927929
new_scale = float(old_scale) / float(mul_scalar)
928930

929931
logging.debug(
930-
f"Fused {mul_node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}"
932+
f"Fused {node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}"
931933
)
932934

933935
# Update quant node input and scale.
934936
old_quant_input = cast(torch.fx.Node, quant_node.args[0])
935-
new_quant_input = cast(torch.fx.Node, mul_node.args[0])
937+
new_quant_input = input_node
936938
quant_node.replace_input_with(old_quant_input, new_quant_input)
937939
quant_node.update_arg(1, new_scale)
938940

939-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
940-
for node in graph_module.graph.find_nodes(
941-
op="call_function", target=exir_ops.edge.aten.mul.Tensor
942-
):
943-
self.attempt_fusion(graph_module, node)
944-
graph_module.graph.eliminate_dead_code()
945-
return super().call(graph_module)
941+
return True
946942

947943

948944
@register_cadence_pass(CadencePassAttribute(opt_level=1))
949-
class FuseMulTensorIntoDequantPass(ExportPass):
945+
class FuseMulTensorIntoDequantPass(RemoveOrReplacePassInterface):
950946
"""
951947
Looks for the pattern where aten.mul is multiplying the outputs of dequantize
952948
and aten.full, or vice versa. If found, updates the dequant scale to reflect
953949
the multiplication and removes the full and mul nodes.
954950
"""
955951

956-
def attempt_fusion(
957-
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
958-
) -> None:
959-
if node.target != exir_ops.edge.aten.mul.Tensor:
960-
return
952+
@property
953+
def targets(self) -> list[EdgeOpOverload]:
954+
return [exir_ops.edge.aten.mul.Tensor]
961955

962-
# ensure that one of the args to mul is dequantize and the other is aten.full
956+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
957+
# Ensure that one of the args to mul is dequantize and the other is aten.full
963958
dequant_nodes = [
964959
arg
965960
for arg in node.args
@@ -979,14 +974,14 @@ def attempt_fusion(
979974
]
980975

981976
if len(dequant_nodes) != 1 or len(multiplier_nodes) != 1:
982-
return
977+
return False
983978

984979
deq_node = dequant_nodes[0]
985980
mplier_node = multiplier_nodes[0]
986981

987-
# ensure that dequant and full don't have any other users
982+
# Ensure that dequant and full don't have any other users
988983
if len(deq_node.users) > 1 or len(mplier_node.users) > 1:
989-
return
984+
return False
990985

991986
new_deq_args = list(deq_node.args)
992987
assert isinstance(deq_node.args[1], Number)
@@ -998,18 +993,16 @@ def attempt_fusion(
998993
f"Fused {node} and {mplier_node} into {deq_node}. Updated scale from {deq_node.args[1]} to {new_deq_args[1]}"
999994
)
1000995

996+
# Replace all uses of the mul node with the dequant node
1001997
node.replace_all_uses_with(deq_node)
998+
# Update the dequant node's args with the new scale
1002999
deq_node.args = tuple(new_deq_args)
10031000

1004-
graph_module.graph.erase_node(node)
1005-
graph_module.graph.erase_node(mplier_node)
1006-
graph_module.recompile()
1001+
# Erase the mul and full nodes
1002+
node.graph.erase_node(node)
1003+
node.graph.erase_node(mplier_node)
10071004

1008-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
1009-
for node in graph_module.graph.nodes:
1010-
self.attempt_fusion(graph_module, node)
1011-
result = super().call(graph_module)
1012-
return result
1005+
return True
10131006

10141007

10151008
@register_cadence_pass(CadencePassAttribute(opt_level=1))

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,8 @@ def test_fuse_mul_into_dequant(self) -> None:
599599
FULL_VALUE: Final[float] = 3
600600

601601
builder = GraphBuilder()
602-
x = builder.placeholder("x", torch.randn(*INPUT_SHAPE, dtype=torch.float32))
602+
x_input = torch.randint(low=0, high=255, size=INPUT_SHAPE, dtype=torch.uint8)
603+
x = builder.placeholder("x", x_input)
603604
dequant = builder.call_operator(
604605
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
605606
args=(x, DEQUANT_SCALE, 0, 0, 255, torch.uint8),
@@ -614,8 +615,15 @@ def test_fuse_mul_into_dequant(self) -> None:
614615
)
615616
builder.output([mul])
616617
original_graph = builder.get_graph_module()
618+
gm_before = copy.deepcopy(original_graph)
619+
617620
p = FuseMulTensorIntoDequantPass()
618-
converted_graph = cast(PassResult, p(original_graph)).graph_module
621+
result = cast(PassResult, p(original_graph))
622+
self.assertTrue(result.modified)
623+
converted_graph = result.graph_module
624+
625+
# Validate numerical accuracy
626+
validate(gm_before, converted_graph, (x_input,), "FuseMulTensorIntoDequantPass")
619627

620628
# verify that the mul and full ops were removed
621629
self.check_op_counts(
@@ -642,7 +650,8 @@ def test_fuse_mul_scalar_into_dequant(self) -> None:
642650
mul_value = 0.3
643651

644652
builder = GraphBuilder()
645-
x = builder.placeholder("x", torch.randn(2, 3, 4, dtype=torch.float32))
653+
x_input = torch.randn(2, 3, 4, dtype=torch.float32)
654+
x = builder.placeholder("x", x_input)
646655
quant = builder.call_operator(
647656
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
648657
args=(x, 1, 0, -128, 127, torch.int8),
@@ -657,8 +666,15 @@ def test_fuse_mul_scalar_into_dequant(self) -> None:
657666
)
658667
builder.output([mul_scalar])
659668
original_graph = builder.get_graph_module()
669+
gm_before = copy.deepcopy(original_graph)
670+
660671
p = FuseMulScalarIntoDequantPass()
661-
converted_graph = cast(PassResult, p(original_graph)).graph_module
672+
result = cast(PassResult, p(original_graph))
673+
self.assertTrue(result.modified)
674+
converted_graph = result.graph_module
675+
676+
# Validate numerical accuracy
677+
validate(gm_before, converted_graph, (x_input,), "FuseMulScalarIntoDequantPass")
662678

663679
# verify that the mul and full ops were removed
664680
self.check_op_counts(
@@ -684,7 +700,8 @@ def test_fuse_mul_into_quant(self) -> None:
684700
mul_value = 10
685701

686702
builder = GraphBuilder()
687-
x = builder.placeholder("x", torch.randn(4, 32, dtype=torch.float32))
703+
x_input = torch.randn(4, 32, dtype=torch.float32)
704+
x = builder.placeholder("x", x_input)
688705
full = builder.call_operator(
689706
op=exir_ops.edge.aten.full.default,
690707
args=([1], mul_value),
@@ -699,8 +716,15 @@ def test_fuse_mul_into_quant(self) -> None:
699716
)
700717
builder.output([quant])
701718
original_graph = builder.get_graph_module()
719+
gm_before = copy.deepcopy(original_graph)
720+
702721
p = FuseMulTensorIntoQuantPass()
703-
converted_graph = cast(PassResult, p(original_graph)).graph_module
722+
result = cast(PassResult, p(original_graph))
723+
self.assertTrue(result.modified)
724+
converted_graph = result.graph_module
725+
726+
# Validate numerical accuracy
727+
validate(gm_before, converted_graph, (x_input,), "FuseMulTensorIntoQuantPass")
704728

705729
# verify that the mul and full ops were removed
706730
self.check_op_counts(
@@ -720,12 +744,6 @@ def test_fuse_mul_into_quant(self) -> None:
720744
new_quant_scale = node.args[1]
721745
self.assertEqual(new_quant_scale, quant_scale / mul_value)
722746

723-
# verify the math is correct
724-
inp = torch.randn(4, 32, dtype=torch.float32)
725-
original_out = original_graph(inp)[0]
726-
new_out = converted_graph(inp)[0]
727-
assert torch.equal(original_out, new_out)
728-
729747
def test_fuse_then_transpose_pass(self) -> None:
730748
# Create a graph with full -> transpose -> permute -> view.
731749
builder = GraphBuilder()

0 commit comments

Comments
 (0)