|
| 1 | +# Copied from https://apple.github.io/coremltools/docs-guides/source/stateful-models.html#example-toy-attention-model-with-stateful-kv-cache |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | + |
| 6 | + |
| 7 | +class SimpleAttention(nn.Module): |
| 8 | + def __init__(self, embed_size): |
| 9 | + super().__init__() |
| 10 | + self.query = nn.Linear(embed_size, embed_size) |
| 11 | + self.key = nn.Linear(embed_size, embed_size) |
| 12 | + self.value = nn.Linear(embed_size, embed_size) |
| 13 | + |
| 14 | + def forward(self, x): |
| 15 | + Q = self.query(x) # (batch_size, seq_len, embed_size) |
| 16 | + K = self.key(x) # (batch_size, seq_len, embed_size) |
| 17 | + V = self.value(x) # (batch_size, seq_len, embed_size) |
| 18 | + return torch.nn.functional.scaled_dot_product_attention(Q, K, V) |
| 19 | + |
| 20 | + |
| 21 | +class ToyModel(nn.Module): |
| 22 | + def __init__(self, vocab_size, embed_size): |
| 23 | + super().__init__() |
| 24 | + self.embedding = nn.Embedding(vocab_size, embed_size) |
| 25 | + self.attention = SimpleAttention(embed_size) |
| 26 | + self.fc = nn.Linear(embed_size, embed_size) |
| 27 | + |
| 28 | + def forward(self, x): |
| 29 | + embedded = self.embedding(x) |
| 30 | + attention_output = self.attention(embedded) |
| 31 | + return self.fc(attention_output) |
| 32 | + |
| 33 | + |
| 34 | +vocab_size = 32000 |
| 35 | +embed_size = 1024 |
| 36 | +batch_size = 1 |
| 37 | +seq_len = 5 |
| 38 | +max_seq_len = 1024 |
| 39 | +num_iterations = 100 |
| 40 | + |
| 41 | +import coremltools as ct |
| 42 | +import numpy as np |
| 43 | + |
| 44 | +torch_model = ToyModel(vocab_size, embed_size) |
| 45 | +torch_model.eval() |
| 46 | +input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) |
| 47 | +torch_output = torch_model(input_ids).detach().numpy() |
| 48 | +traced_model = torch.jit.trace(torch_model, [input_ids]) |
| 49 | + |
| 50 | +############################################################################################################ |
| 51 | +# Using |
| 52 | +# query_length = ct.RangeDim(lower_bound=1, upper_bound=max_seq_len, default=1) |
| 53 | +# leads to mlpackage file that crashes on phone. Changing it to static |
| 54 | +# query_length = 1 |
| 55 | +# leads to mlpackage file that runs on phone. |
| 56 | +############################################################################################################ |
| 57 | +query_length = ct.RangeDim(lower_bound=1, upper_bound=max_seq_len, default=1) |
| 58 | +# query_length = 1 |
| 59 | + |
| 60 | + |
| 61 | +inputs = [ |
| 62 | + ct.TensorType(shape=(batch_size, query_length), dtype=np.int32, name="input_ids") |
| 63 | +] |
| 64 | +outputs = [ct.TensorType(dtype=np.float16, name="output")] |
| 65 | + |
| 66 | +converted_model = ct.convert( |
| 67 | + traced_model, |
| 68 | + inputs=inputs, |
| 69 | + outputs=outputs, |
| 70 | + minimum_deployment_target=ct.target.iOS18, |
| 71 | + compute_units=ct.ComputeUnit.CPU_AND_GPU, |
| 72 | +) |
| 73 | + |
| 74 | +converted_model.save("/Users/scroy/Desktop/coreml.mlpackage") |
0 commit comments