Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from .source_transformation.sdpa import (
replace_causal_mask,
replace_kv_cache_with_simple_kv_cache,
replace_sdpa_with_coreml_sdpa,
replace_sdpa_with_custom_op,
replace_sdpa_with_flex_sdpa,
replace_sdpa_with_simple_sdpa,
Expand Down Expand Up @@ -304,6 +305,11 @@ def build_args_parser() -> argparse.ArgumentParser:
action="store_true",
help="This option is only for coreml, and is only supported for MacOS15+/iOS18+",
)
parser.add_argument(
"--coreml-preserve-sdpa",
action="store_true",
help="This option is only for coreml: Preserve sdpa in torch edge program to use coreml iOS18.sdpa op",
)
parser.add_argument(
"--coreml-quantize",
default=None,
Expand Down Expand Up @@ -527,6 +533,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
if args.coreml:
coreml_partitioner = get_coreml_partitioner(
args.use_kv_cache and args.coreml_enable_state,
args.coreml_preserve_sdpa,
args.embedding_quantize,
args.pt2e_quantize,
args.coreml_quantize,
Expand Down Expand Up @@ -742,6 +749,7 @@ def _load_llama_model(
)


# pyre-ignore: '_get_source_transforms' is too complex (14)
def _get_source_transforms(
modelname: str, dtype_override: Optional[DType], args
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
Expand Down Expand Up @@ -795,10 +803,19 @@ def _get_source_transforms(
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
transforms.append(convert_linear_to_conv2d)

elif args.coreml or args.mps:
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
elif args.mps:
# Currently mps doesn't support sdpa op, use the simpler decomposition
# to get free perf gain.
transforms.append(replace_sdpa_with_simple_sdpa)
transforms.append(replace_causal_mask)

elif args.coreml:
# TODO: We might want to explore simple KV cache,
# since `k_out[:, :, input_pos] = k_val` decomposition is messy
# and is not easy to cleanly map to iOS18.slice_update
if args.coreml_preserve_sdpa:
transforms.append(replace_sdpa_with_coreml_sdpa)
else:
transforms.append(replace_sdpa_with_simple_sdpa)

return transforms
75 changes: 75 additions & 0 deletions examples/models/llama2/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,81 @@ def replace_sdpa_with_flex_sdpa(module: torch.nn.Module):
return module


@torch.library.custom_op("coreml::sdpa", mutates_args=())
def sdpa(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
) -> torch.Tensor:
"""Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion."""
return torch.ops.aten.scaled_dot_product_attention.default(
q, k, v, attn_mask=attn_mask
)


@torch.library.register_fake("coreml::sdpa")
def _(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
) -> torch.Tensor:
"""Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing."""
expected_shape = list(q.shape)
expected_shape[-1] = v.shape[-1]
return q.new_empty(expected_shape)


class SDPACoreML(torch.nn.Module):
"""Similar to SDPASimple, but with coreml custom op to do SDPA calculation."""

def __init__(
self,
kv_cache: KVCache,
dim: int,
head_dim: int,
n_rep: int,
):
super().__init__()
self.kv_cache = kv_cache
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep

def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bsz,
seqlen,
mask,
):
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
attn_mask = mask[None, None, input_pos]

if self.n_rep > 1:
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)

y = torch.ops.coreml.sdpa(q, k, v, attn_mask)

return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)


def replace_sdpa_with_coreml_sdpa(module: torch.nn.Module):
for name, child in module.named_children():
if isinstance(child, SDPA):
setattr(
module,
name,
SDPACoreML(child.kv_cache, child.dim, child.head_dim, child.n_rep),
)
else:
replace_sdpa_with_coreml_sdpa(child)
return module


class KVCacheSimple(torch.nn.Module):
def __init__(
self,
Expand Down
4 changes: 4 additions & 0 deletions extension/llm/export/partitioner_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def get_mps_partitioner(use_kv_cache: bool = False):

def get_coreml_partitioner(
enable_state: bool = False,
preserve_sdpa: bool = True,
embedding_quantize: Optional[str] = None,
pt2e_quantize: Optional[str] = None,
coreml_quantize: Optional[str] = None,
Expand All @@ -78,6 +79,9 @@ def get_coreml_partitioner(
# In Core ML, stateful execution is introduced in iOS 18
if enable_state:
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
# In Core ML, sdpa op is introduced in iOS 18
if preserve_sdpa:
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
# In Core ML, quantization is introduced in iOS 16
if embedding_quantize is not None or pt2e_quantize is not None:
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS16)
Expand Down
Loading