@@ -862,6 +862,73 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
862862 result = super ().call (graph_module )
863863 return result
864864
865+ @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
866+ class FuseMulTensorIntoQuantPass (ExportPass ):
867+ """
868+ Looks for the pattern where aten.mul.Tensor is followed by quant node.
869+ If found, updates the quant scale to reflect the multiplication and
870+ removes the mul node.
871+ """
872+ def attempt_fusion (
873+ self , graph_module : torch .fx .GraphModule , mul_node : torch .fx .Node
874+ ) -> None :
875+ if mul_node .target != exir_ops .edge .aten .mul .Tensor :
876+ return
877+
878+ full_nodes = [
879+ arg
880+ for arg in mul_node .args
881+ if isinstance (arg , torch .fx .Node )
882+ and arg .target == exir_ops .edge .aten .full .default
883+ ]
884+
885+ if len (full_nodes ) != 1 or len (mul_node .users ) != 1 :
886+ return
887+
888+ full_node = full_nodes [0 ]
889+ mul_user = list (mul_node .users .keys ())[0 ]
890+
891+ if mul_user .target not in {
892+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
893+ exir_ops .edge .cadence .quantize_per_tensor .default ,
894+ }:
895+ return
896+
897+ quant_node = mul_user
898+
899+ # First create a copy of the current args
900+ new_quant_args = list (quant_node .args )
901+ assert isinstance (quant_node .args [1 ], Number )
902+ assert isinstance (full_node .args [1 ], Number )
903+ # pyre-ignore[58]: Unsupported operand *
904+ new_scale = quant_node .args [1 ] * full_node .args [1 ]
905+
906+ logging .debug (
907+ f"Fused { mul_node } and { full_node } into { quant_node } . Updated scale from { quant_node .args [1 ]} to { new_scale } "
908+ )
909+
910+ # Replace the input first
911+ quant_node .replace_input_with (cast (torch .fx .Node , quant_node .args [0 ]), cast (torch .fx .Node , mul_node .args [0 ]))
912+
913+ # Now update the scale in the args
914+ new_quant_args = list (quant_node .args )
915+ new_quant_args [1 ] = new_scale
916+ quant_node .args = tuple (new_quant_args )
917+
918+ # Clean up the mul_node
919+ mul_node .args = tuple ()
920+ mul_node .users = {}
921+
922+ graph_module .graph .erase_node (mul_node )
923+ graph_module .graph .erase_node (full_node )
924+ graph_module .recompile ()
925+
926+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
927+ for node in graph_module .graph .nodes :
928+ self .attempt_fusion (graph_module , node )
929+ result = super ().call (graph_module )
930+ return result
931+
865932
866933@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
867934class FuseMulTensorIntoDequantPass (ExportPass ):
0 commit comments