Skip to content

Commit 9afac3c

Browse files
danielssonsimonbcgsimondanielsson
authored andcommitted
Initial commit
1 parent 68f201a commit 9afac3c

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

optimum/commands/export/onnx.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,16 @@
3434
def parse_args_onnx(parser):
3535
required_group = parser.add_argument_group("Required arguments")
3636
required_group.add_argument(
37-
"-m", "--model", type=str, required=True, help="Model ID on huggingface.co or path on disk to load model from."
37+
"-m",
38+
"--model",
39+
type=str,
40+
required=True,
41+
help="Model ID on huggingface.co or path on disk to load model from.",
3842
)
3943
required_group.add_argument(
40-
"output", type=Path, help="Path indicating the directory where to store the generated ONNX model."
44+
"output",
45+
type=Path,
46+
help="Path indicating the directory where to store the generated ONNX model.",
4147
)
4248

4349
optional_group = parser.add_argument_group("Optional arguments")
@@ -127,7 +133,10 @@ def parse_args_onnx(parser):
127133
help="If specified, the absolute difference tolerance when validating the model. Otherwise, the default atol for the model will be used.",
128134
)
129135
optional_group.add_argument(
130-
"--cache_dir", type=str, default=HUGGINGFACE_HUB_CACHE, help="Path indicating where to store cache."
136+
"--cache_dir",
137+
type=str,
138+
default=HUGGINGFACE_HUB_CACHE,
139+
help="Path indicating where to store cache.",
131140
)
132141
optional_group.add_argument(
133142
"--trust-remote-code",
@@ -148,12 +157,16 @@ def parse_args_onnx(parser):
148157
type=str,
149158
choices=["transformers", "diffusers", "timm", "sentence_transformers"],
150159
default=None,
151-
help=("The library on the model. If not provided, will attempt to infer the local checkpoint's library"),
160+
help=(
161+
"The library on the model. If not provided, will attempt to infer the local checkpoint's library"
162+
),
152163
)
153164
optional_group.add_argument(
154165
"--model-kwargs",
155166
type=json.loads,
156-
help=("Any kwargs passed to the model forward, or used to customize the export for a given model."),
167+
help=(
168+
"Any kwargs passed to the model forward, or used to customize the export for a given model."
169+
),
157170
)
158171
optional_group.add_argument(
159172
"--legacy",
@@ -164,7 +177,9 @@ def parse_args_onnx(parser):
164177
),
165178
)
166179
optional_group.add_argument(
167-
"--no-dynamic-axes", action="store_true", help="Disable dynamic axes during ONNX export"
180+
"--no-dynamic-axes",
181+
action="store_true",
182+
help="Disable dynamic axes during ONNX export",
168183
)
169184
optional_group.add_argument(
170185
"--no-constant-folding",
@@ -271,6 +286,8 @@ def parse_args(parser: ArgumentParser):
271286
def run(self):
272287
from optimum.exporters.onnx import main_export
273288

289+
print("Hello from optimum-onnx/onnx.py")
290+
274291
# Get the shapes to be used to generate dummy inputs
275292
input_shapes = {}
276293
for input_name in DEFAULT_DUMMY_SHAPES:

0 commit comments

Comments
 (0)