Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@
from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
from .source_transformation.sdpa import (
replace_causal_mask,
replace_kv_cache_with_coreml_kv_cache,
replace_kv_cache_with_simple_kv_cache,
replace_sdpa_with_coreml_sdpa,
replace_sdpa_with_custom_op,
replace_sdpa_with_flex_sdpa,
replace_sdpa_with_simple_sdpa,
Expand Down Expand Up @@ -304,6 +306,11 @@ def build_args_parser() -> argparse.ArgumentParser:
action="store_true",
help="This option is only for coreml, and is only supported for MacOS15+/iOS18+",
)
parser.add_argument(
"--coreml-preserve-sdpa",
action="store_true",
help="This option is only for coreml: Preserve sdpa in torch edge program to use coreml iOS18.sdpa op",
)
parser.add_argument(
"--coreml-quantize",
default=None,
Expand Down Expand Up @@ -527,6 +534,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
if args.coreml:
coreml_partitioner = get_coreml_partitioner(
args.use_kv_cache and args.coreml_enable_state,
args.coreml_preserve_sdpa,
args.embedding_quantize,
args.pt2e_quantize,
args.coreml_quantize,
Expand Down Expand Up @@ -742,7 +750,7 @@ def _load_llama_model(
)


def _get_source_transforms(
def _get_source_transforms( # noqa
modelname: str, dtype_override: Optional[DType], args
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
transforms = []
Expand Down Expand Up @@ -795,10 +803,17 @@ def _get_source_transforms(
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
transforms.append(convert_linear_to_conv2d)

elif args.coreml or args.mps:
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
elif args.mps:
# Currently mps doesn't support sdpa op, use the simpler decomposition
# to get free perf gain.
transforms.append(replace_sdpa_with_simple_sdpa)
transforms.append(replace_causal_mask)

elif args.coreml:
if args.coreml_preserve_sdpa:
transforms.append(replace_sdpa_with_coreml_sdpa)
else:
transforms.append(replace_sdpa_with_simple_sdpa)
transforms.append(replace_kv_cache_with_coreml_kv_cache)

return transforms
130 changes: 130 additions & 0 deletions examples/models/llama2/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,136 @@ def replace_sdpa_with_flex_sdpa(module: torch.nn.Module):
return module


@torch.library.custom_op("coreml::sdpa", mutates_args=())
def sdpa(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
) -> torch.Tensor:
"""Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion."""
return torch.ops.aten.scaled_dot_product_attention.default(
q, k, v, attn_mask=attn_mask
)


@torch.library.register_fake("coreml::sdpa")
def _(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
) -> torch.Tensor:
"""Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing."""
expected_shape = list(q.shape)
expected_shape[-1] = v.shape[-1]
return q.new_empty(expected_shape)


class SDPACoreML(torch.nn.Module):
"""Similar to SDPASimple, but with coreml custom op to do SDPA calculation."""

def __init__(
self,
kv_cache: KVCache,
dim: int,
head_dim: int,
n_rep: int,
):
super().__init__()
self.kv_cache = kv_cache
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep

def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bsz,
seqlen,
mask,
):
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
attn_mask = mask[None, None, input_pos]

if self.n_rep > 1:
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)

y = torch.ops.coreml.sdpa(q, k, v, attn_mask)

return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)


def replace_sdpa_with_coreml_sdpa(module: torch.nn.Module):
for name, child in module.named_children():
if isinstance(child, SDPA):
setattr(
module,
name,
SDPACoreML(child.kv_cache, child.dim, child.head_dim, child.n_rep),
)
else:
replace_sdpa_with_coreml_sdpa(child)
return module


class KVCacheCoreML(torch.nn.Module):
"""
Rather than k_out[:, :, input_pos] = k_val, use torch.ops.aten.index_put_,
which can directly translate to CoreML iOS18.silce_update
"""

def __init__(
self,
max_batch_size: int,
max_seq_length: int,
n_heads: int,
head_dim: int,
dtype=torch.float32,
):
super().__init__()
self.max_seq_length = max_seq_length
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)

self.max_batch_size = max_batch_size
self.n_heads = n_heads
self.head_dim = head_dim
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
)
self.register_buffer(
"v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
)

def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)
return k_out, v_out


def replace_kv_cache_with_coreml_kv_cache(module: torch.nn.Module):
for name, child in module.named_children():
if isinstance(child, KVCache):
setattr(
module,
name,
KVCacheCoreML(
child.max_batch_size,
child.max_seq_length,
child.n_heads,
child.head_dim,
child.k_cache.dtype,
),
)
else:
replace_kv_cache_with_coreml_kv_cache(child)
return module


class KVCacheSimple(torch.nn.Module):
def __init__(
self,
Expand Down
4 changes: 4 additions & 0 deletions extension/llm/export/partitioner_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def get_mps_partitioner(use_kv_cache: bool = False):

def get_coreml_partitioner(
enable_state: bool = False,
preserve_sdpa: bool = True,
embedding_quantize: Optional[str] = None,
pt2e_quantize: Optional[str] = None,
coreml_quantize: Optional[str] = None,
Expand All @@ -78,6 +79,9 @@ def get_coreml_partitioner(
# In Core ML, stateful execution is introduced in iOS 18
if enable_state:
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
# In Core ML, sdpa op is introduced in iOS 18
if preserve_sdpa:
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
# In Core ML, quantization is introduced in iOS 16
if embedding_quantize is not None or pt2e_quantize is not None:
minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS16)
Expand Down
Loading