Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 24 additions & 40 deletions examples/apple/coreml/llama/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# pyre-strict

import argparse
import json

import sys

Expand All @@ -23,7 +22,7 @@
from executorch.extension.export_util.utils import export_to_edge, save_pte_program

sys.path.insert(0, ".")
from llama_transformer import InputManager, ModelArgs, Transformer
from llama_transformer import InputManager, load_model


class SplitLinearModule(torch.nn.Module):
Expand Down Expand Up @@ -141,42 +140,23 @@ def main() -> None:
default=8,
help="Maximum number of splits to divide linear layers",
)
parser.add_argument(
"--dtype",
type=str,
default="fp16",
)

export_args = parser.parse_args()
params_path = export_args.params
checkpoint_path = export_args.checkpoint

# Load model args
with open(params_path, "r") as f:
params = json.loads(f.read())

args = ModelArgs(
max_seq_len=export_args.max_seq_length,
generate_full_logits=False,
model = load_model(
export_args.checkpoint,
export_args.params,
max_seq_length=export_args.max_seq_length,
use_cache_list=export_args.use_cache_list,
**params,
)

with torch.device("meta"):
model = Transformer(args)

checkpoint = torch.load(
checkpoint_path, map_location="cpu", mmap=True, weights_only=True
)
if "model" in checkpoint:
checkpoint = checkpoint["model"]

missing, unexpected = model.load_state_dict(
checkpoint,
strict=False,
assign=True,
)
print("Missing keys: ", missing)
print("Unexpected keys: ", unexpected)

float_dtype = torch.float16 # dtype for model/inputs
model.eval()
model.to(float_dtype)
float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[
export_args.dtype
] # dtype for model/inputs

if export_args.embedding_quantize:
bitwidth, group_size = export_args.embedding_quantize.split(",")
Expand All @@ -197,7 +177,8 @@ def main() -> None:
model, export_args.target_split_size, export_args.max_splits
)

model = model.to(float_dtype)
model.eval()
model.to(float_dtype)

op_linear_quantizer_config = None
if export_args.coreml_quantize == "b4w":
Expand All @@ -217,7 +198,10 @@ def main() -> None:

compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
minimum_deployment_target=ct.target.iOS18,
compute_precision=ct.precision(ct.precision.FLOAT16.value),
compute_precision={
torch.float16: ct.precision.FLOAT16,
torch.float32: ct.precision.FLOAT32,
}[float_dtype],
compute_unit=ct.ComputeUnit.CPU_AND_NE,
model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16]
op_linear_quantizer_config=op_linear_quantizer_config,
Expand All @@ -232,11 +216,11 @@ def main() -> None:
)

input_manager = InputManager(
n_layers=args.n_layers,
max_batch_size=args.max_batch_size,
n_kv_heads=args.n_kv_heads,
max_seq_length=args.max_seq_len,
head_dim=args.head_dim,
n_layers=model.params.n_layers,
max_batch_size=model.params.max_batch_size,
n_kv_heads=model.params.n_kv_heads,
max_seq_length=model.params.max_seq_len,
head_dim=model.params.head_dim,
use_cache_list=export_args.use_cache_list,
seq_length=export_args.seq_length,
dtype=float_dtype,
Expand Down
88 changes: 83 additions & 5 deletions examples/apple/coreml/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import torch
import torch.nn.functional as F

from executorch.examples.models.llama.llama_transformer import RMSNorm

from executorch.examples.models.llama.rope import (
hf_apply_rotary_emb,
hf_precompute_freqs_cis,
Expand Down Expand Up @@ -121,6 +119,56 @@ def __post_init__(self):
self.head_dim = self.dim // self.n_heads


class RMSNorm(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does CoreML support RMSNorm op? It will be a lot easier if they do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@YifanShenSZ YifanShenSZ Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something existing in Core ML is the translation for torch.norm, which uses Core ML fused reduce_l2_norm kernel

That is to say, we may compute RMS norm by something like

x / torch.norm(x)

Copy link
Contributor Author

@metascroy metascroy Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a slightly different op when eps > 0, although I'm not sure how much it matters in practice.

RMSNorm would actually be something like x / torch.norm([x/sqrt(n), sqrt(eps)])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update to use norm, and then maybe we can work on a longer term solution of support rmsnorm in CoreML @YifanShenSZ?

Copy link
Contributor Author

@metascroy metascroy Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In sync with CoreML team, we might try using https://pytorch.org/docs/stable/generated/torch.nn.functional.rms_norm.html and then write an CoreML op definition for it here: https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/ops.py

@YifanShenSZ mentioned they have a fused norm op that could be used.

Copy link
Collaborator

@YifanShenSZ YifanShenSZ Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there are 2 possibilities

  1. (Simpler) As mentioned above, we already have torch.norm translation using Core ML fused reduced_l2_norm, so we may have a RMS norm torch source like x / torch.norm(x)
  2. (Better) I think we can further fuse to Core ML l2_norm, which may directly correspond to RMS norm? (with some restrictions, though, e.g. x must have rank >= 3) We will need to add the translation function in CoreMLTools

def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.

Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.

"""
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The normalized tensor.

"""
x_max, _ = torch.abs(x).max(-1, keepdim=True)
x = x / x_max # This makes the op more stable in FP16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll leave review for somebody who is better at math, but I'll just note that it is not at all obvious to me that this does not change the result of the operation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't because both numerator/denominator are divided by same thing (x_max). Because denominator is under square root, we divide by x_max**2 there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the function produces equal results before and after. But would it be a concern that if we get very small values of x_max and the result of eps = self.eps / (x_max * x_max) could overflow? Should we dynamically use torch.finfo(x.dtype).eps for different dtypes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont fully follow why this rewrite should hold better for fp16? If you are normalizing by max value then I am presuming that rsqrt is suffering from precision loss of fp16? It is not at all clear

eps = self.eps / (x_max * x_max)
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + eps)

def forward(self, x):
"""
Forward pass through the RMSNorm layer.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The output tensor after applying RMSNorm.

"""
output = self._norm(x)
return output * self.weight


class Rope(torch.nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
Expand Down Expand Up @@ -305,11 +353,8 @@ def forward(
v = v.repeat_interleave(self.n_rep, dim=1)

output = torch.ops.coreml.sdpa(q, k, v, attn_mask)

output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

output = self.wo(output)

return output, new_k, new_v


Expand Down Expand Up @@ -413,6 +458,39 @@ def forward(
return logits, k_out, v_out


def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list):
import json

with open(params_path, "r") as f:
params = json.loads(f.read())

args = ModelArgs(
max_seq_len=max_seq_length,
generate_full_logits=False,
use_cache_list=use_cache_list,
**params,
)

with torch.device("meta"):
model = Transformer(args)

checkpoint = torch.load(
checkpoint_path, map_location="cpu", mmap=True, weights_only=True
)
if "model" in checkpoint:
checkpoint = checkpoint["model"]

missing, unexpected = model.load_state_dict(
checkpoint,
strict=False,
assign=True,
)
print("Missing keys: ", missing)
print("Unexpected keys: ", unexpected)

return model


class InputManager:
def __init__(
self,
Expand Down
8 changes: 7 additions & 1 deletion examples/apple/coreml/llama/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ This directory contains ANE-friendly Llama models.

Export model with:
```
python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w
python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w --dtype fp16
```

(Note the script should be run from the executorch/examples/apple/coreml/llama directory.)
Expand All @@ -17,6 +17,12 @@ Run model with:
python run.py -m /path/to/model.pte -t /path/to/tokenizer.model --prompt "Once upon a time,"
```

The runner can also be used to run an eager model model to compare with CoreML numerics (--use_eager). In this case, you must specify:
* --checkpoint
* --dtype
* --max_seq_length
* --seq_length

(Note the script should be run from the executorch/examples/apple/coreml/llama directory.)


Expand Down
104 changes: 85 additions & 19 deletions examples/apple/coreml/llama/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
sys.path.insert(0, ".")
from executorch.examples.models.llama.runner.generation import next_token
from executorch.examples.models.llama.tokenizer import tiktoken
from llama_transformer import InputManager
from llama_transformer import InputManager, load_model


class Tokenizer:
Expand Down Expand Up @@ -71,28 +71,90 @@ def main() -> None:
type=float,
default=0.9,
)
parser.add_argument(
"--use_eager",
action="store_true",
)
parser.add_argument(
"-p",
"--params",
type=str,
default=None,
)
parser.add_argument(
"-c",
"--checkpoint",
type=str,
default=None,
)
parser.add_argument("--dtype", type=str, choices=["fp16", "fp32"], default=None)
parser.add_argument(
"--seq_length",
type=int,
default=None,
)
parser.add_argument(
"--max_seq_length",
type=int,
default=None,
)
parser.add_argument(
"--cache_size",
type=int,
default=None,
)

args = parser.parse_args()

tokenizer = Tokenizer(args.tokenizer)

runtime = Runtime.get()
program = runtime.load_program(args.model)
method = program.load_method("forward")

metadata = method.metadata
print("Method metadata: ", metadata, "\n\n")

assert (
metadata.num_inputs() == 6
), "Do not export with --use_cache_list for use in pybindings"
# k_cache input
n_layers, max_batch_size, n_kv_heads, cache_size, head_dim = (
metadata.input_tensor_meta(3).sizes()
)

# mask input
seq_length, max_seq_length = metadata.input_tensor_meta(5).sizes()
if args.use_eager:
assert args.params is not None
assert args.checkpoint is not None
assert args.dtype is not None
assert args.max_seq_length is not None
assert args.seq_length is not None

max_seq_length = args.max_seq_length
seq_length = args.seq_length
model = load_model(
args.checkpoint,
args.params,
max_seq_length=max_seq_length,
use_cache_list=False,
)
n_layers = model.params.n_layers
max_batch_size = model.params.max_batch_size
n_kv_heads = model.params.n_kv_heads
head_dim = model.params.head_dim
cache_size = args.cache_size

float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[
args.dtype
] # dtype for model/inputs
model.eval()
model.to(float_dtype)
else:
program = runtime.load_program(args.model)
method = program.load_method("forward")

metadata = method.metadata
print("Method metadata: ", metadata, "\n\n")

assert (
metadata.num_inputs() == 6
), "Do not export with --use_cache_list for use in pybindings"
# k_cache input
n_layers, max_batch_size, n_kv_heads, cache_size, head_dim = (
metadata.input_tensor_meta(3).sizes()
)
float_dtype = {5: torch.float16, 6: torch.float32}[
metadata.input_tensor_meta(3).dtype()
]

# mask input
seq_length, max_seq_length = metadata.input_tensor_meta(5).sizes()

input_manager = InputManager(
n_layers=n_layers,
Expand All @@ -102,7 +164,7 @@ def main() -> None:
head_dim=head_dim,
use_cache_list=False,
seq_length=seq_length,
dtype=torch.float16,
dtype=float_dtype,
minus_infinity=-30000.0,
cache_size=cache_size,
)
Expand All @@ -117,7 +179,11 @@ def main() -> None:
tokens
)
processed_tokens = len(tokens) - len(remaining_tokens)
logits, k, v = method.execute(inputs)
if args.use_eager:
logits, k, v = model(*inputs)
else:
logits, k, v = method.execute(inputs)

input_manager.update(
input_length=processed_tokens, new_k_caches=k, new_v_caches=v
)
Expand Down
Loading