Skip to content

Commit bc0ad82

Browse files
authored
Switch ANE llama model to use to_edge_transform_and_lower + torchao quantization APIs (#12665)
This switches the ANE model to use to_edge_transform_and_lower and torchao quantization APIs. To use to_edge_transform_and_lower, we first need to land: #12629 To use torchao quant APIs, we first need to land #12648 and #12664. This PR contains all of the changes from those PRs because it is rebased on them. I will rebase on main once those PRs land to make this easier to review.
1 parent b183830 commit bc0ad82

File tree

2 files changed

+45
-54
lines changed

2 files changed

+45
-54
lines changed

.ci/scripts/test_ane_static_llama.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,6 @@ pushd $EXECUTORCH_ROOT/examples/apple/coreml/llama
2828
# Download stories llama110m artifacts
2929
download_stories_model_artifacts
3030

31-
python export.py -n model.pte -p params.json -c stories110M.pt --seq_length 32 --max_seq_length 64 --dtype fp16 --coreml-quantize c4w
31+
python export.py -n model.pte -p params.json -c stories110M.pt --seq_length 32 --max_seq_length 64 --dtype fp16 --coreml-quantize c4w --embedding-quantize 4,32
3232

3333
popd

examples/apple/coreml/llama/export.py

Lines changed: 44 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,19 @@
1818
from executorch.examples.apple.coreml.llama.utils import (
1919
replace_linear_with_split_linear,
2020
)
21-
from executorch.examples.models.llama.source_transformation.quantize import (
22-
EmbeddingQuantHandler,
23-
)
2421

22+
from executorch.exir import to_edge_transform_and_lower
2523
from executorch.exir.backend.utils import format_delegated_graph
2624
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
2725
from executorch.exir.passes import MemoryPlanningPass
2826
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
2927
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
30-
from executorch.exir.program._program import to_edge
3128
from executorch.extension.export_util.utils import save_pte_program
3229

30+
from torchao.quantization.granularity import PerAxis, PerGroup
31+
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
32+
from torchao.utils import unwrap_tensor_subclass
33+
3334

3435
def main() -> None:
3536
parser = argparse.ArgumentParser()
@@ -115,19 +116,8 @@ def main() -> None:
115116
export_args.dtype
116117
] # dtype for model/inputs
117118

118-
if export_args.embedding_quantize:
119-
bitwidth, group_size = export_args.embedding_quantize.split(",")
120-
if group_size == "none" or group_size == "None" or group_size == "0":
121-
group_size = None
122-
else:
123-
group_size = int(group_size)
124-
bitwidth = int(bitwidth)
125-
model = EmbeddingQuantHandler(
126-
model,
127-
bitwidth=bitwidth,
128-
group_size=group_size,
129-
packed=(bitwidth in [2, 4]),
130-
).quantized_model()
119+
model.eval()
120+
model.to(float_dtype)
131121

132122
if export_args.target_split_size is not None:
133123
replace_linear_with_split_linear(
@@ -140,24 +130,40 @@ def main() -> None:
140130
in_max_splits=1,
141131
)
142132

143-
model.eval()
144-
model.to(float_dtype)
133+
# Quantization
134+
if export_args.embedding_quantize:
135+
bitwidth, group_size = export_args.embedding_quantize.split(",")
136+
bitwidth = int(bitwidth)
137+
assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization"
138+
group_size = int(group_size)
139+
if group_size == 0:
140+
granularity = PerAxis(0)
141+
else:
142+
granularity = PerGroup(group_size)
143+
weight_dtype = getattr(torch, f"int{bitwidth}")
144+
145+
quantize_(
146+
model,
147+
IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity),
148+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
149+
)
145150

146-
op_linear_quantizer_config = None
147151
if export_args.coreml_quantize == "b4w":
148-
op_linear_quantizer_config = {
149-
"mode": "linear_symmetric",
150-
"dtype": "int4",
151-
"granularity": "per_block",
152-
"block_size": 32,
153-
"weight_threshold": 512,
154-
}
152+
quantize_(
153+
model,
154+
IntxWeightOnlyConfig(
155+
weight_dtype=torch.int4,
156+
granularity=PerGroup(32),
157+
),
158+
)
155159
elif export_args.coreml_quantize == "c4w":
156-
op_linear_quantizer_config = {
157-
"mode": "linear_symmetric",
158-
"dtype": "int4",
159-
"granularity": "per_channel",
160-
}
160+
quantize_(
161+
model,
162+
IntxWeightOnlyConfig(
163+
weight_dtype=torch.int4,
164+
granularity=PerAxis(0),
165+
),
166+
)
161167

162168
compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
163169
minimum_deployment_target=ct.target.iOS18,
@@ -167,15 +173,11 @@ def main() -> None:
167173
}[float_dtype],
168174
compute_unit=ct.ComputeUnit.CPU_AND_NE,
169175
model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16]
170-
op_linear_quantizer_config=op_linear_quantizer_config,
171176
)
172177
partitioner = CoreMLPartitioner( # pyre-fixme[16]
173178
compile_specs=compile_specs,
174179
take_over_mutable_buffer=False,
175-
skip_ops_for_coreml_delegation=[
176-
"quantized_decomposed.embedding_4bit.dtype",
177-
"aten.embedding.default",
178-
],
180+
skip_ops_for_coreml_delegation=[],
179181
)
180182

181183
input_manager = InputManager(
@@ -192,33 +194,22 @@ def main() -> None:
192194
)
193195
example_inputs = input_manager.get_inputs(tokens=[0])
194196

197+
model = unwrap_tensor_subclass(model)
198+
195199
ep = torch.export.export(model, example_inputs, strict=True)
196200
print("Exported program")
197201
print(ep)
198202

199-
edge_manager = to_edge(
203+
edge_manager = to_edge_transform_and_lower(
200204
ep,
205+
partitioner=[partitioner],
201206
compile_config=EdgeCompileConfig(
202-
_check_ir_validity=False,
207+
# TODO: fix lowering when dim_order is enabled
203208
_skip_dim_order=True,
204-
preserve_ops=[
205-
torch.ops.aten.scaled_dot_product_attention.default,
206-
# preserve norm op for numerical stability
207-
torch.ops.aten.linalg_vector_norm.default,
208-
torch.ops.aten.reciprocal.default,
209-
],
210209
),
211210
)
212-
print("Edge program")
213-
print(edge_manager.exported_program())
214-
215-
for node in edge_manager.exported_program().graph_module.graph.nodes:
216-
print(node.name, node.target, node.args, node.kwargs)
217-
218-
edge_manager = edge_manager.to_backend(partitioner)
219211

220212
print("Delegated program")
221-
222213
print(format_delegated_graph(edge_manager.exported_program().graph_module))
223214

224215
executorch_program = edge_manager.to_executorch(

0 commit comments

Comments
 (0)