Skip to content

Commit 6417869

Browse files
committed
init
1 parent 92ee522 commit 6417869

File tree

4 files changed

+42
-16
lines changed

4 files changed

+42
-16
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,23 @@
2323

2424
from torch import nn
2525

26+
@torch.library.custom_op("coreml::sdpa", mutates_args=())
27+
def sdpa(
28+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
29+
) -> torch.Tensor:
30+
"""Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion."""
31+
return torch.ops.aten.scaled_dot_product_attention.default(
32+
q, k, v, attn_mask=attn_mask
33+
)
34+
35+
@torch.library.register_fake("coreml::sdpa")
36+
def _(
37+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
38+
) -> torch.Tensor:
39+
"""Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing."""
40+
expected_shape = list(q.shape)
41+
expected_shape[-1] = v.shape[-1]
42+
return q.new_empty(expected_shape)
2643

2744
class RMSNorm(torch.nn.Module):
2845
def __init__(self, dim: int, eps: float = 1e-6):
@@ -351,7 +368,8 @@ def forward(
351368

352369
mask = self.mask[:seqlen, :seqlen]
353370

354-
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
371+
# output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
372+
output = torch.ops.coreml.sdpa(q, k, v, mask)
355373

356374
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
357375

examples/models/llama/model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,7 @@ def get_example_inputs(self):
244244
return self.get_example_inputs_kvcache_sdpa()
245245
else:
246246
return (
247-
torch.tensor(
248-
[[1, 2, 3]], dtype=torch.long
249-
), # tokens, with kv cache our input token length is always just 1 token.
247+
torch.ones(size=(1, self.max_seq_len), dtype=torch.long),
250248
)
251249

252250
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working

extension/llm/export/builder.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -156,18 +156,17 @@ def source_transform(
156156
def _get_dynamic_shape(self) -> Any:
157157
if self.dynamic_shapes:
158158
return self.dynamic_shapes
159-
160-
dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1)
161-
162-
if not self.use_kv_cache:
163-
# Only one input argument: tokens
164-
self.dynamic_shapes = ({1: dim},)
165-
elif self.enable_dynamic_shape:
166-
# Two input arguments: tokens and input_pos but input_pos is static shape
167-
self.dynamic_shapes = ({1: dim}, {0: 1})
168-
else:
169-
# Two input arguments: tokens and input_pos but both are of static shape
159+
160+
if not self.enable_dynamic_shape:
170161
self.dynamic_shapes = None
162+
else:
163+
dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1)
164+
if not self.use_kv_cache:
165+
# Only one input argument: tokens
166+
self.dynamic_shapes = ({1: dim},)
167+
else:
168+
# Two input arguments: tokens and input_pos but input_pos is static shape
169+
self.dynamic_shapes = ({1: dim}, {0: 1})
171170
return self.dynamic_shapes
172171

173172
def _get_edge_config(self) -> EdgeCompileConfig:

extension/llm/export/partitioner_lib.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,19 @@ def _validate_ios_version() -> None:
128128
"block_size": 32,
129129
"weight_threshold": 512,
130130
}
131+
132+
assert ios == 18
133+
print("OVERRIDING CONFIG TO BE 4B PER_CHANNEL")
134+
op_linear_quantizer_config = {
135+
"mode": "linear_symmetric",
136+
"dtype": "int4",
137+
"granularity": "per_channel",
138+
}
131139
compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
132140
minimum_deployment_target=minimum_deployment_target,
133141
compute_precision=ct.precision(ct.precision.FLOAT16.value),
134142
# using `ComputeUnit.ALL` can increase the model load time, default to `ComputeUnit.CPU_AND_GPU`
135-
compute_unit=ct.ComputeUnit[ct.ComputeUnit.CPU_AND_GPU.name.upper()],
143+
compute_unit=ct.ComputeUnit[ct.ComputeUnit.CPU_AND_AE.name.upper()],
136144
model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16]
137145
op_linear_quantizer_config=op_linear_quantizer_config,
138146
)
@@ -142,6 +150,9 @@ def _validate_ios_version() -> None:
142150
return CoreMLPartitioner( # pyre-fixme[16]
143151
compile_specs=compile_specs,
144152
take_over_mutable_buffer=take_over_mutable_buffer,
153+
skip_ops_for_coreml_delegation=[
154+
"aten.embedding.default",
155+
],
145156
)
146157

147158

0 commit comments

Comments
 (0)