-
Notifications
You must be signed in to change notification settings - Fork 741
Add dtype, fix RMS norm for FP16 #8641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,8 +13,6 @@ | |
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| from executorch.examples.models.llama.llama_transformer import RMSNorm | ||
|
|
||
| from executorch.examples.models.llama.rope import ( | ||
| hf_apply_rotary_emb, | ||
| hf_precompute_freqs_cis, | ||
|
|
@@ -121,6 +119,56 @@ def __post_init__(self): | |
| self.head_dim = self.dim // self.n_heads | ||
|
|
||
|
|
||
| class RMSNorm(torch.nn.Module): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In sync with CoreML team, we might try using https://pytorch.org/docs/stable/generated/torch.nn.functional.rms_norm.html and then write an CoreML op definition for it here: https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/ops.py @YifanShenSZ mentioned they have a fused norm op that could be used.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, there are 2 possibilities
|
||
| def __init__(self, dim: int, eps: float = 1e-6): | ||
| """ | ||
| Initialize the RMSNorm normalization layer. | ||
|
|
||
| Args: | ||
| dim (int): The dimension of the input tensor. | ||
| eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. | ||
|
|
||
| Attributes: | ||
| eps (float): A small value added to the denominator for numerical stability. | ||
| weight (nn.Parameter): Learnable scaling parameter. | ||
|
|
||
| """ | ||
| super().__init__() | ||
| self.dim = dim | ||
| self.eps = eps | ||
| self.weight = nn.Parameter(torch.ones(dim)) | ||
|
|
||
| def _norm(self, x): | ||
| """ | ||
| Apply the RMSNorm normalization to the input tensor. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): The input tensor. | ||
|
|
||
| Returns: | ||
| torch.Tensor: The normalized tensor. | ||
|
|
||
| """ | ||
| x_max, _ = torch.abs(x).max(-1, keepdim=True) | ||
| x = x / x_max # This makes the op more stable in FP16 | ||
|
||
| eps = self.eps / (x_max * x_max) | ||
| return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + eps) | ||
|
|
||
| def forward(self, x): | ||
| """ | ||
| Forward pass through the RMSNorm layer. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): The input tensor. | ||
|
|
||
| Returns: | ||
| torch.Tensor: The output tensor after applying RMSNorm. | ||
|
|
||
| """ | ||
| output = self._norm(x) | ||
| return output * self.weight | ||
|
|
||
|
|
||
| class Rope(torch.nn.Module): | ||
| def __init__(self, params: ModelArgs): | ||
| super().__init__() | ||
|
|
@@ -305,11 +353,8 @@ def forward( | |
| v = v.repeat_interleave(self.n_rep, dim=1) | ||
|
|
||
| output = torch.ops.coreml.sdpa(q, k, v, attn_mask) | ||
|
|
||
| output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) | ||
|
|
||
| output = self.wo(output) | ||
|
|
||
| return output, new_k, new_v | ||
|
|
||
|
|
||
|
|
@@ -413,6 +458,39 @@ def forward( | |
| return logits, k_out, v_out | ||
|
|
||
|
|
||
| def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list): | ||
| import json | ||
|
|
||
| with open(params_path, "r") as f: | ||
| params = json.loads(f.read()) | ||
|
|
||
| args = ModelArgs( | ||
| max_seq_len=max_seq_length, | ||
| generate_full_logits=False, | ||
| use_cache_list=use_cache_list, | ||
| **params, | ||
| ) | ||
|
|
||
| with torch.device("meta"): | ||
| model = Transformer(args) | ||
|
|
||
| checkpoint = torch.load( | ||
| checkpoint_path, map_location="cpu", mmap=True, weights_only=True | ||
| ) | ||
| if "model" in checkpoint: | ||
| checkpoint = checkpoint["model"] | ||
|
|
||
| missing, unexpected = model.load_state_dict( | ||
| checkpoint, | ||
| strict=False, | ||
| assign=True, | ||
| ) | ||
| print("Missing keys: ", missing) | ||
| print("Unexpected keys: ", unexpected) | ||
|
|
||
| return model | ||
|
|
||
|
|
||
| class InputManager: | ||
| def __init__( | ||
| self, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does CoreML support RMSNorm op? It will be a lot easier if they do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see it here: https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/ops.py
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something existing in Core ML is the translation for torch.norm, which uses Core ML fused reduce_l2_norm kernel
That is to say, we may compute RMS norm by something like
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a slightly different op when eps > 0, although I'm not sure how much it matters in practice.
RMSNorm would actually be something like x / torch.norm([x/sqrt(n), sqrt(eps)])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update to use norm, and then maybe we can work on a longer term solution of support rmsnorm in CoreML @YifanShenSZ?