Skip to content

Commit c962278

Browse files
committed
up
1 parent e4ac33d commit c962278

File tree

6 files changed

+146
-40
lines changed

6 files changed

+146
-40
lines changed

backends/apple/coreml/compiler/coreml_preprocess.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,5 +449,8 @@ def preprocess(
449449
op_type_configs={"gather": None},
450450
)
451451
mlmodel = cto.coreml.linear_quantize_weights(mlmodel, config=config)
452+
453+
print("MIL program:")
454+
print(mlmodel._mil_program)
452455

453456
return CoreMLBackend.preprocess_model(mlmodel, model_type=model_type)

examples/models/llama/export_llama_lib.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,18 @@ def build_args_parser() -> argparse.ArgumentParser:
229229
action="store_true",
230230
help="Whether or not to export a model using kv cache",
231231
)
232+
parser.add_argument(
233+
"--decode_kv_cache_as_io",
234+
default=False,
235+
action="store_true",
236+
help="Whether decode models accepts KV cache as IO",
237+
)
238+
parser.add_argument(
239+
"--use_additive_kv_cache_update",
240+
default=False,
241+
action="store_true",
242+
help="Whether use additive KV cache updates",
243+
)
232244
parser.add_argument(
233245
"--prefill_return_kv",
234246
default=False,

examples/models/llama/llama_transformer.py

Lines changed: 99 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ class ModelArgs:
115115
num_activated_experts: int = 2 # Number of experts to activate
116116
use_kv_cache: bool = False # Use key/value cache
117117
prefill_return_kv: bool = False # Return kv cache for prefill
118+
decode_kv_cache_as_io: bool = False # Decode uses KV caches as IO
119+
use_additive_kv_cache_update: bool = False # Additive KV cache update
118120
use_sdpa_with_kv_cache_op: bool = (
119121
False # Use custom sdpa op that updates kv cache in-place
120122
)
@@ -367,6 +369,9 @@ class Attention(nn.Module):
367369
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
368370
super().__init__()
369371
self.use_kv_cache = args.use_kv_cache
372+
self.decode_kv_cache_as_io = args.decode_kv_cache_as_io
373+
self.use_additive_kv_cache_update = args.use_additive_kv_cache_update
374+
self.return_kv_values = (args.prefill_return_kv or args.decode_kv_cache_as_io)
370375
self.n_heads = args.n_heads
371376
self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads
372377
assert self.n_heads % self.n_kv_heads == 0
@@ -397,7 +402,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
397402
)
398403
self.register_buffer("mask", causal_mask, persistent=False)
399404

400-
if self.use_kv_cache:
405+
if self.use_kv_cache and not self.decode_kv_cache_as_io:
401406
self.kv_cache = KVCache(
402407
args.max_batch_size,
403408
args.max_seq_len,
@@ -421,10 +426,19 @@ def forward(
421426
freqs_cos: torch.Tensor,
422427
freqs_sin: torch.Tensor,
423428
input_pos: Optional[torch.Tensor] = None,
424-
return_kv: bool = False,
429+
k_cache: Optional[torch.Tensor] = None,
430+
v_cache: Optional[torch.Tensor] = None,
431+
cache_pos_mask: Optional[torch.Tensor] = None,
425432
):
426-
if return_kv:
427-
assert self.use_kv_cache == False, "Can't return kv when use_kv_cache is True"
433+
if self.decode_kv_cache_as_io:
434+
assert self.use_kv_cache
435+
assert k_cache is not None
436+
assert v_cache is not None
437+
assert self.return_kv_values
438+
439+
if self.use_additive_kv_cache_update:
440+
assert self.decode_kv_cache_as_io
441+
assert cache_pos_mask is not None
428442

429443
bsz, seqlen, _ = x.shape
430444

@@ -438,34 +452,53 @@ def forward(
438452
# RoPE relative positional embeddings
439453
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
440454

441-
if self.use_kv_cache:
455+
if self.use_kv_cache and not self.decode_kv_cache_as_io:
442456
assert input_pos is not None
457+
assert not self.return_kv_values
443458
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
444459
return self.wo(output)
445460

446461
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
447462
k = k.transpose(1, 2)
448463
v = v.transpose(1, 2)
449464

450-
if return_kv:
465+
if self.return_kv_values:
451466
k_ret = k
452467
v_ret = v
453-
454-
# grouped multiquery attention: expand out keys and values
455-
k = k.repeat_interleave(self.n_rep, dim=1)
456-
v = v.repeat_interleave(self.n_rep, dim=1)
457-
468+
458469
assert hasattr(self, "mask")
470+
if self.decode_kv_cache_as_io:
471+
assert self.use_kv_cache
472+
mask = self.mask[None, None, input_pos]
473+
if self.use_additive_kv_cache_update:
474+
assert cache_pos_mask is not None
475+
assert seqlen == 1
476+
k_update = cache_pos_mask * k
477+
v_update = cache_pos_mask * v
478+
k = k_cache + k_update
479+
v = v_cache + v_update
480+
assert k.shape == k_cache.shape
481+
assert v.shape == v_cache.shape
482+
else:
483+
k = torch.ops.aten.index_put(k_cache, [None, None, input_pos, None], k)
484+
v = torch.ops.aten.index_put(v_cache, [None, None, input_pos, None], v)
485+
else:
486+
assert not self.use_kv_cache
487+
mask = self.mask[:seqlen, :seqlen]
488+
459489

460-
mask = self.mask[:seqlen, :seqlen]
490+
# grouped multiquery attention: expand out keys and values
491+
if self.n_rep > 1:
492+
k = k.repeat_interleave(self.n_rep, dim=1)
493+
v = v.repeat_interleave(self.n_rep, dim=1)
461494

462495
output = torch.ops.coreml.sdpa(q, k, v, mask)
463496

464497
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
465498

466499
output = self.wo(output)
467500

468-
if return_kv:
501+
if self.return_kv_values:
469502
return output, k_ret, v_ret
470503
return output
471504

@@ -533,6 +566,8 @@ class TransformerBlock(nn.Module):
533566
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
534567
super().__init__()
535568
self.use_kv_cache = args.use_kv_cache
569+
self.decode_kv_cache_as_io = args.decode_kv_cache_as_io
570+
self.return_kv_values = (args.prefill_return_kv or args.decode_kv_cache_as_io)
536571
self.n_heads = args.n_heads
537572
self.dim = args.dim
538573
self.head_dim = args.head_dim
@@ -544,14 +579,19 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
544579
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
545580
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
546581

547-
def forward(self, x, freqs_cos, freqs_sin, input_pos=None, return_kv=False): # x: 1xN
548-
if not return_kv:
582+
def forward(self, x, freqs_cos, freqs_sin, input_pos=None, k_cache=None, v_cache=None, cache_pos_mask=None): # x: 1xN
583+
if self.decode_kv_cache_as_io:
584+
assert self.use_kv_cache
585+
assert k_cache is not None
586+
assert v_cache is not None
587+
588+
if not self.return_kv_values:
549589
h = self.attention.forward(
550-
self.attention_norm(x), freqs_cos, freqs_sin, input_pos, return_kv=False,
590+
self.attention_norm(x), freqs_cos, freqs_sin, input_pos, k_cache, v_cache, cache_pos_mask,
551591
)
552592
else:
553593
h, k, v = self.attention.forward(
554-
self.attention_norm(x), freqs_cos, freqs_sin, input_pos, return_kv=True,
594+
self.attention_norm(x), freqs_cos, freqs_sin, input_pos, k_cache, v_cache, cache_pos_mask,
555595
)
556596

557597
h = x + h
@@ -560,7 +600,7 @@ def forward(self, x, freqs_cos, freqs_sin, input_pos=None, return_kv=False): #
560600
else:
561601
out = h + self.feed_forward(self.ffn_norm(h))
562602

563-
if return_kv:
603+
if self.return_kv_values:
564604
return out, k, v
565605
return out
566606

@@ -580,49 +620,71 @@ def __init__(self, params: ModelArgs):
580620
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
581621
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
582622
self.use_kv_cache = params.use_kv_cache
623+
self.decode_kv_cache_as_io = params.decode_kv_cache_as_io
583624
self.generate_full_logits = params.generate_full_logits
584625
self.max_seq_len = params.max_seq_len
585626
self.input_prune_map = params.input_prune_map
586627
self.output_prune_map = params.output_prune_map
587-
self.prefill_return_kv = params.prefill_return_kv
628+
629+
# Whether model returns newly computed KV values
630+
self.return_kv_values = (params.prefill_return_kv or params.decode_kv_cache_as_io)
588631

589632
def forward(
590633
self,
591634
tokens: Optional[torch.LongTensor] = None, # tokens
592635
input_pos: Optional[
593636
torch.LongTensor
594637
] = None, # Scalar tensor indicating size of window of the caches
595-
h: Optional[torch.FloatTensor] = None, # embeddings
638+
k_cache: Optional[torch.FloatTensor] = None,
639+
v_cache: Optional[torch.FloatTensor] = None,
640+
cache_pos_mask: Optional[torch.FloatTensor] = None,
596641
) -> torch.Tensor:
597-
if (tokens is None) ^ (h is not None):
598-
raise ValueError(
599-
"You cannot specify both tokens and h at the same time, and must specify either one"
600-
)
601-
if tokens is not None and h is None:
602-
h = self.tok_embeddings(tokens)
642+
h = self.tok_embeddings(tokens)
643+
if self.decode_kv_cache_as_io:
644+
assert self.use_kv_cache
645+
assert k_cache is not None
646+
assert v_cache is not None
647+
648+
649+
603650
seqlen = h.shape[1]
604651
freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen)
605652

606-
if not self.prefill_return_kv:
653+
if not self.return_kv_values:
607654
for layer in self.layers:
608655
h = layer(
609656
h,
610657
freqs_cos,
611658
freqs_sin,
612659
input_pos,
613-
return_kv=False,
660+
k_cache,
661+
v_cache,
662+
cache_pos_mask,
614663
)
615664
else:
616665
k_caches = []
617666
v_caches = []
618-
for layer in self.layers:
619-
h, k, v = layer(
620-
h,
621-
freqs_cos,
622-
freqs_sin,
623-
input_pos,
624-
return_kv=True,
625-
)
667+
for i, layer in enumerate(self.layers):
668+
if not self.decode_kv_cache_as_io:
669+
h, k, v = layer(
670+
h,
671+
freqs_cos,
672+
freqs_sin,
673+
input_pos,
674+
k_cache,
675+
v_cache,
676+
cache_pos_mask,
677+
)
678+
else:
679+
h, k, v = layer(
680+
h,
681+
freqs_cos,
682+
freqs_sin,
683+
input_pos,
684+
k_cache[i,:,:,:,:],
685+
v_cache[i,:,:,:,:],
686+
cache_pos_mask,
687+
)
626688
k_caches.append(k)
627689
v_caches.append(v)
628690
k_ret = torch.stack(k_caches, dim=0)
@@ -658,6 +720,6 @@ def forward(
658720
expanded_logits[:, list(self.output_prune_map.values())] = logits
659721
logits = expanded_logits
660722

661-
if self.prefill_return_kv:
723+
if self.return_kv_values:
662724
return logits, k_ret, v_ret
663725
return logits

examples/models/llama/model.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def __init__(self, **kwargs):
5555
self.args = kwargs.get("args", None)
5656
self.prefill_seq_length = self.args.prefill_seq_length
5757
self.prefill_return_kv = self.args.prefill_return_kv
58+
self.decode_kv_cache_as_io = self.args.decode_kv_cache_as_io
59+
self.use_additive_kv_cache_update = self.args.use_additive_kv_cache_update
5860

5961
# The example is using a dummy small model with random weights for demo purpose only.
6062
# Follow the instruction in https://github.com/facebookresearch/llama to download the model.
@@ -146,9 +148,16 @@ def __init__(self, **kwargs):
146148
output_prune_map=output_prune_map,
147149
enable_dynamic_shape=self.enable_dynamic_shape,
148150
prefill_return_kv=self.prefill_return_kv,
151+
decode_kv_cache_as_io=self.decode_kv_cache_as_io,
152+
use_additive_kv_cache_update=self.use_additive_kv_cache_update,
149153
**params,
150154
)
151155

156+
# Used for self.decode_kv_cache_as_io and self.args.decode_kv_cache_as_io
157+
self._cache_shape = (model_args.n_layers, model_args.max_batch_size, model_args.n_kv_heads, model_args.max_seq_len, model_args.head_dim)
158+
self._cache_pos_mask_shape = (model_args.max_batch_size, model_args.n_kv_heads, model_args.max_seq_len, model_args.head_dim)
159+
160+
152161
if model_args.use_scaled_rope:
153162
# Older models don't have use_scaled_rope configuration
154163
assert self.args.model not in ["llama2", "stories110m"]
@@ -288,14 +297,27 @@ def get_example_inputs_kvcache_sdpa(self):
288297
torch.tensor([0], dtype=torch.long),
289298
)
290299
else:
291-
return (
300+
args = (
292301
torch.tensor(
293302
[[1]], dtype=torch.long
294303
), # tokens, with kv cache our input token length is always just 1 token.
295304
torch.tensor(
296305
[0], dtype=torch.long
297306
), # start_pos, what token of output are we on.
298307
)
308+
if self.decode_kv_cache_as_io:
309+
args = args + (
310+
# (n_layers, max_batch_size, n_heads, max_seq_length, head_dim)
311+
torch.zeros(self._cache_shape, dtype=torch.float16), # k-cache
312+
torch.zeros(self._cache_shape, dtype=torch.float16), # v-cache
313+
)
314+
315+
if self.use_additive_kv_cache_update:
316+
args = args + (
317+
torch.zeros(self._cache_pos_mask_shape, dtype=torch.float16),
318+
)
319+
return args
320+
299321

300322
def _transform_for_pre_quantization(self, checkpoint, model_args):
301323
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"

extension/llm/export/builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def export(self) -> "LLMEdgeManager":
201201
logging.info(f"inputs: {self.example_inputs}")
202202
logging.info(f"kwargs: {self.example_kwarg_inputs}")
203203
logging.info(f"dynamic shapes: {dynamic_shape}")
204+
print("EVALUATED", self.model(*self.example_inputs))
204205
exported_module = export_for_training(
205206
self.model,
206207
self.example_inputs,

model_export_script.sh

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,19 @@ export PARAMS=$HOME/models/stories110M/params.json
66
export MODEL_OUT_DIR=$HOME/models/stories110M
77
export MODEL_OUT_PREFILL=$MODEL_OUT_DIR/prefill_model.pte
88
export MODEL_OUT_DECODE=$MODEL_OUT_DIR/decode_model.pte
9+
export MODEL_OUT_DECODE_KV_IO=$MODEL_OUT_DIR/decode_kv_io_model.pte
10+
export MODEL_OUT_DECODE_KV_IO_ADDITIVE=$MODEL_OUT_DIR/decode_kv_io_additive_model.pte
911

10-
python -m examples.models.llama.export_llama -c $MODEL_IN -p $PARAMS --output_name=$MODEL_OUT_PREFILL -E "4,32" --prefill_seq_length 512 --disable_dynamic_shape --coreml --coreml-ios 18 --coreml-quantize c4w --coreml-compute-units cpu_only --max_seq_length 1024 --prefill_return_kv --dtype fp16
1112

12-
python -m examples.models.llama.export_llama -c $MODEL_IN -p $PARAMS --output_name=$MODEL_OUT_DECODE -E "4,32" -kv --disable_dynamic_shape --coreml --coreml-ios 18 --coreml-quantize c4w --coreml-compute-units cpu_only --max_seq_length 1024
13+
python -m examples.models.llama.export_llama -c $MODEL_IN -p $PARAMS --output_name=$MODEL_OUT_PREFILL -E "4,32" --prefill_seq_length 512 --disable_dynamic_shape --coreml --coreml-ios 18 --coreml-quantize c4w --coreml-compute-units cpu_and_ne --max_seq_length 1024 --prefill_return_kv --dtype fp16
14+
python -m examples.models.llama.export_llama -c $MODEL_IN -p $PARAMS --output_name=$MODEL_OUT_DECODE -E "4,32" -kv --disable_dynamic_shape --coreml --coreml-ios 18 --coreml-quantize c4w --coreml-compute-units cpu_and_ne --max_seq_length 1024
15+
python -m examples.models.llama.export_llama -c $MODEL_IN -p $PARAMS --output_name=$MODEL_OUT_DECODE_KV_IO -E "4,32" -kv --disable_dynamic_shape --coreml --coreml-ios 18 --coreml-quantize c4w --coreml-compute-units cpu_and_ne --max_seq_length 1024 --decode_kv_cache_as_io --dtype fp16
16+
python -m examples.models.llama.export_llama -c $MODEL_IN -p $PARAMS --output_name=$MODEL_OUT_DECODE_KV_IO_ADDITIVE -E "4,32" -kv --disable_dynamic_shape --coreml --coreml-ios 18 --coreml-quantize c4w --coreml-compute-units cpu_and_ne --max_seq_length 1024 --decode_kv_cache_as_io --use_additive_kv_cache_update --dtype fp16
1317

1418

1519
python examples/apple/coreml/scripts/extract_coreml_models.py -m $MODEL_OUT_PREFILL -o "${MODEL_OUT_DIR}/prefill"
1620
python examples/apple/coreml/scripts/extract_coreml_models.py -m $MODEL_OUT_DECODE -o "${MODEL_OUT_DIR}/decode"
21+
python examples/apple/coreml/scripts/extract_coreml_models.py -m $MODEL_OUT_DECODE_KV_IO -o "${MODEL_OUT_DIR}/decode_kv_io"
22+
python examples/apple/coreml/scripts/extract_coreml_models.py -m $MODEL_OUT_DECODE_KV_IO_ADDITIVE -o "${MODEL_OUT_DIR}/decode_kv_io_additive"
1723

1824
python combine_coreml_models.py -m1 "${MODEL_OUT_DIR}/prefill/extracted_coreml_models/model_1/lowered_module/model.mlpackage" -m2 "${MODEL_OUT_DIR}/decode/extracted_coreml_models/model_1/lowered_module/model.mlpackage" -o "${MODEL_OUT_DIR}/combined.mlpackage"

0 commit comments

Comments
 (0)