File tree Expand file tree Collapse file tree 2 files changed +8
-7
lines changed
models/turbine_models/custom_models Expand file tree Collapse file tree 2 files changed +8
-7
lines changed Original file line number Diff line number Diff line change @@ -244,7 +244,7 @@ def __init__(
244244 "exit_on_vmfb" : False ,
245245 "pipeline_dir" : pipeline_dir ,
246246 "input_mlir" : None ,
247- "attn_spec" : None ,
247+ "attn_spec" : attn_spec ,
248248 "external_weights" : None ,
249249 "external_weight_path" : None ,
250250 }
Original file line number Diff line number Diff line change @@ -182,12 +182,13 @@ def export_unet_model(
182182 submodel_name = "punet"
183183 else :
184184 submodel_name = "unet"
185- if (not decomp_attn ) and use_punet :
186- attn_spec = "punet"
187- elif (not decomp_attn ) and "gfx9" in target :
188- attn_spec = "mfma"
189- elif (not decomp_attn ) and "gfx11" in target :
190- attn_spec = "wmma"
185+ if not attn_spec :
186+ if (not decomp_attn ) and use_punet :
187+ attn_spec = "punet"
188+ elif (not decomp_attn ) and "gfx9" in target :
189+ attn_spec = "mfma"
190+ elif (not decomp_attn ) and "gfx11" in target :
191+ attn_spec = "wmma"
191192 safe_name = utils .create_safe_name (
192193 hf_model_name ,
193194 f"_bs{ batch_size } _{ max_length } _{ height } x{ width } _{ precision } _{ submodel_name } " ,
You can’t perform that action at this time.
0 commit comments