Skip to content

Commit 53d71cb

Browse files
committed
enumerated-shape-test
1 parent dc2e02a commit 53d71cb

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import coremltools as ct
2+
import numpy as np
3+
import torch
4+
from executorch.examples.models.llama.export_llama_lib import (
5+
_prepare_for_llama_export,
6+
build_args_parser,
7+
)
8+
from numpy import dtype
9+
10+
parser = build_args_parser()
11+
args = parser.parse_args()
12+
13+
model_manager = _prepare_for_llama_export("llama2", args)
14+
15+
model = model_manager.model
16+
model.eval()
17+
18+
19+
def get_example_inputs(max_batch_size, args, coreml=False, use_enumerated_shapes=False):
20+
tokens = torch.tensor([[1 for _ in range(max_batch_size)]], dtype=torch.long)
21+
if use_enumerated_shapes:
22+
ct_tokens_shape = ct.EnumeratedShapes(
23+
shapes=[
24+
[1, 1],
25+
[1, max_batch_size],
26+
],
27+
default=[1, max_batch_size],
28+
)
29+
else:
30+
ct_tokens_shape = ct.Shape([1, max_batch_size])
31+
32+
ct_tokens = ct.TensorType(
33+
shape=ct_tokens_shape,
34+
dtype=np.int64,
35+
)
36+
37+
if args.use_kv_cache:
38+
# NOTE: torch.jit.trace does not work if tensor has size 1, but ct.convert does not work if not 512, so for KV cache with batch input, size should be 1
39+
# input_pos = torch.tensor([0 for _ in range(max_batch_size)], dtype=torch.long)
40+
input_pos = torch.tensor([0], dtype=torch.long)
41+
ct_input_pos = ct.TensorType(shape=ct.Shape([1]), dtype=np.int64)
42+
43+
if coreml:
44+
return (ct_tokens, ct_input_pos)
45+
return (tokens, input_pos)
46+
47+
if coreml:
48+
return (ct_tokens,)
49+
return (tokens,)
50+
51+
52+
# Batch with kv cache runs into issues
53+
# Either we need input_pos to be size batch_size to export with jit.trace or we need it to be size 1 to export with ct.convert
54+
# Might try refactoring the model so that jit.trace works when it is size 1 (interested as starting position)
55+
if args.use_kv_cache:
56+
max_batch_size = 1
57+
else:
58+
max_batch_size = 128
59+
60+
example_inputs = get_example_inputs(max_batch_size, args)
61+
62+
print("Example input shapes: ", [t.shape for t in example_inputs])
63+
64+
traced_model = torch.jit.trace(model, example_inputs)
65+
66+
states = None
67+
if args.use_kv_cache:
68+
states = [
69+
ct.StateType(
70+
wrapped_type=ct.TensorType(
71+
shape=v[1].shape,
72+
),
73+
name=v[0],
74+
)
75+
for v in traced_model.named_buffers()
76+
if v[0].endswith("_cache")
77+
]
78+
79+
mlmodel = ct.convert(
80+
traced_model,
81+
inputs=list(
82+
get_example_inputs(
83+
max_batch_size=max_batch_size,
84+
args=args,
85+
coreml=True,
86+
use_enumerated_shapes=True,
87+
)
88+
),
89+
outputs=[ct.TensorType(name="op")],
90+
states=states,
91+
minimum_deployment_target=ct.target.iOS18,
92+
compute_units=ct.ComputeUnit.CPU_AND_NE,
93+
)
94+
95+
mlmodel.save(args.output_name)

0 commit comments

Comments
 (0)