@@ -818,30 +818,32 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
818818
819819
820820@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
821- class FuseMulScalarIntoDequantPass (ExportPass ):
821+ class FuseMulScalarIntoDequantPass (RemoveOrReplacePassInterface ):
822822 """
823823 Looks for the pattern where aten.mul.Scalar is multiplying the
824824 outputs of dequantize. If found, updates the dequant scale
825825 to reflect the multiplication and removes the mul node.
826826 """
827827
828- def attempt_fusion (
829- self , graph_module : torch .fx .GraphModule , node : torch .fx .Node
830- ) -> None :
831- if node .target not in {
828+ @property
829+ def targets (self ) -> list [EdgeOpOverload ]:
830+ return [
832831 exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
833832 exir_ops .edge .cadence .dequantize_per_tensor .default ,
834- }:
835- return
833+ ]
834+
835+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
836+ # Ensure that the single user of dequant is aten.mul.Scalar
837+ if len (node .users ) != 1 :
838+ return False
836839
837- # ensure that the single user of dequant is aten.mul.Scalar
838840 user = list (node .users .keys ())[0 ]
839- if len ( node . users ) != 1 or user .target != exir_ops .edge .aten .mul .Scalar :
840- return
841+ if user .target != exir_ops .edge .aten .mul .Scalar :
842+ return False
841843
842- # ensure that the other arg to mul is a node (i.e. not a constant)
844+ # Ensure that the other arg to mul is not a node (i.e. it's a constant)
843845 if len (user .args ) > 1 and isinstance (user .args [1 ], torch .fx .Node ):
844- return
846+ return False
845847
846848 new_deq_args = list (node .args )
847849 assert isinstance (node .args [1 ], Number )
@@ -853,36 +855,36 @@ def attempt_fusion(
853855 f"Fused { node } and { user } into { node } . Updated scale from { node .args [1 ]} to { new_deq_args [1 ]} "
854856 )
855857
858+ # Replace all uses of mul with the dequant node
856859 user .replace_all_uses_with (node )
860+ # Update the dequant node's args with the new scale
857861 node .args = tuple (new_deq_args )
858862
859- graph_module .graph .erase_node (user )
860-
861- graph_module .recompile ()
863+ # Erase the mul node
864+ node .graph .erase_node (user )
862865
863- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
864- for node in graph_module .graph .nodes :
865- self .attempt_fusion (graph_module , node )
866- result = super ().call (graph_module )
867- return result
866+ return True
868867
869868
870869@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
871- class FuseMulTensorIntoQuantPass (ExportPass ):
870+ class FuseMulTensorIntoQuantPass (RemoveOrReplacePassInterface ):
872871 """
873872 Looks for the pattern where aten.mul.Tensor is followed by quant node.
874873 If found, updates the quant scale to reflect the multiplication and
875874 removes the mul node.
876875 """
877876
878- def attempt_fusion (
879- self , graph_module : torch .fx .GraphModule , mul_node : torch .fx .Node
880- ) -> None :
881- if len (mul_node .args ) != 2 or len (mul_node .users ) != 1 :
882- return
877+ @property
878+ def targets (self ) -> list [EdgeOpOverload ]:
879+ return [exir_ops .edge .aten .mul .Tensor ]
883880
884- first_arg = cast (torch .fx .Node , mul_node .args [0 ])
885- second_arg = cast (torch .fx .Node , mul_node .args [1 ])
881+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
882+ # Check that mul has exactly 2 args and 1 user
883+ if len (node .args ) != 2 or len (node .users ) != 1 :
884+ return False
885+
886+ first_arg = cast (torch .fx .Node , node .args [0 ])
887+ second_arg = cast (torch .fx .Node , node .args [1 ])
886888
887889 input_node = first_arg
888890 full_node = second_arg
@@ -895,20 +897,20 @@ def attempt_fusion(
895897 input_node = second_arg
896898 else :
897899 # Full node is not found, skip.
898- return
900+ return False
899901
900902 # Ensure that the mul op does not do any broadcasting.
901- if input_node .meta ["val" ].shape != mul_node .meta ["val" ].shape :
902- return
903+ if input_node .meta ["val" ].shape != node .meta ["val" ].shape :
904+ return False
903905
904- mul_user = list (mul_node .users .keys ())[0 ]
906+ mul_user = list (node .users .keys ())[0 ]
905907
906908 # Ensure only the expected quant ops are using the current mul op.
907909 if mul_user .target not in {
908910 exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
909911 exir_ops .edge .cadence .quantize_per_tensor .default ,
910912 }:
911- return
913+ return False
912914
913915 quant_node = mul_user
914916
@@ -927,39 +929,32 @@ def attempt_fusion(
927929 new_scale = float (old_scale ) / float (mul_scalar )
928930
929931 logging .debug (
930- f"Fused { mul_node } and { full_node } into { quant_node } . Updated scale from { quant_node .args [1 ]} to { new_scale } "
932+ f"Fused { node } and { full_node } into { quant_node } . Updated scale from { quant_node .args [1 ]} to { new_scale } "
931933 )
932934
933935 # Update quant node input and scale.
934936 old_quant_input = cast (torch .fx .Node , quant_node .args [0 ])
935- new_quant_input = cast ( torch . fx . Node , mul_node . args [ 0 ])
937+ new_quant_input = input_node
936938 quant_node .replace_input_with (old_quant_input , new_quant_input )
937939 quant_node .update_arg (1 , new_scale )
938940
939- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
940- for node in graph_module .graph .find_nodes (
941- op = "call_function" , target = exir_ops .edge .aten .mul .Tensor
942- ):
943- self .attempt_fusion (graph_module , node )
944- graph_module .graph .eliminate_dead_code ()
945- return super ().call (graph_module )
941+ return True
946942
947943
948944@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
949- class FuseMulTensorIntoDequantPass (ExportPass ):
945+ class FuseMulTensorIntoDequantPass (RemoveOrReplacePassInterface ):
950946 """
951947 Looks for the pattern where aten.mul is multiplying the outputs of dequantize
952948 and aten.full, or vice versa. If found, updates the dequant scale to reflect
953949 the multiplication and removes the full and mul nodes.
954950 """
955951
956- def attempt_fusion (
957- self , graph_module : torch .fx .GraphModule , node : torch .fx .Node
958- ) -> None :
959- if node .target != exir_ops .edge .aten .mul .Tensor :
960- return
952+ @property
953+ def targets (self ) -> list [EdgeOpOverload ]:
954+ return [exir_ops .edge .aten .mul .Tensor ]
961955
962- # ensure that one of the args to mul is dequantize and the other is aten.full
956+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
957+ # Ensure that one of the args to mul is dequantize and the other is aten.full
963958 dequant_nodes = [
964959 arg
965960 for arg in node .args
@@ -979,14 +974,14 @@ def attempt_fusion(
979974 ]
980975
981976 if len (dequant_nodes ) != 1 or len (multiplier_nodes ) != 1 :
982- return
977+ return False
983978
984979 deq_node = dequant_nodes [0 ]
985980 mplier_node = multiplier_nodes [0 ]
986981
987- # ensure that dequant and full don't have any other users
982+ # Ensure that dequant and full don't have any other users
988983 if len (deq_node .users ) > 1 or len (mplier_node .users ) > 1 :
989- return
984+ return False
990985
991986 new_deq_args = list (deq_node .args )
992987 assert isinstance (deq_node .args [1 ], Number )
@@ -998,18 +993,16 @@ def attempt_fusion(
998993 f"Fused { node } and { mplier_node } into { deq_node } . Updated scale from { deq_node .args [1 ]} to { new_deq_args [1 ]} "
999994 )
1000995
996+ # Replace all uses of the mul node with the dequant node
1001997 node .replace_all_uses_with (deq_node )
998+ # Update the dequant node's args with the new scale
1002999 deq_node .args = tuple (new_deq_args )
10031000
1004- graph_module . graph . erase_node ( node )
1005- graph_module .graph .erase_node (mplier_node )
1006- graph_module . recompile ( )
1001+ # Erase the mul and full nodes
1002+ node .graph .erase_node (node )
1003+ node . graph . erase_node ( mplier_node )
10071004
1008- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
1009- for node in graph_module .graph .nodes :
1010- self .attempt_fusion (graph_module , node )
1011- result = super ().call (graph_module )
1012- return result
1005+ return True
10131006
10141007
10151008@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
0 commit comments