Skip to content

Commit e7f25a3

Browse files
committed
init
1 parent a5eeef0 commit e7f25a3

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,18 @@ def build_args_parser() -> argparse.ArgumentParser:
229229
action="store_true",
230230
help="Whether or not to export a model using kv cache",
231231
)
232+
parser.add_argument(
233+
"--prefill_return_kv",
234+
default=False,
235+
action="store_true",
236+
help="Whether or not to return kv values from prefill model",
237+
)
238+
parser.add_argument(
239+
"--prefill_seq_length",
240+
default=False,
241+
action="store_true",
242+
help="Sequence length for prefill model",
243+
)
232244
parser.add_argument(
233245
"--quantize_kv_cache",
234246
default=False,

examples/models/llama/llama_transformer.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,25 @@
2323

2424
from torch import nn
2525

26+
@torch.library.custom_op("coreml::sdpa", mutates_args=())
27+
def sdpa(
28+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
29+
) -> torch.Tensor:
30+
"""Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion."""
31+
return torch.ops.aten.scaled_dot_product_attention.default(
32+
q, k, v, attn_mask=attn_mask
33+
)
34+
35+
36+
@torch.library.register_fake("coreml::sdpa")
37+
def _(
38+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
39+
) -> torch.Tensor:
40+
"""Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing."""
41+
expected_shape = list(q.shape)
42+
expected_shape[-1] = v.shape[-1]
43+
return q.new_empty(expected_shape)
44+
2645

2746
class RMSNorm(torch.nn.Module):
2847
def __init__(self, dim: int, eps: float = 1e-6):
@@ -431,7 +450,7 @@ def forward(
431450

432451
mask = self.mask[:seqlen, :seqlen]
433452

434-
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
453+
output = torch.ops.coreml.sdpa(q, k, v, mask)
435454

436455
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
437456

examples/models/llama/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def get_example_inputs(self):
273273
else:
274274
return (
275275
torch.tensor(
276-
[[1, 2, 3]], dtype=torch.long
276+
[[0 for _ in range(self.args.get("prefill_seq_length", 3))]], dtype=torch.long
277277
), # tokens, with kv cache our input token length is always just 1 token.
278278
)
279279

0 commit comments

Comments
 (0)