Skip to content

Commit a27f9a0

Browse files
committed
add enumeratred shape to argparse
1 parent 8b1ba93 commit a27f9a0

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/models/llama/coreml_enumerated_shape.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from numpy import dtype
99

1010
parser = build_args_parser()
11+
parser.add_argument('--use_enumerated_shapes', action='store_true')
1112
args = parser.parse_args()
1213

1314
model_manager = _prepare_for_llama_export("llama2", args)
@@ -83,7 +84,7 @@ def get_example_inputs(max_batch_size, args, coreml=False, use_enumerated_shapes
8384
max_batch_size=max_batch_size,
8485
args=args,
8586
coreml=True,
86-
use_enumerated_shapes=True,
87+
use_enumerated_shapes=args.use_enumerated_shapes,
8788
)
8889
),
8990
outputs=[ct.TensorType(name="op")],

0 commit comments

Comments
 (0)