Skip to content

Commit e0c4a07

Browse files
committed
init
1 parent ee7d388 commit e0c4a07

File tree

3 files changed

+686
-0
lines changed

3 files changed

+686
-0
lines changed
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
import argparse
6+
import json
7+
8+
import coremltools as ct
9+
import torch
10+
from executorch.backends.apple.coreml.compiler import CoreMLBackend # pyre-ignore
11+
from executorch.backends.apple.coreml.partition import CoreMLPartitioner # pyre-ignore
12+
from executorch.examples.models.llama.source_transformation.quantize import (
13+
EmbeddingQuantHandler,
14+
)
15+
16+
from executorch.exir.backend.utils import format_delegated_graph
17+
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
18+
from executorch.exir.passes import MemoryPlanningPass
19+
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
20+
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
21+
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
22+
23+
import sys
24+
sys.path.insert(0, "..")
25+
from llama.llama_transformer import (
26+
ModelArgs,
27+
Transformer,
28+
)
29+
30+
31+
32+
def main() -> None:
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument(
35+
"-n",
36+
"--output_name",
37+
default="model.pte",
38+
help="Override the output filename of the saved pte model file.",
39+
)
40+
parser.add_argument(
41+
"-p",
42+
"--params",
43+
help="config.json",
44+
)
45+
parser.add_argument(
46+
"-c",
47+
"--checkpoint",
48+
help="checkpoint path",
49+
)
50+
parser.add_argument(
51+
"--static_seq_length",
52+
type=int,
53+
default=1, # set to 1 for decode
54+
help="length sequence to evaluate",
55+
)
56+
parser.add_argument(
57+
"--max_seq_length",
58+
type=int,
59+
default=128,
60+
help="maximum length sequence to evaluate",
61+
)
62+
parser.add_argument(
63+
"-E",
64+
"--embedding-quantize",
65+
default=None,
66+
type=str,
67+
help="type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
68+
)
69+
parser.add_argument(
70+
"--coreml-quantize",
71+
default="c4w",
72+
choices=["b4w", "c4w"],
73+
help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight), c4w (for channelwise 4 bit weight)",
74+
)
75+
76+
export_args = parser.parse_args()
77+
params_path = export_args.params
78+
checkpoint_path = export_args.checkpoint
79+
80+
# Load model args
81+
with open(params_path, "r") as f:
82+
params = json.loads(f.read())
83+
84+
args = ModelArgs(
85+
max_seq_len=export_args.max_seq_length,
86+
generate_full_logits=False,
87+
**params,
88+
)
89+
90+
with torch.device("meta"):
91+
model = Transformer(args)
92+
93+
checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True)
94+
if "model" in checkpoint:
95+
checkpoint = checkpoint["model"]
96+
97+
missing, unexpected = model.load_state_dict(
98+
checkpoint,
99+
strict=False,
100+
assign=True,
101+
)
102+
print("Missing keys: ", missing)
103+
print("Unexpected keys: ", unexpected)
104+
105+
float_dtype = torch.float16 # dtype for model/inputs
106+
107+
assert export_args.static_seq_length < args.max_seq_len
108+
109+
cache_shape = (
110+
args.n_layers,
111+
args.max_batch_size,
112+
args.n_kv_heads,
113+
args.max_seq_len - export_args.static_seq_length,
114+
args.head_dim,
115+
)
116+
attn_mask_shape = (export_args.static_seq_length, args.max_seq_len)
117+
118+
example_inputs = (
119+
torch.tensor(
120+
[0 for _ in range(export_args.static_seq_length)], dtype=torch.long
121+
).reshape(1, -1), # tokens
122+
torch.tensor([0], dtype=torch.long), # input_pos
123+
torch.zeros(cache_shape, dtype=float_dtype), # k_cache
124+
torch.zeros(cache_shape, dtype=float_dtype), # v_cache
125+
torch.zeros(attn_mask_shape, dtype=float_dtype), # attn_mask
126+
)
127+
model.eval()
128+
model.to(float_dtype)
129+
130+
if export_args.embedding_quantize:
131+
bitwidth, group_size = export_args.embedding_quantize.split(",")
132+
if group_size == "none" or group_size == "None" or group_size == "0":
133+
group_size = None
134+
else:
135+
group_size = int(group_size)
136+
bitwidth = int(bitwidth)
137+
model = EmbeddingQuantHandler(
138+
model,
139+
bitwidth=bitwidth,
140+
group_size=group_size,
141+
packed=(bitwidth in [2, 4]),
142+
).quantized_model()
143+
144+
if export_args.coreml_quantize == "b4w":
145+
op_linear_quantizer_config = {
146+
"mode": "linear_symmetric",
147+
"dtype": "int4",
148+
"granularity": "per_block",
149+
"block_size": 32,
150+
"weight_threshold": 512,
151+
}
152+
elif export_args.coreml_quantize == "c4w":
153+
op_linear_quantizer_config = {
154+
"mode": "linear_symmetric",
155+
"dtype": "int4",
156+
"granularity": "per_channel",
157+
}
158+
else:
159+
raise ValueError("Invalid coreml_quantize arg")
160+
161+
compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
162+
minimum_deployment_target=ct.target.iOS18,
163+
compute_precision=ct.precision(ct.precision.FLOAT16.value),
164+
compute_unit=ct.ComputeUnit.CPU_AND_NE,
165+
model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16]
166+
op_linear_quantizer_config=op_linear_quantizer_config,
167+
)
168+
partitioner = CoreMLPartitioner( # pyre-fixme[16]
169+
compile_specs=compile_specs,
170+
take_over_mutable_buffer=False,
171+
skip_ops_for_coreml_delegation=[
172+
"quantized_decomposed.embedding_4bit.dtype",
173+
"aten.embedding.default",
174+
],
175+
)
176+
177+
edge_manager = export_to_edge(
178+
model,
179+
example_inputs,
180+
edge_compile_config=EdgeCompileConfig(
181+
_check_ir_validity=False,
182+
_skip_type_promotion=(float_dtype == torch.float16),
183+
_skip_dim_order=True,
184+
),
185+
)
186+
print("Edge program")
187+
print(edge_manager.exported_program())
188+
189+
edge_manager = edge_manager.to_backend(partitioner)
190+
191+
print("Delegated program")
192+
193+
print(format_delegated_graph(edge_manager.exported_program().graph_module))
194+
195+
executorch_program = edge_manager.to_executorch(
196+
ExecutorchBackendConfig(
197+
extract_delegate_segments=True,
198+
passes=[
199+
QuantFusionPass(),
200+
],
201+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
202+
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
203+
)
204+
)
205+
206+
filename = save_pte_program(executorch_program, export_args.output_name)
207+
print(f"Saved Executorch program to local {filename}")
208+
209+
if __name__ == "__main__":
210+
main() # pragma: no cover
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import coremltools as ct
2+
import argparse
3+
import os
4+
import subprocess
5+
import shutil
6+
7+
if __name__ == "__main__":
8+
"""
9+
Extract mlpackage from two CoreML pte files, and combine them into one mlpackage using multifunction
10+
"""
11+
parser = argparse.ArgumentParser()
12+
parser.add_argument(
13+
"-m1",
14+
"--model1_path",
15+
type=str,
16+
help="Model1 path.",
17+
)
18+
parser.add_argument(
19+
"-m2",
20+
"--model2_path",
21+
type=str,
22+
help="Model2 path.",
23+
)
24+
parser.add_argument(
25+
"-o",
26+
"--output_dir",
27+
type=str,
28+
help="Output path to save combined model",
29+
)
30+
31+
args = parser.parse_args()
32+
model1_path = str(args.model1_path)
33+
model2_path = str(args.model2_path)
34+
output_dir = str(args.output_dir)
35+
36+
if os.path.exists(output_dir):
37+
shutil.rmtree(output_dir)
38+
os.makedirs(output_dir)
39+
40+
extract_script_path = os.path.join(os.path.dirname(__file__), "../scripts/extract_coreml_models.py")
41+
extracted_path = "extracted_coreml_models/model_1/lowered_module/model.mlpackage"
42+
43+
subprocess.run(["python", extract_script_path, "--model", model1_path])
44+
items = os.listdir("extracted_coreml_models")
45+
assert len(items) == 1, "Expected one CoreML partition"
46+
shutil.copytree(extracted_path, f"{output_dir}/model1.mlpackage")
47+
shutil.rmtree("extracted_coreml_models")
48+
49+
subprocess.run(["python", extract_script_path, "--model", model2_path])
50+
items = os.listdir("extracted_coreml_models")
51+
assert len(items) == 1, "Expected one CoreML partition"
52+
shutil.copytree(extracted_path, f"{output_dir}/model2.mlpackage")
53+
shutil.rmtree("extracted_coreml_models")
54+
55+
56+
desc = ct.utils.MultiFunctionDescriptor()
57+
58+
desc.add_function(
59+
f"{output_dir}/model1.mlpackage",
60+
src_function_name="main",
61+
target_function_name="model1"
62+
)
63+
desc.add_function(
64+
f"{output_dir}/model2.mlpackage",
65+
src_function_name="main",
66+
target_function_name="model2"
67+
)
68+
desc.default_function_name = "model1"
69+
ct.utils.save_multifunction(desc, f"{output_dir}/combined.mlpackage")

0 commit comments

Comments
 (0)