@@ -24,12 +24,22 @@ def __init__(self) -> None:
2424 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
2525 graph = graph_module .graph
2626 for node in graph .nodes :
27- if node .target in decomp_set and "alpha" in node .kwargs :
27+ if (
28+ node .target in decomp_set
29+ and "alpha" in node .kwargs
30+ and node .kwargs ["alpha" ] != 1
31+ ):
2832 alpha = node .kwargs ["alpha" ]
2933 # Remove alpha from immutable dict
3034 node .kwargs = {k : v for k , v in node .kwargs .items () if k != "alpha" }
35+ input2_node = node .args [1 ]
36+ # If input2 is constant, we can just multiply the value for optimization
37+ if isinstance (input2_node , (int , float )):
38+ arg_list = list (node .args )
39+ arg_list [1 ] = input2_node * alpha
40+ node .args = tuple (arg_list )
41+ continue
3142 with graph .inserting_before (node ):
32- input2_node = node .args [1 ]
3343 mul_op = torch .ops .aten .mul .Scalar
3444 mul_node = graph .create_node (
3545 "call_function" ,
@@ -40,7 +50,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4050 ),
4151 )
4252 mul_node .meta = copy_meta (node .meta )
43- mul_node . users = { node : None }
53+ node . replace_input_with ( input2_node , mul_node )
4454 node .args = (
4555 node .args [0 ],
4656 mul_node ,
0 commit comments