@@ -819,68 +819,76 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
819819
820820
821821@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
822- class FuseMulScalarIntoDequantPass (ExportPass ):
822+ class FuseMulScalarIntoDequantPass (RemoveOrReplacePassInterface ):
823823 """
824824 Looks for the pattern where aten.mul.Scalar is multiplying the
825825 outputs of dequantize. If found, updates the dequant scale
826826 to reflect the multiplication and removes the mul node.
827827 """
828828
829- def attempt_fusion (
830- self , graph_module : torch .fx .GraphModule , node : torch .fx .Node
831- ) -> None :
832- if node .target not in {
833- exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
834- exir_ops .edge .cadence .dequantize_per_tensor .default ,
835- }:
836- return
829+ @property
830+ def targets (self ) -> list [EdgeOpOverload ]:
831+ return [exir_ops .edge .aten .mul .Scalar ]
837832
838- # ensure that the single user of dequant is aten.mul.Scalar
839- user = list (node .users .keys ())[0 ]
840- if len (node .users ) != 1 or user .target != exir_ops .edge .aten .mul .Scalar :
841- return
833+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
834+ # Ensure that the single user of dequant is aten.mul.Scalar
835+ mul_node = node
836+ if len (node .all_input_nodes ) != 1 or len (node .all_input_nodes [0 ].users ) != 1 :
837+ return False
842838
843- # ensure that the other arg to mul is a node (i.e. not a constant)
844- if len (user .args ) > 1 and isinstance (user .args [1 ], torch .fx .Node ):
845- return
839+ dequant_node = mul_node .all_input_nodes [0 ]
846840
847- new_deq_args = list (node .args )
848- assert isinstance (node .args [1 ], Number )
849- assert isinstance (user .args [1 ], Number )
841+ new_deq_args = list (dequant_node .args )
842+ assert isinstance (dequant_node .args [1 ], Number )
843+ assert isinstance (mul_node .args [1 ], Number )
850844 # pyre-ignore[58]: Unsupported operand *
851- new_deq_args [1 ] = node .args [1 ] * user .args [1 ]
845+ new_deq_args [1 ] = dequant_node .args [1 ] * mul_node .args [1 ]
852846
853- logging .debug (
854- f"Fused { node } and { user } into { node } . Updated scale from { node .args [1 ]} to { new_deq_args [1 ]} "
855- )
847+ # Replace all uses of mul with the dequant node
848+ mul_node .replace_all_uses_with (dequant_node )
849+ # Update the dequant node's args with the new scale
850+ dequant_node .args = tuple (new_deq_args )
856851
857- user . replace_all_uses_with ( node )
858- node . args = tuple ( new_deq_args )
852+ # Erase the mul node
853+ mul_node . graph . erase_node ( mul_node )
859854
860- graph_module .graph .erase_node (user )
861-
862- graph_module .recompile ()
863-
864- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
865- for node in graph_module .graph .nodes :
866- self .attempt_fusion (graph_module , node )
867- result = super ().call (graph_module )
868- return result
855+ logging .debug (
856+ f"Fused { dequant_node } and { mul_node } into { dequant_node } . Updated scale from { dequant_node .args [1 ]} to { new_deq_args [1 ]} "
857+ )
858+ return True
869859
870860
871861@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
872- class FuseMulTensorIntoQuantPass (ExportPass ):
862+ class FuseMulTensorIntoQuantPass (RemoveOrReplacePassInterface ):
873863 """
874864 Looks for the pattern where aten.mul.Tensor is followed by quant node.
875865 If found, updates the quant scale to reflect the multiplication and
876866 removes the mul node.
877867 """
878868
879- def attempt_fusion (
880- self , graph_module : torch .fx .GraphModule , mul_node : torch .fx .Node
881- ) -> None :
882- if len (mul_node .args ) != 2 or len (mul_node .users ) != 1 :
883- return
869+ @property
870+ def targets (self ) -> list [EdgeOpOverload ]:
871+ return [exir_ops .edge .aten .mul .Tensor ]
872+ # return [exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.cadence.quantize_per_tensor.default]
873+
874+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
875+
876+ mul_node = node
877+ if len (mul_node .users ) != 1 :
878+ return False
879+
880+ user = next (iter (mul_node .users ))
881+ if len (user .all_input_nodes ) != 1 :
882+ return False
883+
884+ if user .target not in [
885+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
886+ exir_ops .edge .cadence .quantize_per_tensor .default ,
887+ ]:
888+ return False
889+
890+ # Alias for readability.
891+ quant_node = user
884892
885893 first_arg = cast (torch .fx .Node , mul_node .args [0 ])
886894 second_arg = cast (torch .fx .Node , mul_node .args [1 ])
@@ -896,22 +904,11 @@ def attempt_fusion(
896904 input_node = second_arg
897905 else :
898906 # Full node is not found, skip.
899- return
907+ return False
900908
901909 # Ensure that the mul op does not do any broadcasting.
902- if input_node .meta ["val" ].shape != mul_node .meta ["val" ].shape :
903- return
904-
905- mul_user = list (mul_node .users .keys ())[0 ]
906-
907- # Ensure only the expected quant ops are using the current mul op.
908- if mul_user .target not in {
909- exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
910- exir_ops .edge .cadence .quantize_per_tensor .default ,
911- }:
912- return
913-
914- quant_node = mul_user
910+ if input_node .meta ["val" ].shape != node .meta ["val" ].shape :
911+ return False
915912
916913 # Calculate the new scale value.
917914 old_scale = quant_node .args [1 ]
@@ -925,42 +922,41 @@ def attempt_fusion(
925922 new_scale = old_scale / mul_scalar
926923 q = zp + x / new_scale
927924 """
925+
926+ # Cannot fuse if either value is zero:
927+ # - mul_scalar == 0 would cause division by zero computing new_scale
928+ # - old_scale == 0 would result in new_scale = 0, causing division by zero during quantization
929+ if mul_scalar == 0 or old_scale == 0 :
930+ return False
928931 new_scale = float (old_scale ) / float (mul_scalar )
929932
930933 logging .debug (
931- f"Fused { mul_node } and { full_node } into { quant_node } . Updated scale from { quant_node .args [1 ]} to { new_scale } "
934+ f"Fused { node } and { full_node } into { quant_node } . Updated scale from { quant_node .args [1 ]} to { new_scale } "
932935 )
933936
934937 # Update quant node input and scale.
935938 old_quant_input = cast (torch .fx .Node , quant_node .args [0 ])
936- new_quant_input = cast ( torch . fx . Node , mul_node . args [ 0 ])
939+ new_quant_input = input_node
937940 quant_node .replace_input_with (old_quant_input , new_quant_input )
938941 quant_node .update_arg (1 , new_scale )
939942
940- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
941- for node in graph_module .graph .find_nodes (
942- op = "call_function" , target = exir_ops .edge .aten .mul .Tensor
943- ):
944- self .attempt_fusion (graph_module , node )
945- graph_module .graph .eliminate_dead_code ()
946- return super ().call (graph_module )
943+ return True
947944
948945
949946@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
950- class FuseMulTensorIntoDequantPass (ExportPass ):
947+ class FuseMulTensorIntoDequantPass (RemoveOrReplacePassInterface ):
951948 """
952949 Looks for the pattern where aten.mul is multiplying the outputs of dequantize
953950 and aten.full, or vice versa. If found, updates the dequant scale to reflect
954951 the multiplication and removes the full and mul nodes.
955952 """
956953
957- def attempt_fusion (
958- self , graph_module : torch .fx .GraphModule , node : torch .fx .Node
959- ) -> None :
960- if node .target != exir_ops .edge .aten .mul .Tensor :
961- return
954+ @property
955+ def targets (self ) -> list [EdgeOpOverload ]:
956+ return [exir_ops .edge .aten .mul .Tensor ]
962957
963- # ensure that one of the args to mul is dequantize and the other is aten.full
958+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
959+ # Ensure that one of the args to mul is dequantize and the other is aten.full
964960 dequant_nodes = [
965961 arg
966962 for arg in node .args
@@ -980,14 +976,14 @@ def attempt_fusion(
980976 ]
981977
982978 if len (dequant_nodes ) != 1 or len (multiplier_nodes ) != 1 :
983- return
979+ return False
984980
985981 deq_node = dequant_nodes [0 ]
986982 mplier_node = multiplier_nodes [0 ]
987983
988- # ensure that dequant and full don't have any other users
984+ # Ensure that dequant and full don't have any other users
989985 if len (deq_node .users ) > 1 or len (mplier_node .users ) > 1 :
990- return
986+ return False
991987
992988 new_deq_args = list (deq_node .args )
993989 assert isinstance (deq_node .args [1 ], Number )
@@ -999,18 +995,16 @@ def attempt_fusion(
999995 f"Fused { node } and { mplier_node } into { deq_node } . Updated scale from { deq_node .args [1 ]} to { new_deq_args [1 ]} "
1000996 )
1001997
998+ # Replace all uses of the mul node with the dequant node
1002999 node .replace_all_uses_with (deq_node )
1000+ # Update the dequant node's args with the new scale
10031001 deq_node .args = tuple (new_deq_args )
10041002
1005- graph_module . graph . erase_node ( node )
1006- graph_module .graph .erase_node (mplier_node )
1007- graph_module . recompile ( )
1003+ # Erase the mul and full nodes
1004+ node .graph .erase_node (node )
1005+ node . graph . erase_node ( mplier_node )
10081006
1009- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
1010- for node in graph_module .graph .nodes :
1011- self .attempt_fusion (graph_module , node )
1012- result = super ().call (graph_module )
1013- return result
1007+ return True
10141008
10151009
10161010@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
0 commit comments