Skip to content

Commit 2414e61

Browse files
committed
updates
1 parent e7f25a3 commit 2414e61

File tree

7 files changed

+138
-20
lines changed

7 files changed

+138
-20
lines changed

combine_coreml_models.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import coremltools as ct
2+
import argparse
3+
4+
5+
if __name__ == "__main__":
6+
"""
7+
Combines two CoreML models together
8+
"""
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument(
11+
"-m1",
12+
"--model1_path",
13+
type=str,
14+
help="Model1 path.",
15+
)
16+
parser.add_argument(
17+
"-m2",
18+
"--model2_path",
19+
type=str,
20+
help="Model2 path.",
21+
)
22+
parser.add_argument(
23+
"-o",
24+
"--output_path",
25+
type=str,
26+
help="Output path to save combined model",
27+
)
28+
29+
args = parser.parse_args()
30+
model1_path = str(args.model1_path)
31+
model2_path = str(args.model2_path)
32+
output_path = str(args.output_path)
33+
34+
35+
desc = ct.utils.MultiFunctionDescriptor()
36+
37+
desc.add_function(
38+
model1_path,
39+
src_function_name="main",
40+
target_function_name="model1"
41+
)
42+
desc.add_function(
43+
model2_path,
44+
src_function_name="main",
45+
target_function_name="model2"
46+
)
47+
desc.default_function_name = "model1"
48+
ct.utils.save_multifunction(desc, output_path)

examples/apple/coreml/scripts/extract_coreml_models.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424

2525

26-
def extract_coreml_models(pte_data: bytes):
26+
def extract_coreml_models(pte_data: bytes, output_dir: str = "."):
2727
program = deserialize_pte_binary(pte_data)
2828
delegates: List[BackendDelegate] = sum(
2929
[execution_plan.delegates for execution_plan in program.execution_plan], []
@@ -45,7 +45,7 @@ def extract_coreml_models(pte_data: bytes):
4545
AssertionError("The loaded Program must have inline data.")
4646

4747
model_name: str = f"model_{model_index}"
48-
model_path: Path = Path() / "extracted_coreml_models" / model_name
48+
model_path: Path = Path() / output_dir / "extracted_coreml_models" / model_name
4949
if model_path.exists():
5050
shutil.rmtree(model_path.absolute())
5151
os.makedirs(model_path.absolute())
@@ -72,9 +72,15 @@ def extract_coreml_models(pte_data: bytes):
7272
required=True,
7373
help="Input must be a .pte file.",
7474
)
75+
parser.add_argument(
76+
"-o",
77+
"--output_dir",
78+
default=".",
79+
help="Output directory to save the extracted Core ML models.",
80+
)
7581

7682
args = parser.parse_args()
7783
model_path = str(args.model_path)
7884
with open(model_path, mode="rb") as pte_file:
7985
pte_data = pte_file.read()
80-
extract_coreml_models(pte_data)
86+
extract_coreml_models(pte_data, args.output_dir)

examples/models/llama/export_llama_lib.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import pkg_resources
2323
import torch
2424

25-
from executorch.devtools.etrecord import generate_etrecord
25+
2626

2727
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
2828

@@ -237,8 +237,8 @@ def build_args_parser() -> argparse.ArgumentParser:
237237
)
238238
parser.add_argument(
239239
"--prefill_seq_length",
240-
default=False,
241-
action="store_true",
240+
type=int,
241+
default=32,
242242
help="Sequence length for prefill model",
243243
)
244244
parser.add_argument(
@@ -781,6 +781,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
781781
logging.info(f"--> {partitioner.__class__.__name__}")
782782

783783
if args.generate_etrecord:
784+
from executorch.devtools.etrecord import generate_etrecord
784785
if not builder_exported_to_edge.edge_manager:
785786
raise ValueError("Unable to generate etrecord due to missing edge manager.")
786787

examples/models/llama/llama_transformer.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class ModelArgs:
114114
num_experts: int = 8 # Number of experts
115115
num_activated_experts: int = 2 # Number of experts to activate
116116
use_kv_cache: bool = False # Use key/value cache
117+
prefill_return_kv: bool = False # Return kv cache for prefill
117118
use_sdpa_with_kv_cache_op: bool = (
118119
False # Use custom sdpa op that updates kv cache in-place
119120
)
@@ -420,7 +421,11 @@ def forward(
420421
freqs_cos: torch.Tensor,
421422
freqs_sin: torch.Tensor,
422423
input_pos: Optional[torch.Tensor] = None,
424+
return_kv: bool = False,
423425
):
426+
if return_kv:
427+
assert self.use_kv_cache == False, "Can't return kv when use_kv_cache is True"
428+
424429
bsz, seqlen, _ = x.shape
425430

426431
# QKV
@@ -442,6 +447,10 @@ def forward(
442447
k = k.transpose(1, 2)
443448
v = v.transpose(1, 2)
444449

450+
if return_kv:
451+
k_ret = k
452+
v_ret = v
453+
445454
# grouped multiquery attention: expand out keys and values
446455
k = k.repeat_interleave(self.n_rep, dim=1)
447456
v = v.repeat_interleave(self.n_rep, dim=1)
@@ -456,6 +465,8 @@ def forward(
456465

457466
output = self.wo(output)
458467

468+
if return_kv:
469+
return output, k_ret, v_ret
459470
return output
460471

461472

@@ -533,16 +544,24 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
533544
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
534545
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
535546

536-
def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN
537-
h = self.attention.forward(
538-
self.attention_norm(x), freqs_cos, freqs_sin, input_pos
539-
)
547+
def forward(self, x, freqs_cos, freqs_sin, input_pos=None, return_kv=False): # x: 1xN
548+
if not return_kv:
549+
h = self.attention.forward(
550+
self.attention_norm(x), freqs_cos, freqs_sin, input_pos, return_kv=False,
551+
)
552+
else:
553+
h, k, v = self.attention.forward(
554+
self.attention_norm(x), freqs_cos, freqs_sin, input_pos, return_kv=True,
555+
)
540556

541557
h = x + h
542558
if hasattr(self, "block_sparse_moe"):
543559
out = h + self.block_sparse_moe(self.ffn_norm(h))
544560
else:
545561
out = h + self.feed_forward(self.ffn_norm(h))
562+
563+
if return_kv:
564+
return out, k, v
546565
return out
547566

548567

@@ -565,6 +584,7 @@ def __init__(self, params: ModelArgs):
565584
self.max_seq_len = params.max_seq_len
566585
self.input_prune_map = params.input_prune_map
567586
self.output_prune_map = params.output_prune_map
587+
self.prefill_return_kv = params.prefill_return_kv
568588

569589
def forward(
570590
self,
@@ -583,13 +603,30 @@ def forward(
583603
seqlen = h.shape[1]
584604
freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen)
585605

586-
for layer in self.layers:
587-
h = layer(
588-
h,
589-
freqs_cos,
590-
freqs_sin,
591-
input_pos,
592-
)
606+
if not self.prefill_return_kv:
607+
for layer in self.layers:
608+
h = layer(
609+
h,
610+
freqs_cos,
611+
freqs_sin,
612+
input_pos,
613+
return_kv=False,
614+
)
615+
else:
616+
k_caches = []
617+
v_caches = []
618+
for layer in self.layers:
619+
h, k, v = layer(
620+
h,
621+
freqs_cos,
622+
freqs_sin,
623+
input_pos,
624+
return_kv=True,
625+
)
626+
k_caches.append(k)
627+
v_caches.append(v)
628+
k_ret = torch.stack(k_caches, dim=0)
629+
v_ret = torch.stack(v_caches, dim=0)
593630

594631
if not self.generate_full_logits:
595632
# Only the last logit is used for the new generated token
@@ -621,4 +658,6 @@ def forward(
621658
expanded_logits[:, list(self.output_prune_map.values())] = logits
622659
logits = expanded_logits
623660

661+
if self.prefill_return_kv:
662+
return logits, k_ret, v_ret
624663
return logits

examples/models/llama/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def __init__(self, **kwargs):
5353
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
5454
self.max_seq_len = kwargs.get("max_seq_len", 128)
5555
self.args = kwargs.get("args", None)
56+
self.prefill_seq_length = self.args.prefill_seq_length
57+
self.prefill_return_kv = self.args.prefill_return_kv
5658

5759
# The example is using a dummy small model with random weights for demo purpose only.
5860
# Follow the instruction in https://github.com/facebookresearch/llama to download the model.
@@ -143,6 +145,7 @@ def __init__(self, **kwargs):
143145
input_prune_map=input_prune_map,
144146
output_prune_map=output_prune_map,
145147
enable_dynamic_shape=self.enable_dynamic_shape,
148+
prefill_return_kv=self.prefill_return_kv,
146149
**params,
147150
)
148151

@@ -273,7 +276,7 @@ def get_example_inputs(self):
273276
else:
274277
return (
275278
torch.tensor(
276-
[[0 for _ in range(self.args.get("prefill_seq_length", 3))]], dtype=torch.long
279+
[[0 for _ in range(self.prefill_seq_length)]], dtype=torch.long
277280
), # tokens, with kv cache our input token length is always just 1 token.
278281
)
279282

extension/llm/export/builder.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,11 @@ def _get_dynamic_shape(self) -> Any:
160160
dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1)
161161

162162
if not self.use_kv_cache:
163-
# Only one input argument: tokens
164-
self.dynamic_shapes = ({1: dim},)
163+
if not self.enable_dynamic_shape:
164+
self.dynamic_shapes = None
165+
else:
166+
# Only one input argument: tokens
167+
self.dynamic_shapes = ({1: dim},)
165168
elif self.enable_dynamic_shape:
166169
# Two input arguments: tokens and input_pos but input_pos is static shape
167170
self.dynamic_shapes = ({1: dim}, {0: 1})

model_export_script.sh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
set -e
2+
3+
export MODEL_IN=$HOME/models/stories110M/stories110M.pt
4+
export TOKENIZER=$HOME/models/stories110M/tokenizer.bin
5+
export PARAMS=$HOME/models/stories110M/params.json
6+
export MODEL_OUT_DIR=$HOME/models/stories110M
7+
export MODEL_OUT_PREFILL=$MODEL_OUT_DIR/prefill_model.pte
8+
export MODEL_OUT_DECODE=$MODEL_OUT_DIR/decode_model.pte
9+
10+
python -m examples.models.llama.export_llama -c $MODEL_IN -p $PARAMS --output_name=$MODEL_OUT_PREFILL -E "4,32" --prefill_seq_length 512 --disable_dynamic_shape --coreml --coreml-ios 18 --coreml-quantize c4w --coreml-compute-units cpu_only --max_seq_length 1024 --prefill_return_kv --dtype fp16
11+
12+
python -m examples.models.llama.export_llama -c $MODEL_IN -p $PARAMS --output_name=$MODEL_OUT_DECODE -E "4,32" -kv --disable_dynamic_shape --coreml --coreml-ios 18 --coreml-quantize c4w --coreml-compute-units cpu_only --max_seq_length 1024
13+
14+
15+
python examples/apple/coreml/scripts/extract_coreml_models.py -m $MODEL_OUT_PREFILL -o "${MODEL_OUT_DIR}/prefill"
16+
python examples/apple/coreml/scripts/extract_coreml_models.py -m $MODEL_OUT_DECODE -o "${MODEL_OUT_DIR}/decode"
17+
18+
python combine_coreml_models.py -m1 "${MODEL_OUT_DIR}/prefill/extracted_coreml_models/model_1/lowered_module/model.mlpackage" -m2 "${MODEL_OUT_DIR}/decode/extracted_coreml_models/model_1/lowered_module/model.mlpackage" -o "${MODEL_OUT_DIR}/combined.mlpackage"

0 commit comments

Comments
 (0)