@@ -160,6 +160,20 @@ def get_args_and_kwargs_layer_norm(
160160 ),
161161 {"dtype" : torch .float32 },
162162 )
163+ if len (inputs_inputs ) > 0 :
164+ if "val" in inputs_inputs [0 ].meta :
165+ fake_mode = inputs_inputs [0 ].meta ["val" ].fake_mode
166+ if fake_mode is not None :
167+ with fake_mode :
168+ fake_weight = torch .full (
169+ other_inputs [0 ], 1 , dtype = torch .float32
170+ )
171+ weight .meta ["val" ] = fake_weight
172+ else :
173+ weight .meta ["val" ] = torch .full (
174+ other_inputs [0 ], 1 , dtype = torch .float32
175+ )
176+ copy_node_metadata (weight , inputs_inputs [0 ])
163177
164178 bias = other_inputs [2 ] if len (other_inputs ) > 2 else None
165179
@@ -172,6 +186,18 @@ def get_args_and_kwargs_layer_norm(
172186 ),
173187 {"dtype" : torch .float32 },
174188 )
189+ if len (inputs_inputs ) > 0 :
190+ if "val" in inputs_inputs [0 ].meta :
191+ fake_mode = inputs_inputs [0 ].meta ["val" ].fake_mode
192+ if fake_mode is not None :
193+ with fake_mode :
194+ fake_bias = torch .full (other_inputs [0 ], 0 , dtype = torch .float32 )
195+ bias .meta ["val" ] = fake_bias
196+ else :
197+ bias .meta ["val" ] = torch .full (
198+ other_inputs [0 ], 0 , dtype = torch .float32
199+ )
200+ copy_node_metadata (bias , inputs_inputs [0 ])
175201
176202 # Make the args and kwargs for the replacement op
177203 args = tuple (inputs_inputs + [scale , zero_point ])
@@ -347,6 +373,16 @@ def get_args_and_kwargs_softmax(
347373 ),
348374 {"dtype" : torch .int32 },
349375 )
376+ if len (inputs_inputs ) > 0 :
377+ if "val" in inputs_inputs [0 ].meta :
378+ fake_mode = inputs_inputs [0 ].meta ["val" ].fake_mode
379+ if fake_mode is not None :
380+ with fake_mode :
381+ fake_mask = torch .full (mask_shape , 0.0 , dtype = torch .int32 )
382+ mask_tensor .meta ["val" ] = fake_mask
383+ else :
384+ mask_tensor .meta ["val" ] = torch .full (mask_shape , 0.0 , dtype = torch .int32 )
385+ copy_node_metadata (mask_tensor , inputs_inputs [0 ])
350386 # Make the scale and zero_point tensors
351387 in_scale = dequants_inputs [0 ].args [1 ]
352388 in_zero_point = dequants_inputs [0 ].args [2 ]
0 commit comments