Skip to content

Commit 174e73d

Browse files
committed
Update README and weights saving imports
1 parent 32d3fa0 commit 174e73d

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

models/turbine_models/custom_models/torchbench/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,5 @@ cd ..
3232
### Export and compile
3333

3434
```shell
35-
python ./export.py --model_id=All --target=gfx942 --device=hip --compile_to=vmfb --accuracy --inference
35+
python ./export.py --target=gfx942 --device=rocm --compile_to=vmfb --performance --inference --precision=fp16 --float16 --external_weights=safetensors --external_weights_dir=./torchbench_weights/
3636
```

models/turbine_models/custom_models/torchbench/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import safetensors
66
import safetensors.numpy as safe_numpy
7+
import safetensors.torch as safe_torch
78
import re
89
import glob
910

@@ -461,7 +462,7 @@ def save_external_weights(
461462
mod_params = vae_params
462463
if external_weight_file and not os.path.isfile(external_weight_file):
463464
if not force_format:
464-
safetensors.torch.save_file(mod_params, external_weight_file)
465+
safe_torch.save_file(mod_params, external_weight_file)
465466
else:
466467
for x in mod_params.keys():
467468
mod_params[x] = mod_params[x].numpy()

0 commit comments

Comments
 (0)