Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,13 @@ def build_args_parser() -> argparse.ArgumentParser:
choices=["b4w"],
help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight)",
)
parser.add_argument(
"--coreml-ios",
type=int,
default=15,
choices=(15, 16, 17, 18),
help="This option is only for coreml: The minimum iOS version to deploy",
)
parser.add_argument(
"--qnn",
action="store_true",
Expand Down Expand Up @@ -533,8 +540,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901

if args.coreml:
coreml_partitioner = get_coreml_partitioner(
args.use_kv_cache and args.coreml_enable_state,
args.coreml_preserve_sdpa,
args.coreml_ios,
args.embedding_quantize,
args.pt2e_quantize,
args.coreml_quantize,
Expand Down Expand Up @@ -810,7 +816,8 @@ def _get_source_transforms( # noqa
transforms.append(replace_causal_mask)

elif args.coreml:
if args.coreml_preserve_sdpa:
# iOS 18 introduced fused sdpa op
if args.coreml_ios >= 18:
transforms.append(replace_sdpa_with_coreml_sdpa)
else:
transforms.append(replace_sdpa_with_simple_sdpa)
Expand Down
66 changes: 40 additions & 26 deletions extension/llm/export/partitioner_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def get_mps_partitioner(use_kv_cache: bool = False):


def get_coreml_partitioner(
enable_state: bool = False,
preserve_sdpa: bool = True,
ios: int = 15,
embedding_quantize: Optional[str] = None,
pt2e_quantize: Optional[str] = None,
coreml_quantize: Optional[str] = None,
Expand All @@ -75,29 +74,42 @@ def get_coreml_partitioner(
"Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html"
)

minimum_deployment_target = ct.target.iOS15
# In Core ML, stateful execution is introduced in iOS 18
if enable_state:
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
# In Core ML, sdpa op is introduced in iOS 18
if preserve_sdpa:
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
# In Core ML, quantization is introduced in iOS 16
if embedding_quantize is not None or pt2e_quantize is not None:
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS16)
# In Core ML, 8-bit activation quantization is introduced in iOS 17
if (
embedding_quantize is not None and int(embedding_quantize.split(",")[0]) == 8
) or pt2e_quantize in ("coreml_8a_c8w", "coreml_baseline_8a_c8w"):
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS17)
# In Core ML, 4-bit weight compression is introduced in iOS 18
if (
(embedding_quantize is not None and int(embedding_quantize.split(",")[0]) == 4)
or pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w", "coreml_baseline_8a_c4w")
or coreml_quantize == "b4w"
):
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
def _validate_ios_version() -> None:
assert ios in (15, 16, 17, 18)

if embedding_quantize is not None and ios < 18:
raise ValueError(
"In Core ML, per-block quantization is introduced in iOS 18"
)

use_quantization = pt2e_quantize is not None or coreml_quantize is not None
if use_quantization and ios < 16:
raise ValueError("In Core ML, quantization is introduced in iOS 16")

use_8a = (pt2e_quantize is not None and "8a" in pt2e_quantize) or (
coreml_quantize is not None and "8a" in coreml_quantize
)
if use_8a and ios < 17:
raise ValueError(
"In Core ML, 8-bit activation quantization is introduced in iOS 17"
)

use_4w = (pt2e_quantize is not None and "4w" in pt2e_quantize) or (
coreml_quantize is not None and "4w" in coreml_quantize
)
if use_4w and ios < 18:
raise ValueError(
"In Core ML, 4-bit weight compression is introduced in iOS 18"
)

_validate_ios_version()

minimum_deployment_target = {
15: ct.target.iOS15,
16: ct.target.iOS16,
17: ct.target.iOS17,
18: ct.target.iOS18,
}[ios]
op_linear_quantizer_config = None
if coreml_quantize == "b4w":
op_linear_quantizer_config = {
Expand All @@ -107,7 +119,6 @@ def get_coreml_partitioner(
"block_size": 32,
"weight_threshold": 512,
}

compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
minimum_deployment_target=minimum_deployment_target,
compute_precision=ct.precision(ct.precision.FLOAT16.value),
Expand All @@ -116,9 +127,12 @@ def get_coreml_partitioner(
model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16]
op_linear_quantizer_config=op_linear_quantizer_config,
)

take_over_mutable_buffer = minimum_deployment_target >= ct.target.iOS18

return CoreMLPartitioner( # pyre-fixme[16]
compile_specs=compile_specs,
take_over_mutable_buffer=enable_state,
take_over_mutable_buffer=take_over_mutable_buffer,
)


Expand Down
Loading