Skip to content

Commit 5d72a52

Browse files
authored
Merge branch 'pytorch:main' into main
2 parents f6ea28d + 4939b45 commit 5d72a52

File tree

10 files changed

+434
-94
lines changed

10 files changed

+434
-94
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 75 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -819,68 +819,76 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
819819

820820

821821
@register_cadence_pass(CadencePassAttribute(opt_level=1))
822-
class FuseMulScalarIntoDequantPass(ExportPass):
822+
class FuseMulScalarIntoDequantPass(RemoveOrReplacePassInterface):
823823
"""
824824
Looks for the pattern where aten.mul.Scalar is multiplying the
825825
outputs of dequantize. If found, updates the dequant scale
826826
to reflect the multiplication and removes the mul node.
827827
"""
828828

829-
def attempt_fusion(
830-
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
831-
) -> None:
832-
if node.target not in {
833-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
834-
exir_ops.edge.cadence.dequantize_per_tensor.default,
835-
}:
836-
return
829+
@property
830+
def targets(self) -> list[EdgeOpOverload]:
831+
return [exir_ops.edge.aten.mul.Scalar]
837832

838-
# ensure that the single user of dequant is aten.mul.Scalar
839-
user = list(node.users.keys())[0]
840-
if len(node.users) != 1 or user.target != exir_ops.edge.aten.mul.Scalar:
841-
return
833+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
834+
# Ensure that the single user of dequant is aten.mul.Scalar
835+
mul_node = node
836+
if len(node.all_input_nodes) != 1 or len(node.all_input_nodes[0].users) != 1:
837+
return False
842838

843-
# ensure that the other arg to mul is a node (i.e. not a constant)
844-
if len(user.args) > 1 and isinstance(user.args[1], torch.fx.Node):
845-
return
839+
dequant_node = mul_node.all_input_nodes[0]
846840

847-
new_deq_args = list(node.args)
848-
assert isinstance(node.args[1], Number)
849-
assert isinstance(user.args[1], Number)
841+
new_deq_args = list(dequant_node.args)
842+
assert isinstance(dequant_node.args[1], Number)
843+
assert isinstance(mul_node.args[1], Number)
850844
# pyre-ignore[58]: Unsupported operand *
851-
new_deq_args[1] = node.args[1] * user.args[1]
845+
new_deq_args[1] = dequant_node.args[1] * mul_node.args[1]
852846

853-
logging.debug(
854-
f"Fused {node} and {user} into {node}. Updated scale from {node.args[1]} to {new_deq_args[1]}"
855-
)
847+
# Replace all uses of mul with the dequant node
848+
mul_node.replace_all_uses_with(dequant_node)
849+
# Update the dequant node's args with the new scale
850+
dequant_node.args = tuple(new_deq_args)
856851

857-
user.replace_all_uses_with(node)
858-
node.args = tuple(new_deq_args)
852+
# Erase the mul node
853+
mul_node.graph.erase_node(mul_node)
859854

860-
graph_module.graph.erase_node(user)
861-
862-
graph_module.recompile()
863-
864-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
865-
for node in graph_module.graph.nodes:
866-
self.attempt_fusion(graph_module, node)
867-
result = super().call(graph_module)
868-
return result
855+
logging.debug(
856+
f"Fused {dequant_node} and {mul_node} into {dequant_node}. Updated scale from {dequant_node.args[1]} to {new_deq_args[1]}"
857+
)
858+
return True
869859

870860

871861
@register_cadence_pass(CadencePassAttribute(opt_level=1))
872-
class FuseMulTensorIntoQuantPass(ExportPass):
862+
class FuseMulTensorIntoQuantPass(RemoveOrReplacePassInterface):
873863
"""
874864
Looks for the pattern where aten.mul.Tensor is followed by quant node.
875865
If found, updates the quant scale to reflect the multiplication and
876866
removes the mul node.
877867
"""
878868

879-
def attempt_fusion(
880-
self, graph_module: torch.fx.GraphModule, mul_node: torch.fx.Node
881-
) -> None:
882-
if len(mul_node.args) != 2 or len(mul_node.users) != 1:
883-
return
869+
@property
870+
def targets(self) -> list[EdgeOpOverload]:
871+
return [exir_ops.edge.aten.mul.Tensor]
872+
# return [exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.cadence.quantize_per_tensor.default]
873+
874+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
875+
876+
mul_node = node
877+
if len(mul_node.users) != 1:
878+
return False
879+
880+
user = next(iter(mul_node.users))
881+
if len(user.all_input_nodes) != 1:
882+
return False
883+
884+
if user.target not in [
885+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
886+
exir_ops.edge.cadence.quantize_per_tensor.default,
887+
]:
888+
return False
889+
890+
# Alias for readability.
891+
quant_node = user
884892

885893
first_arg = cast(torch.fx.Node, mul_node.args[0])
886894
second_arg = cast(torch.fx.Node, mul_node.args[1])
@@ -896,22 +904,11 @@ def attempt_fusion(
896904
input_node = second_arg
897905
else:
898906
# Full node is not found, skip.
899-
return
907+
return False
900908

901909
# Ensure that the mul op does not do any broadcasting.
902-
if input_node.meta["val"].shape != mul_node.meta["val"].shape:
903-
return
904-
905-
mul_user = list(mul_node.users.keys())[0]
906-
907-
# Ensure only the expected quant ops are using the current mul op.
908-
if mul_user.target not in {
909-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
910-
exir_ops.edge.cadence.quantize_per_tensor.default,
911-
}:
912-
return
913-
914-
quant_node = mul_user
910+
if input_node.meta["val"].shape != node.meta["val"].shape:
911+
return False
915912

916913
# Calculate the new scale value.
917914
old_scale = quant_node.args[1]
@@ -925,42 +922,41 @@ def attempt_fusion(
925922
new_scale = old_scale / mul_scalar
926923
q = zp + x / new_scale
927924
"""
925+
926+
# Cannot fuse if either value is zero:
927+
# - mul_scalar == 0 would cause division by zero computing new_scale
928+
# - old_scale == 0 would result in new_scale = 0, causing division by zero during quantization
929+
if mul_scalar == 0 or old_scale == 0:
930+
return False
928931
new_scale = float(old_scale) / float(mul_scalar)
929932

930933
logging.debug(
931-
f"Fused {mul_node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}"
934+
f"Fused {node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}"
932935
)
933936

934937
# Update quant node input and scale.
935938
old_quant_input = cast(torch.fx.Node, quant_node.args[0])
936-
new_quant_input = cast(torch.fx.Node, mul_node.args[0])
939+
new_quant_input = input_node
937940
quant_node.replace_input_with(old_quant_input, new_quant_input)
938941
quant_node.update_arg(1, new_scale)
939942

940-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
941-
for node in graph_module.graph.find_nodes(
942-
op="call_function", target=exir_ops.edge.aten.mul.Tensor
943-
):
944-
self.attempt_fusion(graph_module, node)
945-
graph_module.graph.eliminate_dead_code()
946-
return super().call(graph_module)
943+
return True
947944

948945

949946
@register_cadence_pass(CadencePassAttribute(opt_level=1))
950-
class FuseMulTensorIntoDequantPass(ExportPass):
947+
class FuseMulTensorIntoDequantPass(RemoveOrReplacePassInterface):
951948
"""
952949
Looks for the pattern where aten.mul is multiplying the outputs of dequantize
953950
and aten.full, or vice versa. If found, updates the dequant scale to reflect
954951
the multiplication and removes the full and mul nodes.
955952
"""
956953

957-
def attempt_fusion(
958-
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
959-
) -> None:
960-
if node.target != exir_ops.edge.aten.mul.Tensor:
961-
return
954+
@property
955+
def targets(self) -> list[EdgeOpOverload]:
956+
return [exir_ops.edge.aten.mul.Tensor]
962957

963-
# ensure that one of the args to mul is dequantize and the other is aten.full
958+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
959+
# Ensure that one of the args to mul is dequantize and the other is aten.full
964960
dequant_nodes = [
965961
arg
966962
for arg in node.args
@@ -980,14 +976,14 @@ def attempt_fusion(
980976
]
981977

982978
if len(dequant_nodes) != 1 or len(multiplier_nodes) != 1:
983-
return
979+
return False
984980

985981
deq_node = dequant_nodes[0]
986982
mplier_node = multiplier_nodes[0]
987983

988-
# ensure that dequant and full don't have any other users
984+
# Ensure that dequant and full don't have any other users
989985
if len(deq_node.users) > 1 or len(mplier_node.users) > 1:
990-
return
986+
return False
991987

992988
new_deq_args = list(deq_node.args)
993989
assert isinstance(deq_node.args[1], Number)
@@ -999,18 +995,16 @@ def attempt_fusion(
999995
f"Fused {node} and {mplier_node} into {deq_node}. Updated scale from {deq_node.args[1]} to {new_deq_args[1]}"
1000996
)
1001997

998+
# Replace all uses of the mul node with the dequant node
1002999
node.replace_all_uses_with(deq_node)
1000+
# Update the dequant node's args with the new scale
10031001
deq_node.args = tuple(new_deq_args)
10041002

1005-
graph_module.graph.erase_node(node)
1006-
graph_module.graph.erase_node(mplier_node)
1007-
graph_module.recompile()
1003+
# Erase the mul and full nodes
1004+
node.graph.erase_node(node)
1005+
node.graph.erase_node(mplier_node)
10081006

1009-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
1010-
for node in graph_module.graph.nodes:
1011-
self.attempt_fusion(graph_module, node)
1012-
result = super().call(graph_module)
1013-
return result
1007+
return True
10141008

10151009

10161010
@register_cadence_pass(CadencePassAttribute(opt_level=1))

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,8 @@ def test_fuse_mul_into_dequant(self) -> None:
602602
FULL_VALUE: Final[float] = 3
603603

604604
builder = GraphBuilder()
605-
x = builder.placeholder("x", torch.randn(*INPUT_SHAPE, dtype=torch.float32))
605+
x_input = torch.randint(low=0, high=255, size=INPUT_SHAPE, dtype=torch.uint8)
606+
x = builder.placeholder("x", x_input)
606607
dequant = builder.call_operator(
607608
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
608609
args=(x, DEQUANT_SCALE, 0, 0, 255, torch.uint8),
@@ -617,8 +618,17 @@ def test_fuse_mul_into_dequant(self) -> None:
617618
)
618619
builder.output([mul])
619620
original_graph = builder.get_graph_module()
621+
gm_before = copy.deepcopy(original_graph)
622+
620623
p = FuseMulTensorIntoDequantPass()
621-
converted_graph = cast(PassResult, p(original_graph)).graph_module
624+
result = cast(PassResult, p(original_graph))
625+
self.assertTrue(result.modified)
626+
converted_graph = result.graph_module
627+
628+
# Validate numerical accuracy
629+
validate_numerics(
630+
gm_before, converted_graph, (x_input,), "FuseMulTensorIntoDequantPass"
631+
)
622632

623633
# verify that the mul and full ops were removed
624634
self.check_op_counts(
@@ -645,7 +655,8 @@ def test_fuse_mul_scalar_into_dequant(self) -> None:
645655
mul_value = 0.3
646656

647657
builder = GraphBuilder()
648-
x = builder.placeholder("x", torch.randn(2, 3, 4, dtype=torch.float32))
658+
x_input = torch.randn(2, 3, 4, dtype=torch.float32)
659+
x = builder.placeholder("x", x_input)
649660
quant = builder.call_operator(
650661
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
651662
args=(x, 1, 0, -128, 127, torch.int8),
@@ -660,8 +671,17 @@ def test_fuse_mul_scalar_into_dequant(self) -> None:
660671
)
661672
builder.output([mul_scalar])
662673
original_graph = builder.get_graph_module()
674+
gm_before = copy.deepcopy(original_graph)
675+
663676
p = FuseMulScalarIntoDequantPass()
664-
converted_graph = cast(PassResult, p(original_graph)).graph_module
677+
result = cast(PassResult, p(original_graph))
678+
self.assertTrue(result.modified)
679+
converted_graph = result.graph_module
680+
681+
# Validate numerical accuracy
682+
validate_numerics(
683+
gm_before, converted_graph, (x_input,), "FuseMulScalarIntoDequantPass"
684+
)
665685

666686
# verify that the mul and full ops were removed
667687
self.check_op_counts(
@@ -687,7 +707,8 @@ def test_fuse_mul_into_quant(self) -> None:
687707
mul_value = 10
688708

689709
builder = GraphBuilder()
690-
x = builder.placeholder("x", torch.randn(4, 32, dtype=torch.float32))
710+
x_input = torch.randn(4, 32, dtype=torch.float32)
711+
x = builder.placeholder("x", x_input)
691712
full = builder.call_operator(
692713
op=exir_ops.edge.aten.full.default,
693714
args=([1], mul_value),
@@ -702,8 +723,17 @@ def test_fuse_mul_into_quant(self) -> None:
702723
)
703724
builder.output([quant])
704725
original_graph = builder.get_graph_module()
726+
gm_before = copy.deepcopy(original_graph)
727+
705728
p = FuseMulTensorIntoQuantPass()
706-
converted_graph = cast(PassResult, p(original_graph)).graph_module
729+
result = cast(PassResult, p(original_graph))
730+
self.assertTrue(result.modified)
731+
converted_graph = result.graph_module
732+
733+
# Validate numerical accuracy
734+
validate_numerics(
735+
gm_before, converted_graph, (x_input,), "FuseMulTensorIntoQuantPass"
736+
)
707737

708738
# verify that the mul and full ops were removed
709739
self.check_op_counts(
@@ -723,12 +753,6 @@ def test_fuse_mul_into_quant(self) -> None:
723753
new_quant_scale = node.args[1]
724754
self.assertEqual(new_quant_scale, quant_scale / mul_value)
725755

726-
# verify the math is correct
727-
inp = torch.randn(4, 32, dtype=torch.float32)
728-
original_out = original_graph(inp)[0]
729-
new_out = converted_graph(inp)[0]
730-
assert torch.equal(original_out, new_out)
731-
732756
def test_fuse_then_transpose_pass(self) -> None:
733757
# Create a graph with full -> transpose -> permute -> view.
734758
builder = GraphBuilder()

backends/xnnpack/test/TARGETS

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,28 @@ runtime.python_test(
2727
],
2828
)
2929

30+
runtime.python_test(
31+
name = "test_xnnpack_fragments",
32+
srcs = glob([
33+
"fragments/*.py",
34+
]) + [
35+
"test_xnnpack_utils.py",
36+
],
37+
deps = [
38+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
39+
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
40+
"//executorch/backends/xnnpack/test/tester:tester",
41+
"//executorch/devtools:lib",
42+
"//executorch/devtools/bundled_program:config",
43+
"//executorch/devtools/bundled_program/serialize:lib",
44+
"//executorch/exir/passes:constant_prop_pass",
45+
"//pytorch/ao:torchao", # @manual
46+
],
47+
external_deps = [
48+
"libtorch",
49+
],
50+
)
51+
3052
runtime.python_test(
3153
name = "test_xnnpack_ops",
3254
srcs = glob([

0 commit comments

Comments
 (0)