Skip to content

Commit 7702a9e

Browse files
committed
add smaller repro
1 parent 0d70105 commit 7702a9e

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copied from https://apple.github.io/coremltools/docs-guides/source/stateful-models.html#example-toy-attention-model-with-stateful-kv-cache
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
7+
class SimpleAttention(nn.Module):
8+
def __init__(self, embed_size):
9+
super().__init__()
10+
self.query = nn.Linear(embed_size, embed_size)
11+
self.key = nn.Linear(embed_size, embed_size)
12+
self.value = nn.Linear(embed_size, embed_size)
13+
14+
def forward(self, x):
15+
Q = self.query(x) # (batch_size, seq_len, embed_size)
16+
K = self.key(x) # (batch_size, seq_len, embed_size)
17+
V = self.value(x) # (batch_size, seq_len, embed_size)
18+
return torch.nn.functional.scaled_dot_product_attention(Q, K, V)
19+
20+
21+
class ToyModel(nn.Module):
22+
def __init__(self, vocab_size, embed_size):
23+
super().__init__()
24+
self.embedding = nn.Embedding(vocab_size, embed_size)
25+
self.attention = SimpleAttention(embed_size)
26+
self.fc = nn.Linear(embed_size, embed_size)
27+
28+
def forward(self, x):
29+
embedded = self.embedding(x)
30+
attention_output = self.attention(embedded)
31+
return self.fc(attention_output)
32+
33+
34+
vocab_size = 32000
35+
embed_size = 1024
36+
batch_size = 1
37+
seq_len = 5
38+
max_seq_len = 1024
39+
num_iterations = 100
40+
41+
import coremltools as ct
42+
import numpy as np
43+
44+
torch_model = ToyModel(vocab_size, embed_size)
45+
torch_model.eval()
46+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
47+
torch_output = torch_model(input_ids).detach().numpy()
48+
traced_model = torch.jit.trace(torch_model, [input_ids])
49+
50+
############################################################################################################
51+
# Using
52+
# query_length = ct.RangeDim(lower_bound=1, upper_bound=max_seq_len, default=1)
53+
# leads to mlpackage file that crashes on phone. Changing it to static
54+
# query_length = 1
55+
# leads to mlpackage file that runs on phone.
56+
############################################################################################################
57+
query_length = ct.RangeDim(lower_bound=1, upper_bound=max_seq_len, default=1)
58+
# query_length = 1
59+
60+
61+
inputs = [
62+
ct.TensorType(shape=(batch_size, query_length), dtype=np.int32, name="input_ids")
63+
]
64+
outputs = [ct.TensorType(dtype=np.float16, name="output")]
65+
66+
converted_model = ct.convert(
67+
traced_model,
68+
inputs=inputs,
69+
outputs=outputs,
70+
minimum_deployment_target=ct.target.iOS18,
71+
compute_units=ct.ComputeUnit.CPU_AND_GPU,
72+
)
73+
74+
converted_model.save("/Users/scroy/Desktop/coreml.mlpackage")

0 commit comments

Comments
 (0)