@@ -856,19 +856,25 @@ class FuseMulTensorIntoQuantPass(ExportPass):
856856 def attempt_fusion (
857857 self , graph_module : torch .fx .GraphModule , mul_node : torch .fx .Node
858858 ) -> None :
859- full_nodes = [
860- arg
861- for arg in mul_node .args
862- if isinstance (arg , torch .fx .Node )
863- and arg .target == exir_ops .edge .aten .full .default
864- ]
859+ # import fbvscode
860+ # fbvscode.set_trace()
861+ if len (mul_node .args ) != 2 or len (mul_node .users ) != 1 :
862+ return
863+
864+ second_arg = cast (torch .fx .Node , mul_node .args [1 ])
865+ input_index = 0 if second_arg .target == exir_ops .edge .aten .full .default else 1
865866
866- if len (full_nodes ) != 1 or len (mul_node .users ) != 1 :
867+ input_node = cast (torch .fx .Node , mul_node .args [input_index ])
868+ full_node = cast (torch .fx .Node , mul_node .args [1 - input_index ])
869+ output_node = list (mul_node .users .keys ())[0 ]
870+
871+ # Ensure that the mul op does not do any broadcasting.
872+ if input_node .meta ["val" ].shape != output_node .meta ["val" ].shape :
867873 return
868874
869- full_node = full_nodes [0 ]
870875 mul_user = list (mul_node .users .keys ())[0 ]
871876
877+ # Ensure only the expected quant ops are using the current mul op.
872878 if mul_user .target not in {
873879 exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
874880 exir_ops .edge .cadence .quantize_per_tensor .default ,
@@ -878,33 +884,27 @@ def attempt_fusion(
878884 quant_node = mul_user
879885
880886 # Calculate the new scale value.
881- prev_scale = quant_node .args [1 ]
882- assert isinstance (prev_scale , (int , float ))
887+ old_scale = quant_node .args [1 ]
888+ assert isinstance (old_scale , (int , float ))
883889 mul_scalar = full_node .args [1 ]
884890 assert isinstance (mul_scalar , (int , float ))
885- new_scale = float (prev_scale ) * float (mul_scalar )
891+ # The reason why we divide old scale by the mul value to get a new scale:
892+ # y = x * mul_scalar
893+ # q = zp + y / old_scale
894+ # q = zp + x * mul_scalar / old_scale
895+ # new_scale = old_scale / mul_scalar
896+ # q = zp + x / new_scale
897+ new_scale = float (old_scale ) / float (mul_scalar )
886898
887899 logging .debug (
888900 f"Fused { mul_node } and { full_node } into { quant_node } . Updated scale from { quant_node .args [1 ]} to { new_scale } "
889901 )
890902
891- # Replace the input first
892- quant_node .replace_input_with (
893- cast (torch .fx .Node , quant_node .args [0 ]),
894- cast (torch .fx .Node , mul_node .args [0 ]),
895- )
896-
897- # Now update the scale in the args
898- new_quant_args = list (quant_node .args )
899- new_quant_args [1 ] = new_scale
900- quant_node .args = tuple (new_quant_args )
901-
902- # Clean up the mul_node
903- mul_node .args = ()
904- mul_node .users = {}
905-
906- graph_module .graph .erase_node (mul_node )
907- graph_module .graph .erase_node (full_node )
903+ # Update quant node input and scale.
904+ old_quant_input = cast (torch .fx .Node , quant_node .args [0 ])
905+ new_quant_input = cast (torch .fx .Node , mul_node .args [0 ])
906+ quant_node .replace_input_with (old_quant_input , new_quant_input )
907+ quant_node .update_arg (1 , new_scale )
908908
909909 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
910910 for node in graph_module .graph .find_nodes (
0 commit comments