3333)
3434from executorch .backends .cadence .aot .quantizer .utils import (
3535 check_out_zero_point_is_min_range ,
36- copy_node_metadata ,
3736 create_zero_bias_int32 ,
3837 find_sequential_partitions_aten ,
3938 get_conv_args ,
@@ -160,8 +159,6 @@ def get_args_and_kwargs_layer_norm(
160159 ),
161160 {"dtype" : torch .float32 },
162161 )
163- if len (inputs_inputs ) > 0 :
164- copy_node_metadata (weight , inputs_inputs [0 ])
165162
166163 bias = other_inputs [2 ] if len (other_inputs ) > 2 else None
167164
@@ -174,8 +171,6 @@ def get_args_and_kwargs_layer_norm(
174171 ),
175172 {"dtype" : torch .float32 },
176173 )
177- if len (inputs_inputs ) > 0 :
178- copy_node_metadata (bias , inputs_inputs [0 ])
179174
180175 # Make the args and kwargs for the replacement op
181176 args = tuple (inputs_inputs + [scale , zero_point ])
@@ -351,8 +346,6 @@ def get_args_and_kwargs_softmax(
351346 ),
352347 {"dtype" : torch .int32 },
353348 )
354- if len (inputs_inputs ) > 0 :
355- copy_node_metadata (mask_tensor , inputs_inputs [0 ])
356349 # Make the scale and zero_point tensors
357350 in_scale = dequants_inputs [0 ].args [1 ]
358351 in_zero_point = dequants_inputs [0 ].args [2 ]
@@ -402,13 +395,10 @@ def get_args_and_kwargs_mixed_w8a32_conv(
402395 torch .ops .aten .permute .default ,
403396 (other_inputs [0 ], [0 , 2 , 1 ]), # NCL -> NLC
404397 )
405- copy_node_metadata (transposed_inputs , other_inputs [0 ])
406-
407398 transposed_weights = graph_module .graph .call_function (
408399 torch .ops .aten .permute .default ,
409400 (weights_inputs [0 ], [2 , 0 , 1 ]), # NCL -> LNC
410401 )
411- copy_node_metadata (transposed_weights , weights_inputs [0 ])
412402
413403 args = (
414404 transposed_inputs ,
@@ -592,26 +582,6 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
592582 torch .ops .aten .transpose .int ,
593583 (weights_inputs [0 ], 0 , 1 ),
594584 )
595- if "val" in weights_inputs [0 ].meta :
596- original_val = weights_inputs [0 ].meta ["val" ]
597- fake_mode = original_val .fake_mode
598- if fake_mode is not None :
599- with fake_mode :
600- transposed_val = torch .ops .aten .transpose .int (
601- original_val , 0 , 1
602- )
603- transposed_weights .meta ["val" ] = transposed_val
604- else :
605- transposed_shape = list (original_val .shape )
606- transposed_shape [0 ], transposed_shape [1 ] = (
607- transposed_shape [1 ],
608- transposed_shape [0 ],
609- )
610- transposed_weights .meta ["val" ] = torch .zeros (
611- transposed_shape , dtype = original_val .dtype
612- )
613- copy_node_metadata (transposed_weights , weights_inputs [0 ])
614-
615585 # Call linear with transposed weight
616586 args , kwargs = get_args_and_kwargs_linear (
617587 graph_module ,
@@ -684,19 +654,6 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
684654
685655 legalize_graph (graph_module )
686656 graph_module .graph .eliminate_dead_code ()
687- nodes_list = list (graph_module .graph .nodes )
688-
689- if len (nodes_list ) > 0 and nodes_list [- 1 ].op != "output" :
690- output_nodes = [n for n in nodes_list if n .op == "output" ]
691- output_arg = output_nodes [0 ].args [0 ]
692- original_meta = output_nodes [0 ].meta .copy ()
693-
694- for out_node in output_nodes :
695- graph_module .graph .erase_node (out_node )
696-
697- new_output_node = graph_module .graph .output (output_arg )
698- new_output_node .meta .update (original_meta )
699-
700657 graph_module .recompile ()
701658 return PassResult (graph_module , True )
702659
0 commit comments