Skip to content

Commit 32d3fa0

Browse files
committed
Add torchbench exports.
1 parent 5eb013d commit 32d3fa0

File tree

4 files changed

+991
-0
lines changed

4 files changed

+991
-0
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# SHARK torchbench exports and benchmarks
2+
3+
### Setup
4+
5+
- pip install torch+rocm packages:
6+
```shell
7+
pip install --pre torch==2.5.0.dev20240801+rocm6.1 torchvision==0.20.0.dev20240801+rocm6.1 torchaudio==2.4.0.dev20240801%2Brocm6.1 --index-url https://download.pytorch.org/whl/nightly/rocm6.1
8+
9+
```
10+
- Workaround amdsmi error in pre-release pytorch+rocm:
11+
```shell
12+
sudo apt install amd-smi-lib
13+
sudo chown -R $USER:$USER /opt/rocm/share/amd_smi
14+
python3 -m pip install /opt/rocm/share/amd_smi
15+
```
16+
- Clone torch and expose benchmarking code as a relative module:
17+
```shell
18+
git clone https://github.com/pytorch/pytorch
19+
cd pytorch/benchmarks
20+
touch __init__.py
21+
cd ../..
22+
```
23+
- Clone and install pytorch benchmark modules:
24+
```shell
25+
git clone https://github.com/pytorch/benchmark
26+
cd benchmark
27+
python3 install.py --models BERT_pytorch Background_Matting LearningToPaint alexnet dcgan densenet121 hf_Albert hf_Bart hf_Bert hf_GPT2 hf_T5 mnasnet1_0 mobilenet_v2 mobilenet_v3_large nvidia_deeprecommender pytorch_unet resnet18 resnet50 resnet50_32x4d shufflenet_v2_x1_0 squeezenet1_1 timm_nfnet timm_efficientnet timm_regnet timm_resnest timm_vision_transformer timm_vovnet vgg16
28+
pip install -e .
29+
cd ..
30+
```
31+
32+
### Export and compile
33+
34+
```shell
35+
python ./export.py --model_id=All --target=gfx942 --device=hip --compile_to=vmfb --accuracy --inference
36+
```
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import argparse
2+
import os
3+
from pathlib import Path
4+
5+
6+
def path_expand(s):
7+
return Path(s).expanduser().resolve()
8+
9+
10+
def is_valid_file(arg):
11+
if not os.path.exists(arg):
12+
return None
13+
else:
14+
return arg
15+
16+
17+
# Note: this is where command-line options for the scripts in this directory
18+
# are defined along with their defaults. Thus, they should not be referenced
19+
# within modelling or inference code, only at the entry point to the script.
20+
21+
# We should consider separating out the options that are "model configs" from
22+
# the options that control the compiler, runtime, and script behavior,
23+
# when applicable, as the former would best be kept in a separate
24+
# config or imported from huggingface.
25+
26+
p = argparse.ArgumentParser(
27+
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
28+
)
29+
30+
##############################################################################
31+
# general options
32+
##############################################################################
33+
34+
p.add_argument(
35+
"--hf_auth_token",
36+
type=str,
37+
help="The Hugging Face auth token, if required",
38+
default=None,
39+
)
40+
p.add_argument(
41+
"--model_id",
42+
type=str,
43+
help="model ID as it appears in the torchbench models text file lists, or 'all' for batch export",
44+
default="all",
45+
)
46+
p.add_argument(
47+
"--external_weights_dir",
48+
type=str,
49+
default="",
50+
help="Path to external weights file, for jobs with one weights filepath. When importing, this is used to specify where to save the model weights, and at runtime, this is used to specify where to load the model weights from.",
51+
)
52+
p.add_argument(
53+
"--vmfbs_dir", type=str, default="", help="path to vmfb containing compiled module"
54+
)
55+
p.add_argument(
56+
"--benchmark",
57+
type=str,
58+
default=None,
59+
help="A comma-separated list of submodel IDs for which to report benchmarks for, or 'all' for all components.",
60+
)
61+
p.add_argument(
62+
"--save_outputs",
63+
type=str,
64+
default=None,
65+
help="A comma-separated list of submodel IDs for which to save output .npys for, or 'all' for all components.",
66+
)
67+
p.add_argument("--compile_to", type=str, default="mlir", help="torch, linalg, vmfb")
68+
p.add_argument(
69+
"--external_weights",
70+
type=str,
71+
default=None,
72+
choices=["safetensors", "irpa", "gguf", None],
73+
help="Externalizes model weights from the torch dialect IR and its successors",
74+
)
75+
76+
##############################################################################
77+
# Modeling and Export Options
78+
# These options are used to control model defining parameters.
79+
# These are MLIR - changing variables! If you change them, you will need
80+
# to import/download and recompile the model.
81+
##############################################################################
82+
83+
p.add_argument("--batch_size", type=int, default=1, help="Batch size for inference")
84+
p.add_argument(
85+
"--precision",
86+
type=str,
87+
default="fp16",
88+
help="Precision of Stable Diffusion weights and graph.",
89+
)
90+
p.add_argument(
91+
"--decomp_attn",
92+
default=False,
93+
action="store_true",
94+
help="Decompose attention at fx graph level",
95+
)
96+
97+
# See --external_weight_path and --external_weight_dir to specify where to save the model weights.
98+
99+
p.add_argument(
100+
"--compare_vs_torch",
101+
action="store_true",
102+
help="Runs both turbine vmfb and a torch model to compare results",
103+
)
104+
p.add_argument(
105+
"--input_mlir",
106+
type=str,
107+
default=None,
108+
help="Path to input mlir file to compile. Comma-separate paths to provide more than one input to pipelines.",
109+
)
110+
111+
112+
##############################################################################
113+
# IREE Compiler Options
114+
##############################################################################
115+
116+
p.add_argument(
117+
"--device",
118+
type=str,
119+
default="local-task",
120+
help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.",
121+
)
122+
p.add_argument(
123+
"--target",
124+
type=str,
125+
default="gfx942",
126+
help="Usually a rocm chip arch or llvmcpu target triple, e.g. gfx942 or x86_64-linux-gnu.",
127+
)
128+
p.add_argument("--ireec_flags", type=str, default="", help="extra iree-compile options")
129+
p.add_argument(
130+
"--attn_spec",
131+
type=str,
132+
default=None,
133+
help="extra iree-compile options for models with sdpa ops.",
134+
)
135+
136+
137+
args, unknown = p.parse_known_args()

0 commit comments

Comments
 (0)