|
13 | 13 | import torch |
14 | 14 | import torch.nn.functional as F |
15 | 15 |
|
16 | | -from executorch.examples.models.llama.llama_transformer import RMSNorm |
17 | | - |
18 | 16 | from executorch.examples.models.llama.rope import ( |
19 | 17 | hf_apply_rotary_emb, |
20 | 18 | hf_precompute_freqs_cis, |
|
25 | 23 | from torch import nn |
26 | 24 |
|
27 | 25 |
|
28 | | -# These are just to prevent to_edge from decomposing SDPA |
29 | | -# A better method is to use the to_edge_transform_and_lower API for CoreML |
30 | | -# and not decompose SDPA |
31 | | -@torch.library.custom_op("coreml::sdpa", mutates_args=()) |
32 | | -def sdpa( |
33 | | - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor |
34 | | -) -> torch.Tensor: |
35 | | - """Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion.""" |
36 | | - return torch.ops.aten.scaled_dot_product_attention.default( |
37 | | - q, k, v, attn_mask=attn_mask |
38 | | - ) |
39 | | - |
40 | | - |
41 | | -@torch.library.register_fake("coreml::sdpa") |
42 | | -def _( |
43 | | - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor |
44 | | -) -> torch.Tensor: |
45 | | - """Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing.""" |
46 | | - expected_shape = list(q.shape) |
47 | | - expected_shape[-1] = v.shape[-1] |
48 | | - return q.new_empty(expected_shape) |
49 | | - |
50 | | - |
51 | 26 | def find_multiple(n: int, k: int) -> int: |
52 | 27 | if n % k == 0: |
53 | 28 | return n |
@@ -121,6 +96,63 @@ def __post_init__(self): |
121 | 96 | self.head_dim = self.dim // self.n_heads |
122 | 97 |
|
123 | 98 |
|
| 99 | +class RMSNorm(torch.nn.Module): |
| 100 | + def __init__(self, dim: int, eps: float = 1e-6): |
| 101 | + """ |
| 102 | + Initialize the RMSNorm normalization layer. |
| 103 | +
|
| 104 | + Args: |
| 105 | + dim (int): The dimension of the input tensor. |
| 106 | + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. |
| 107 | +
|
| 108 | + Attributes: |
| 109 | + eps (float): A small value added to the denominator for numerical stability. |
| 110 | + weight (nn.Parameter): Learnable scaling parameter. |
| 111 | +
|
| 112 | + """ |
| 113 | + super().__init__() |
| 114 | + self.dim = dim |
| 115 | + self.eps = eps |
| 116 | + self.weight = nn.Parameter(torch.ones(dim)) |
| 117 | + |
| 118 | + def _norm(self, x): |
| 119 | + """ |
| 120 | + Apply the RMSNorm normalization to the input tensor. |
| 121 | +
|
| 122 | + Args: |
| 123 | + x (torch.Tensor): The input tensor. |
| 124 | +
|
| 125 | + Returns: |
| 126 | + torch.Tensor: The normalized tensor. |
| 127 | +
|
| 128 | + """ |
| 129 | + # CoreML ignores casts to FP32, so existing implementation of RMSNorm was not stable |
| 130 | + # We instead use (x * sqrt(n)) / norm(x, dim=-1) |
| 131 | + # Using torch.norm and preserving this op in CoreML improves stability |
| 132 | + # Note, we ignore eps, but could add it by using torch.norm(torch.concat(x, sqrt(n*eps))) in the denominator |
| 133 | + # In future, we want to add CoreML support for the functional RMSNorm op |
| 134 | + # We have yet to do large scale evaluations on the numeric stability of this solution, but note that |
| 135 | + # it appears better than what exists currently (removing FP32 casts and using FP16) |
| 136 | + rms_norm_eps0 = ( |
| 137 | + x * torch.sqrt(torch.tensor(self.dim, dtype=x.dtype)) |
| 138 | + ) / torch.linalg.vector_norm(x, dim=-1, keepdim=True) |
| 139 | + return rms_norm_eps0 |
| 140 | + |
| 141 | + def forward(self, x): |
| 142 | + """ |
| 143 | + Forward pass through the RMSNorm layer. |
| 144 | +
|
| 145 | + Args: |
| 146 | + x (torch.Tensor): The input tensor. |
| 147 | +
|
| 148 | + Returns: |
| 149 | + torch.Tensor: The output tensor after applying RMSNorm. |
| 150 | +
|
| 151 | + """ |
| 152 | + output = self._norm(x) |
| 153 | + return output * self.weight |
| 154 | + |
| 155 | + |
124 | 156 | class Rope(torch.nn.Module): |
125 | 157 | def __init__(self, params: ModelArgs): |
126 | 158 | super().__init__() |
@@ -304,12 +336,11 @@ def forward( |
304 | 336 | k = k.repeat_interleave(self.n_rep, dim=1) |
305 | 337 | v = v.repeat_interleave(self.n_rep, dim=1) |
306 | 338 |
|
307 | | - output = torch.ops.coreml.sdpa(q, k, v, attn_mask) |
308 | | - |
| 339 | + output = torch.ops.aten.scaled_dot_product_attention.default( |
| 340 | + q, k, v, attn_mask=attn_mask |
| 341 | + ) |
309 | 342 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) |
310 | | - |
311 | 343 | output = self.wo(output) |
312 | | - |
313 | 344 | return output, new_k, new_v |
314 | 345 |
|
315 | 346 |
|
@@ -413,6 +444,39 @@ def forward( |
413 | 444 | return logits, k_out, v_out |
414 | 445 |
|
415 | 446 |
|
| 447 | +def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list): |
| 448 | + import json |
| 449 | + |
| 450 | + with open(params_path, "r") as f: |
| 451 | + params = json.loads(f.read()) |
| 452 | + |
| 453 | + args = ModelArgs( |
| 454 | + max_seq_len=max_seq_length, |
| 455 | + generate_full_logits=False, |
| 456 | + use_cache_list=use_cache_list, |
| 457 | + **params, |
| 458 | + ) |
| 459 | + |
| 460 | + with torch.device("meta"): |
| 461 | + model = Transformer(args) |
| 462 | + |
| 463 | + checkpoint = torch.load( |
| 464 | + checkpoint_path, map_location="cpu", mmap=True, weights_only=True |
| 465 | + ) |
| 466 | + if "model" in checkpoint: |
| 467 | + checkpoint = checkpoint["model"] |
| 468 | + |
| 469 | + missing, unexpected = model.load_state_dict( |
| 470 | + checkpoint, |
| 471 | + strict=False, |
| 472 | + assign=True, |
| 473 | + ) |
| 474 | + print("Missing keys: ", missing) |
| 475 | + print("Unexpected keys: ", unexpected) |
| 476 | + |
| 477 | + return model |
| 478 | + |
| 479 | + |
416 | 480 | class InputManager: |
417 | 481 | def __init__( |
418 | 482 | self, |
|
0 commit comments