diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index 58bc0859c79..cc9eb9f02ee 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -3,7 +3,6 @@ # pyre-strict import argparse -import json import sys @@ -20,10 +19,11 @@ from executorch.exir.passes import MemoryPlanningPass from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass -from executorch.extension.export_util.utils import export_to_edge, save_pte_program +from executorch.exir.program._program import to_edge_with_preserved_ops +from executorch.extension.export_util.utils import save_pte_program sys.path.insert(0, ".") -from llama_transformer import InputManager, ModelArgs, Transformer +from llama_transformer import InputManager, load_model class SplitLinearModule(torch.nn.Module): @@ -141,42 +141,23 @@ def main() -> None: default=8, help="Maximum number of splits to divide linear layers", ) + parser.add_argument( + "--dtype", + type=str, + default="fp16", + ) export_args = parser.parse_args() - params_path = export_args.params - checkpoint_path = export_args.checkpoint - - # Load model args - with open(params_path, "r") as f: - params = json.loads(f.read()) - - args = ModelArgs( - max_seq_len=export_args.max_seq_length, - generate_full_logits=False, + model = load_model( + export_args.checkpoint, + export_args.params, + max_seq_length=export_args.max_seq_length, use_cache_list=export_args.use_cache_list, - **params, - ) - - with torch.device("meta"): - model = Transformer(args) - - checkpoint = torch.load( - checkpoint_path, map_location="cpu", mmap=True, weights_only=True ) - if "model" in checkpoint: - checkpoint = checkpoint["model"] - missing, unexpected = model.load_state_dict( - checkpoint, - strict=False, - assign=True, - ) - print("Missing keys: ", missing) - print("Unexpected keys: ", unexpected) - - float_dtype = torch.float16 # dtype for model/inputs - model.eval() - model.to(float_dtype) + float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[ + export_args.dtype + ] # dtype for model/inputs if export_args.embedding_quantize: bitwidth, group_size = export_args.embedding_quantize.split(",") @@ -197,7 +178,8 @@ def main() -> None: model, export_args.target_split_size, export_args.max_splits ) - model = model.to(float_dtype) + model.eval() + model.to(float_dtype) op_linear_quantizer_config = None if export_args.coreml_quantize == "b4w": @@ -217,7 +199,10 @@ def main() -> None: compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16] minimum_deployment_target=ct.target.iOS18, - compute_precision=ct.precision(ct.precision.FLOAT16.value), + compute_precision={ + torch.float16: ct.precision.FLOAT16, + torch.float32: ct.precision.FLOAT32, + }[float_dtype], compute_unit=ct.ComputeUnit.CPU_AND_NE, model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16] op_linear_quantizer_config=op_linear_quantizer_config, @@ -232,11 +217,11 @@ def main() -> None: ) input_manager = InputManager( - n_layers=args.n_layers, - max_batch_size=args.max_batch_size, - n_kv_heads=args.n_kv_heads, - max_seq_length=args.max_seq_len, - head_dim=args.head_dim, + n_layers=model.params.n_layers, + max_batch_size=model.params.max_batch_size, + n_kv_heads=model.params.n_kv_heads, + max_seq_length=model.params.max_seq_len, + head_dim=model.params.head_dim, use_cache_list=export_args.use_cache_list, seq_length=export_args.seq_length, dtype=float_dtype, @@ -245,10 +230,20 @@ def main() -> None: ) example_inputs = input_manager.get_inputs(tokens=[0]) - edge_manager = export_to_edge( + ep = torch.export.export( model, example_inputs, - edge_compile_config=EdgeCompileConfig( + ) + print("Exported program") + print(ep) + + edge_manager = to_edge_with_preserved_ops( + ep, + preserve_ops=[ + torch.ops.aten.scaled_dot_product_attention.default, + torch.ops.aten.linalg_vector_norm.default, + ], + compile_config=EdgeCompileConfig( _check_ir_validity=False, _skip_type_promotion=(float_dtype == torch.float16), _skip_dim_order=True, diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index 5788bcd5e5a..2ce4c1d2b5b 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -13,8 +13,6 @@ import torch import torch.nn.functional as F -from executorch.examples.models.llama.llama_transformer import RMSNorm - from executorch.examples.models.llama.rope import ( hf_apply_rotary_emb, hf_precompute_freqs_cis, @@ -25,29 +23,6 @@ from torch import nn -# These are just to prevent to_edge from decomposing SDPA -# A better method is to use the to_edge_transform_and_lower API for CoreML -# and not decompose SDPA -@torch.library.custom_op("coreml::sdpa", mutates_args=()) -def sdpa( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor -) -> torch.Tensor: - """Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion.""" - return torch.ops.aten.scaled_dot_product_attention.default( - q, k, v, attn_mask=attn_mask - ) - - -@torch.library.register_fake("coreml::sdpa") -def _( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor -) -> torch.Tensor: - """Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing.""" - expected_shape = list(q.shape) - expected_shape[-1] = v.shape[-1] - return q.new_empty(expected_shape) - - def find_multiple(n: int, k: int) -> int: if n % k == 0: return n @@ -121,6 +96,63 @@ def __post_init__(self): self.head_dim = self.dim // self.n_heads +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + # CoreML ignores casts to FP32, so existing implementation of RMSNorm was not stable + # We instead use (x * sqrt(n)) / norm(x, dim=-1) + # Using torch.norm and preserving this op in CoreML improves stability + # Note, we ignore eps, but could add it by using torch.norm(torch.concat(x, sqrt(n*eps))) in the denominator + # In future, we want to add CoreML support for the functional RMSNorm op + # We have yet to do large scale evaluations on the numeric stability of this solution, but note that + # it appears better than what exists currently (removing FP32 casts and using FP16) + rms_norm_eps0 = ( + x * torch.sqrt(torch.tensor(self.dim, dtype=x.dtype)) + ) / torch.linalg.vector_norm(x, dim=-1, keepdim=True) + return rms_norm_eps0 + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x) + return output * self.weight + + class Rope(torch.nn.Module): def __init__(self, params: ModelArgs): super().__init__() @@ -304,12 +336,11 @@ def forward( k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) - output = torch.ops.coreml.sdpa(q, k, v, attn_mask) - + output = torch.ops.aten.scaled_dot_product_attention.default( + q, k, v, attn_mask=attn_mask + ) output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) - output = self.wo(output) - return output, new_k, new_v @@ -413,6 +444,39 @@ def forward( return logits, k_out, v_out +def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list): + import json + + with open(params_path, "r") as f: + params = json.loads(f.read()) + + args = ModelArgs( + max_seq_len=max_seq_length, + generate_full_logits=False, + use_cache_list=use_cache_list, + **params, + ) + + with torch.device("meta"): + model = Transformer(args) + + checkpoint = torch.load( + checkpoint_path, map_location="cpu", mmap=True, weights_only=True + ) + if "model" in checkpoint: + checkpoint = checkpoint["model"] + + missing, unexpected = model.load_state_dict( + checkpoint, + strict=False, + assign=True, + ) + print("Missing keys: ", missing) + print("Unexpected keys: ", unexpected) + + return model + + class InputManager: def __init__( self, diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md index 353f0b56307..a9efedf6bbe 100644 --- a/examples/apple/coreml/llama/readme.md +++ b/examples/apple/coreml/llama/readme.md @@ -4,7 +4,7 @@ This directory contains ANE-friendly Llama models. Export model with: ``` -python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w +python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w --dtype fp16 ``` (Note the script should be run from the executorch/examples/apple/coreml/llama directory.) @@ -17,6 +17,12 @@ Run model with: python run.py -m /path/to/model.pte -t /path/to/tokenizer.model --prompt "Once upon a time," ``` +The runner can also be used to run an eager model model to compare with CoreML numerics (--use_eager). In this case, you must specify: +* --checkpoint +* --dtype +* --max_seq_length +* --seq_length + (Note the script should be run from the executorch/examples/apple/coreml/llama directory.) diff --git a/examples/apple/coreml/llama/run.py b/examples/apple/coreml/llama/run.py index 65026e1f6bc..501aaee07ed 100644 --- a/examples/apple/coreml/llama/run.py +++ b/examples/apple/coreml/llama/run.py @@ -11,7 +11,7 @@ sys.path.insert(0, ".") from executorch.examples.models.llama.runner.generation import next_token from executorch.examples.models.llama.tokenizer import tiktoken -from llama_transformer import InputManager +from llama_transformer import InputManager, load_model class Tokenizer: @@ -71,28 +71,90 @@ def main() -> None: type=float, default=0.9, ) + parser.add_argument( + "--use_eager", + action="store_true", + ) + parser.add_argument( + "-p", + "--params", + type=str, + default=None, + ) + parser.add_argument( + "-c", + "--checkpoint", + type=str, + default=None, + ) + parser.add_argument("--dtype", type=str, choices=["fp16", "fp32"], default=None) + parser.add_argument( + "--seq_length", + type=int, + default=None, + ) + parser.add_argument( + "--max_seq_length", + type=int, + default=None, + ) + parser.add_argument( + "--cache_size", + type=int, + default=None, + ) args = parser.parse_args() tokenizer = Tokenizer(args.tokenizer) runtime = Runtime.get() - program = runtime.load_program(args.model) - method = program.load_method("forward") - - metadata = method.metadata - print("Method metadata: ", metadata, "\n\n") - - assert ( - metadata.num_inputs() == 6 - ), "Do not export with --use_cache_list for use in pybindings" - # k_cache input - n_layers, max_batch_size, n_kv_heads, cache_size, head_dim = ( - metadata.input_tensor_meta(3).sizes() - ) - - # mask input - seq_length, max_seq_length = metadata.input_tensor_meta(5).sizes() + if args.use_eager: + assert args.params is not None + assert args.checkpoint is not None + assert args.dtype is not None + assert args.max_seq_length is not None + assert args.seq_length is not None + + max_seq_length = args.max_seq_length + seq_length = args.seq_length + model = load_model( + args.checkpoint, + args.params, + max_seq_length=max_seq_length, + use_cache_list=False, + ) + n_layers = model.params.n_layers + max_batch_size = model.params.max_batch_size + n_kv_heads = model.params.n_kv_heads + head_dim = model.params.head_dim + cache_size = args.cache_size + + float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[ + args.dtype + ] # dtype for model/inputs + model.eval() + model.to(float_dtype) + else: + program = runtime.load_program(args.model) + method = program.load_method("forward") + + metadata = method.metadata + print("Method metadata: ", metadata, "\n\n") + + assert ( + metadata.num_inputs() == 6 + ), "Do not export with --use_cache_list for use in pybindings" + # k_cache input + n_layers, max_batch_size, n_kv_heads, cache_size, head_dim = ( + metadata.input_tensor_meta(3).sizes() + ) + float_dtype = {5: torch.float16, 6: torch.float32}[ + metadata.input_tensor_meta(3).dtype() + ] + + # mask input + seq_length, max_seq_length = metadata.input_tensor_meta(5).sizes() input_manager = InputManager( n_layers=n_layers, @@ -102,7 +164,7 @@ def main() -> None: head_dim=head_dim, use_cache_list=False, seq_length=seq_length, - dtype=torch.float16, + dtype=float_dtype, minus_infinity=-30000.0, cache_size=cache_size, ) @@ -117,7 +179,11 @@ def main() -> None: tokens ) processed_tokens = len(tokens) - len(remaining_tokens) - logits, k, v = method.execute(inputs) + if args.use_eager: + logits, k, v = model(*inputs) + else: + logits, k, v = method.execute(inputs) + input_manager.update( input_length=processed_tokens, new_k_caches=k, new_v_caches=v )