Skip to content

Commit e1ab893

Browse files
Di Xu (SWE)facebook-github-bot
authored andcommitted
Make oss Coreml Llama supports both list and tensor KV cache inputs (#9225)
Summary: Pull Request resolved: #9225 Make oss Coreml Llama supports both list and tensor KV cache inputs - reorder input args sequence to match with other code - Make optional static_seq_len for the input tokens/static_seq_len Differential Revision: D71081340
1 parent cb3ec19 commit e1ab893

File tree

2 files changed

+65
-12
lines changed

2 files changed

+65
-12
lines changed

examples/apple/coreml/llama/llama_transformer.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class ModelArgs:
4242
ffn_dim_multiplier: Optional[float] = None
4343
norm_eps: float = 1e-5
4444
max_batch_size: int = 1
45+
static_seq_len: int = 32
4546
max_seq_len: int = 128
4647
max_context_len: int = 2048
4748
moe: bool = False # True to enable the MoE (Mixture of Experts)
@@ -398,15 +399,18 @@ def __init__(self, params: ModelArgs):
398399
self.input_prune_map = params.input_prune_map
399400
self.output_prune_map = params.output_prune_map
400401
self.use_cache_list = params.use_cache_list
402+
if self.use_cache_list:
403+
# pyre-ignore: Incompatible attribute type [8]
404+
self.forward = self.forward_use_cache_list
401405

402-
def forward(
406+
def forward_use_cache_list(
403407
self,
404408
tokens: torch.LongTensor, # tokens
405409
input_pos: torch.LongTensor,
406-
input_length: torch.LongTensor, # input_length
407410
k_caches: List[torch.FloatTensor],
408411
v_caches: List[torch.FloatTensor],
409-
attn_mask: torch.LongTensor,
412+
attn_mask: torch.FloatTensor,
413+
input_length: torch.LongTensor, # input_length
410414
h: Optional[torch.FloatTensor] = None, # embeddings
411415
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
412416
if (tokens is None) ^ (h is not None):
@@ -425,8 +429,8 @@ def forward(
425429
h,
426430
freqs_cos,
427431
freqs_sin,
428-
k_caches[i] if self.use_cache_list else k_caches[i, :, :, :, :],
429-
v_caches[i] if self.use_cache_list else v_caches[i, :, :, :, :],
432+
k_caches[i],
433+
v_caches[i],
430434
attn_mask,
431435
)
432436
k_out.append(new_k)
@@ -445,15 +449,64 @@ def forward(
445449
v_out = torch.stack(v_out, dim=0)
446450
return logits, k_out, v_out # pyre-ignore[7]
447451

452+
def forward(
453+
self,
454+
tokens: torch.LongTensor, # tokens
455+
input_pos: torch.LongTensor,
456+
k_caches: torch.FloatTensor,
457+
v_caches: torch.FloatTensor,
458+
attn_mask: torch.FloatTensor,
459+
input_length: torch.LongTensor, # input_length
460+
h: Optional[torch.FloatTensor] = None, # embeddings
461+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
462+
if (tokens is None) ^ (h is not None):
463+
raise ValueError(
464+
"You cannot specify both tokens and h at the same time, and must specify either one"
465+
)
466+
if tokens is not None and h is None:
467+
h = self.tok_embeddings(tokens)
468+
seqlen = h.shape[1]
469+
freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen)
448470

449-
def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list):
471+
k_out = []
472+
v_out = []
473+
for i, layer in enumerate(self.layers):
474+
h, new_k, new_v = layer(
475+
h,
476+
freqs_cos,
477+
freqs_sin,
478+
k_caches[i, :, :, :, :],
479+
v_caches[i, :, :, :, :],
480+
attn_mask,
481+
)
482+
k_out.append(new_k)
483+
v_out.append(new_v)
484+
485+
if not self.generate_full_logits:
486+
# Only the last logit is used for the new generated token
487+
h = h[:, input_length - 1, :].squeeze(1)
488+
489+
h = self.norm(h)
490+
491+
logits = self.output(h)
492+
493+
if not self.use_cache_list:
494+
k_out = torch.stack(k_out, dim=0)
495+
v_out = torch.stack(v_out, dim=0)
496+
return logits, k_out, v_out # pyre-ignore[7]
497+
498+
499+
def load_model(
500+
checkpoint_path, params_path, max_seq_length, use_cache_list, static_seq_len=32
501+
):
450502
import json
451503

452504
with open(params_path, "r") as f:
453505
params = json.loads(f.read())
454506

455507
args = ModelArgs(
456508
max_seq_len=max_seq_length,
509+
static_seq_len=static_seq_len,
457510
generate_full_logits=False,
458511
use_cache_list=use_cache_list,
459512
**params,
@@ -618,14 +671,14 @@ def get_inputs(self, tokens: List[int]):
618671
).reshape(1, -1),
619672
# input_pos
620673
torch.tensor([self.input_pos], dtype=torch.long),
621-
# input_length
622-
torch.tensor([input_length], dtype=torch.long),
623674
# k_cache
624675
self.k_caches,
625676
# v_cache
626677
self.v_caches,
627678
# attn_mask
628-
self.attn_mask,
679+
torch.zeros(self.attn_mask.shape, dtype=torch.float16),
680+
# input_length
681+
torch.tensor([input_length], dtype=torch.long),
629682
)
630683

631684
def get_inputs_and_remaining_tokens(self, tokens: List[int]):

examples/apple/coreml/llama/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ def __init__(
2121
self.out_split_sizes = self._get_split_sizes(
2222
out_features, out_target_split_size, out_max_splits
2323
)
24-
self.in_split_sizes = self._get_split_sizes(
25-
in_features, in_target_split_size, in_max_splits
26-
)
2724
print(
2825
f"Splitting out_features={out_features} into {len(self.out_split_sizes)} of size {self.out_split_sizes[0]}."
2926
)
27+
self.in_split_sizes = self._get_split_sizes(
28+
in_features, in_target_split_size, in_max_splits
29+
)
3030
print(
3131
f"Splitting in_features={in_features} into {len(self.in_split_sizes)} of size {self.in_split_sizes[0]}."
3232
)

0 commit comments

Comments
 (0)