@@ -856,19 +856,23 @@ class FuseMulTensorIntoQuantPass(ExportPass):
856856 def attempt_fusion (
857857 self , graph_module : torch .fx .GraphModule , mul_node : torch .fx .Node
858858 ) -> None :
859- full_nodes = [
860- arg
861- for arg in mul_node .args
862- if isinstance (arg , torch .fx .Node )
863- and arg .target == exir_ops .edge .aten .full .default
864- ]
859+ if len (mul_node .args ) != 2 or len (mul_node .users ) != 1 :
860+ return
861+
862+ second_arg = cast (torch .fx .Node , mul_node .args [1 ])
863+ input_index = 0 if second_arg .target == exir_ops .edge .aten .full .default else 1
865864
866- if len (full_nodes ) != 1 or len (mul_node .users ) != 1 :
865+ input_node = cast (torch .fx .Node , mul_node .args [input_index ])
866+ full_node = cast (torch .fx .Node , mul_node .args [1 - input_index ])
867+ output_node = list (mul_node .users .keys ())[0 ]
868+
869+ # Ensure that the mul op does not do any broadcasting.
870+ if input_node .meta ["val" ].shape != output_node .meta ["val" ].shape :
867871 return
868872
869- full_node = full_nodes [0 ]
870873 mul_user = list (mul_node .users .keys ())[0 ]
871874
875+ # Ensure only the expected quant ops are using the current mul op.
872876 if mul_user .target not in {
873877 exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
874878 exir_ops .edge .cadence .quantize_per_tensor .default ,
@@ -878,33 +882,27 @@ def attempt_fusion(
878882 quant_node = mul_user
879883
880884 # Calculate the new scale value.
881- prev_scale = quant_node .args [1 ]
882- assert isinstance (prev_scale , (int , float ))
885+ old_scale = quant_node .args [1 ]
886+ assert isinstance (old_scale , (int , float ))
883887 mul_scalar = full_node .args [1 ]
884888 assert isinstance (mul_scalar , (int , float ))
885- new_scale = float (prev_scale ) * float (mul_scalar )
889+ # The reason why we divide old scale by the mul value to get a new scale:
890+ # y = x * mul_scalar
891+ # q = zp + y / old_scale
892+ # q = zp + x * mul_scalar / old_scale
893+ # new_scale = old_scale / mul_scalar
894+ # q = zp + x / new_scale
895+ new_scale = float (old_scale ) / float (mul_scalar )
886896
887897 logging .debug (
888898 f"Fused { mul_node } and { full_node } into { quant_node } . Updated scale from { quant_node .args [1 ]} to { new_scale } "
889899 )
890900
891- # Replace the input first
892- quant_node .replace_input_with (
893- cast (torch .fx .Node , quant_node .args [0 ]),
894- cast (torch .fx .Node , mul_node .args [0 ]),
895- )
896-
897- # Now update the scale in the args
898- new_quant_args = list (quant_node .args )
899- new_quant_args [1 ] = new_scale
900- quant_node .args = tuple (new_quant_args )
901-
902- # Clean up the mul_node
903- mul_node .args = ()
904- mul_node .users = {}
905-
906- graph_module .graph .erase_node (mul_node )
907- graph_module .graph .erase_node (full_node )
901+ # Update quant node input and scale.
902+ old_quant_input = cast (torch .fx .Node , quant_node .args [0 ])
903+ new_quant_input = cast (torch .fx .Node , mul_node .args [0 ])
904+ quant_node .replace_input_with (old_quant_input , new_quant_input )
905+ quant_node .update_arg (1 , new_scale )
908906
909907 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
910908 for node in graph_module .graph .find_nodes (
0 commit comments