4848 "dim" : 32 ,
4949 "buffer_prefix" : "albert"
5050 },
51- "hf_Bart" : {
52- "dim" : 16 ,
53- "buffer_prefix" : "bart"
54- },
55- "hf_Bert" : {
56- "dim" : 16 ,
57- "buffer_prefix" : "bert"
58- },
59- "hf_GPT2" : {
60- "dim" : 16 ,
61- "buffer_prefix" : "gpt2"
62- },
63- "hf_T5" : {
64- "dim" : 4 ,
65- "buffer_prefix" : "t5"
66- },
51+ # "hf_Bart": {
52+ # "dim": 16,
53+ # },
54+ # "hf_Bert": {
55+ # "dim": 16,
56+ # "buffer_prefix": "bert"
57+ # },
58+ # "hf_GPT2": {
59+ # "dim": 16,
60+ # "buffer_prefix": "gpt2"
61+ # },
62+ # "hf_T5": {
63+ # "dim": 4,
64+ # "buffer_prefix": "t5"
65+ # },
6766 "mnasnet1_0" : {
6867 "dim" : 256 ,
6968 },
@@ -182,30 +181,21 @@ def export_torchbench_model(
182181
183182 _ , model_name , model , forward_args , _ = get_model_and_inputs (model_id , batch_size , tb_dir , tb_args )
184183
184+ for idx , i in enumerate (forward_args .values ()):
185+ np .save (f"input{ idx } " , i .clone ().detach ().cpu ())
185186 if dtype == torch .float16 :
186187 model = model .half ()
187188 model .to ("cuda:0" )
188189
189190 if not isinstance (forward_args , dict ):
190191 forward_args = [i .type (dtype ) for i in forward_args ]
191- elif "hf" in model_id :
192- forward_args ["head_mask" ] = torch .zeros (model .config .num_hidden_layers , device = "cuda:0" )
193192
194193 mapper = {}
195194 if (external_weights_dir is not None ):
196195 if not os .path .exists (external_weights_dir ):
197196 os .mkdir (external_weights_dir )
198- external_weight_path = os .path .join (external_weights_dir , f"{ model_id } _{ precision } .{ external_weights } " )
199- if os .path .exists (external_weight_path ):
200- print ("External weights for this module already exist at {external_weight_path}. Will not overwrite." )
201- utils .save_external_weights (
202- mapper ,
203- model ,
204- external_weights ,
205- external_weight_path ,
206- )
207- if weights_only :
208- return external_weight_path
197+ external_weight_path = os .path .join (external_weights_dir , f"{ model_id } _{ precision } .irpa" )
198+
209199
210200 decomp_list = [torch .ops .aten .reflection_pad2d ]
211201 if decomp_attn == True :
@@ -225,18 +215,20 @@ def __init__(self, model):
225215 self .mod = model
226216
227217 def forward (self , inp ):
228- return self .mod (** inp , return_dict = False )
229- # In transformers, the position ids buffer is registered as non-persistent,
230- # which makes it fail to globalize in the FX import.
231- # Add them manually to the state dict here.
232-
233- prefix = torchbench_models_dict [model_id ]["buffer_prefix" ]
234- getattr (model , prefix ).embeddings .register_buffer (
235- "position_ids" ,
236- getattr (model , prefix ).embeddings .position_ids ,
237- persistent = True ,
238- )
218+ return self .mod (** inp )
219+
220+ if "Bart" not in model_id :
221+ # In some transformers models, the position ids buffer is registered as non-persistent,
222+ # which makes it fail to globalize in the FX import.
223+ # Add them manually to the state dict here.
239224
225+ prefix = torchbench_models_dict [model_id ]["buffer_prefix" ]
226+ getattr (model , prefix ).embeddings .register_buffer (
227+ "position_ids" ,
228+ getattr (model , prefix ).embeddings .position_ids ,
229+ persistent = True ,
230+ )
231+ breakpoint ()
240232 fxb = FxProgramsBuilder (HF_M (model ))
241233 @fxb .export_program (args = (forward_args ,))
242234 def _forward (module : HF_M (model ), inputs ):
@@ -252,6 +244,7 @@ class CompiledTorchbenchModel(CompiledModule):
252244
253245 if external_weights :
254246 externalize_module_parameters (model )
247+ save_module_parameters (external_weight_path , model )
255248
256249 inst = CompiledTorchbenchModel (context = Context (), import_to = "IMPORT" )
257250
0 commit comments