@@ -856,19 +856,32 @@ class FuseMulTensorIntoQuantPass(ExportPass):
856856 def attempt_fusion (
857857 self , graph_module : torch .fx .GraphModule , mul_node : torch .fx .Node
858858 ) -> None :
859- full_nodes = [
860- arg
861- for arg in mul_node .args
862- if isinstance (arg , torch .fx .Node )
863- and arg .target == exir_ops .edge .aten .full .default
864- ]
859+ if len (mul_node .args ) != 2 or len (mul_node .users ) != 1 :
860+ return
861+
862+ first_arg = cast (torch .fx .Node , mul_node .args [0 ])
863+ second_arg = cast (torch .fx .Node , mul_node .args [1 ])
864+
865+ input_node = first_arg
866+ full_node = second_arg
867+ if second_arg .target == exir_ops .edge .aten .full .default :
868+ # Most common case, nothing to change.
869+ pass
870+ elif first_arg .target == exir_ops .edge .aten .full .default :
871+ # Input and full nodes are swapped.
872+ full_node = first_arg
873+ input_node = second_arg
874+ else :
875+ # Full node is not found, skip.
876+ return
865877
866- if len (full_nodes ) != 1 or len (mul_node .users ) != 1 :
878+ # Ensure that the mul op does not do any broadcasting.
879+ if input_node .meta ["val" ].shape != mul_node .meta ["val" ].shape :
867880 return
868881
869- full_node = full_nodes [0 ]
870882 mul_user = list (mul_node .users .keys ())[0 ]
871883
884+ # Ensure only the expected quant ops are using the current mul op.
872885 if mul_user .target not in {
873886 exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
874887 exir_ops .edge .cadence .quantize_per_tensor .default ,
@@ -878,33 +891,28 @@ def attempt_fusion(
878891 quant_node = mul_user
879892
880893 # Calculate the new scale value.
881- prev_scale = quant_node .args [1 ]
882- assert isinstance (prev_scale , (int , float ))
894+ old_scale = quant_node .args [1 ]
895+ assert isinstance (old_scale , (int , float ))
883896 mul_scalar = full_node .args [1 ]
884897 assert isinstance (mul_scalar , (int , float ))
885- new_scale = float (prev_scale ) * float (mul_scalar )
898+ """ The reason why we divide old scale by the mul value to get a new scale:
899+ y = x * mul_scalar
900+ q = zp + y / old_scale
901+ q = zp + x * mul_scalar / old_scale
902+ new_scale = old_scale / mul_scalar
903+ q = zp + x / new_scale
904+ """
905+ new_scale = float (old_scale ) / float (mul_scalar )
886906
887907 logging .debug (
888908 f"Fused { mul_node } and { full_node } into { quant_node } . Updated scale from { quant_node .args [1 ]} to { new_scale } "
889909 )
890910
891- # Replace the input first
892- quant_node .replace_input_with (
893- cast (torch .fx .Node , quant_node .args [0 ]),
894- cast (torch .fx .Node , mul_node .args [0 ]),
895- )
896-
897- # Now update the scale in the args
898- new_quant_args = list (quant_node .args )
899- new_quant_args [1 ] = new_scale
900- quant_node .args = tuple (new_quant_args )
901-
902- # Clean up the mul_node
903- mul_node .args = ()
904- mul_node .users = {}
905-
906- graph_module .graph .erase_node (mul_node )
907- graph_module .graph .erase_node (full_node )
911+ # Update quant node input and scale.
912+ old_quant_input = cast (torch .fx .Node , quant_node .args [0 ])
913+ new_quant_input = cast (torch .fx .Node , mul_node .args [0 ])
914+ quant_node .replace_input_with (old_quant_input , new_quant_input )
915+ quant_node .update_arg (1 , new_scale )
908916
909917 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
910918 for node in graph_module .graph .find_nodes (
0 commit comments