Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ def build_args_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--coreml-quantize",
default=None,
choices=["b4w"],
help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight)",
choices=["b4w", "c4w"],
help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight), c4w (for channelwise 4 bit weight)",
)
parser.add_argument(
"--coreml-ios",
Expand All @@ -363,6 +363,13 @@ def build_args_parser() -> argparse.ArgumentParser:
choices=(15, 16, 17, 18),
help="This option is only for coreml: The minimum iOS version to deploy",
)
parser.add_argument(
"--coreml-compute-units",
type=str,
default="cpu_only",
choices=("cpu_only", "cpu_and_gpu", "cpu_and_ne", "all"),
help="This option is only for coreml: the compute units to use when running the model",
)
parser.add_argument(
"--qnn",
action="store_true",
Expand Down Expand Up @@ -703,6 +710,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
args.embedding_quantize,
args.pt2e_quantize,
args.coreml_quantize,
args.coreml_compute_units,
)
partitioners.append(coreml_partitioner)
modelname = f"coreml_{modelname}"
Expand Down
25 changes: 22 additions & 3 deletions extension/llm/export/partitioner_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def get_coreml_partitioner(
embedding_quantize: Optional[str] = None,
pt2e_quantize: Optional[str] = None,
coreml_quantize: Optional[str] = None,
coreml_compute_units: Optional[str] = None,
):
try:
import coremltools as ct
Expand Down Expand Up @@ -119,6 +120,19 @@ def _validate_ios_version() -> None:
17: ct.target.iOS17,
18: ct.target.iOS18,
}[ios]

if coreml_compute_units is None:
# using `ComputeUnit.ALL` can increase the model load time
# On iPhone 15 Pro, CPU decode model is over 8x faster than GPU for stories110M,
# so default to CPU_ONLY
coreml_compute_units = "cpu_only"
coreml_compute_units = {
"cpu_only": ct.ComputeUnit.CPU_ONLY,
"cpu_and_ne": ct.ComputeUnit.CPU_AND_NE,
"cpu_and_gpu": ct.ComputeUnit.CPU_AND_GPU,
"all": ct.ComputeUnit.ALL,
}[coreml_compute_units.lower()]

op_linear_quantizer_config = None
if coreml_quantize == "b4w":
op_linear_quantizer_config = {
Expand All @@ -128,17 +142,22 @@ def _validate_ios_version() -> None:
"block_size": 32,
"weight_threshold": 512,
}
elif coreml_quantize == "c4w":
op_linear_quantizer_config = {
"mode": "linear_symmetric",
"dtype": "int4",
"granularity": "per_channel",
}

compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
minimum_deployment_target=minimum_deployment_target,
compute_precision=ct.precision(ct.precision.FLOAT16.value),
# using `ComputeUnit.ALL` can increase the model load time, default to `ComputeUnit.CPU_AND_GPU`
compute_unit=ct.ComputeUnit[ct.ComputeUnit.CPU_AND_GPU.name.upper()],
compute_unit=coreml_compute_units,
model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16]
op_linear_quantizer_config=op_linear_quantizer_config,
)

take_over_mutable_buffer = minimum_deployment_target >= ct.target.iOS18

return CoreMLPartitioner( # pyre-fixme[16]
compile_specs=compile_specs,
take_over_mutable_buffer=take_over_mutable_buffer,
Expand Down