Skip to content

Commit 02450cf

Browse files
authored
Add RMSNorm operator support for PyTorch to CoreML conversion (#2585) (#2592)
* Add RMSNorm operator support for PyTorch to CoreML conversion (#2585) * formatted code * handles FP16 overflow for RMSNorm operation
1 parent 2256748 commit 02450cf

File tree

2 files changed

+267
-0
lines changed

2 files changed

+267
-0
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3057,6 +3057,81 @@ def _parse_positional_args(context, node) -> Tuple[Var]:
30573057
else:
30583058
context.add(layer_norm)
30593059

3060+
@register_torch_op
3061+
def rms_norm(context, node):
3062+
# Parse Inputs
3063+
inputs = _get_inputs(context, node, expected=4)
3064+
x = inputs[0]
3065+
normalized_shape = inputs[1]
3066+
weight = inputs[2]
3067+
eps = inputs[3]
3068+
axes = list(range(-len(normalized_shape.val), 0))
3069+
# Store epsilon value to ensure ZeroDivisionError doesn't occur
3070+
# while computing RMSNorm
3071+
eps_val = eps.val if eps is not None else 1e-5
3072+
3073+
# RMS Normalization Formula:
3074+
# RMS(x) = sqrt(E[x^2] + epsilon)
3075+
# out = gamma * x / RMS(x)
3076+
# For more info, check out: `<https://arxiv.org/pdf/1910.07467.pdf>`
3077+
3078+
# Note: Apple Neural Engine (ANE) does not have native RMSNorm support
3079+
# and computing x^2 directly can cause FP16 overflow for
3080+
# large activation values (>256).
3081+
#
3082+
# To ensure ANE compatibility and prevent overflow,
3083+
# we scale the input by its maximum
3084+
# absolute value before computing RMS, then scale back the result.
3085+
# Reference: https://x.com/anemll/status/1942432672007192928
3086+
#
3087+
# Advantages:
3088+
# - Prevents FP16 overflow on ANE.
3089+
# - Maintains ANE placement (avoiding CPU/GPU fallback).
3090+
#
3091+
# Trade-offs:
3092+
# - May introduce slight numerical differences compared
3093+
# to the standard operation due to the division
3094+
# and rescaling operations.
3095+
# - Maximum relative error is typically < 0.1% in practice.
3096+
#
3097+
# Note: For applications requiring exact PyTorch parity,
3098+
# consider using CPU/GPU compute units.
3099+
3100+
max_val_tensor = mb.reduce_max(
3101+
x=mb.abs(x=x, name=node.name + "_abs"),
3102+
axes=axes,
3103+
keep_dims=True,
3104+
name=node.name + "_max_val"
3105+
)
3106+
x_scaled = mb.real_div(x=x, y=max_val_tensor, name=node.name + "_scale")
3107+
x_scale_squared = mb.square(x=x_scaled, name=node.name + "_square")
3108+
mean_squared = mb.reduce_mean(
3109+
x=x_scale_squared,
3110+
axes=axes,
3111+
keep_dims=True,
3112+
name=node.name + "_mean_squared"
3113+
)
3114+
mean_plus_eps = mb.add(
3115+
x=mean_squared,
3116+
y=eps_val,
3117+
name=node.name + "_add_eps"
3118+
)
3119+
rms = mb.sqrt(x=mean_plus_eps, name=node.name + "_rms")
3120+
rms_scaled = mb.mul(
3121+
x=rms,
3122+
y=max_val_tensor,
3123+
name=node.name + "_rms_scaled"
3124+
)
3125+
normalized = mb.real_div(x=x, y=rms_scaled, name=node.name + "_normalized")
3126+
3127+
# Apply weight if provided
3128+
if weight is not None:
3129+
output = mb.mul(x=normalized, y=weight, name=node.name)
3130+
else:
3131+
output = normalized
3132+
3133+
context.add(output, node.name)
3134+
30603135

30613136
@register_torch_op
30623137
def numtotensor(context, node):
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# Copyright (c) 2020, Apple Inc. All rights reserved.
2+
#
3+
# Use of this source code is governed by a BSD-3-clause license that can be
4+
# found in the LICENSE.txt file or at
5+
# https://opensource.org/licenses/BSD-3-Clause
6+
#
7+
# Test suite for RMSNorm feature (Issue #2585)
8+
9+
import pytest
10+
from coremltools._deps import _HAS_TORCH
11+
import numpy as np
12+
13+
14+
# Check if pytorch module is installed
15+
# Also, check if pytorch and coremltools' versions are compatible for this test
16+
if _HAS_TORCH:
17+
import torch
18+
import coremltools as ct
19+
20+
# get package versions
21+
torch_major = int(torch.__version__.split('.')[0])
22+
ct_version_parts = ct.__version__.split('.')
23+
ct_major = int(ct_version_parts[0])
24+
25+
# Run only on PyTorch 2.x and coremltools >= 8.x
26+
_TORCH_COMPATIBLE = torch_major >= 2
27+
_CT_COMPATIBLE = ct_major >= 8
28+
_VERSIONS_COMPATIBLE = _TORCH_COMPATIBLE and _CT_COMPATIBLE
29+
else:
30+
_VERSIONS_COMPATIBLE = False
31+
32+
33+
@pytest.mark.skipif(not _HAS_TORCH, reason="PyTorch not found")
34+
@pytest.mark.skipif(not _VERSIONS_COMPATIBLE, reason="Incompatible versions")
35+
class TestRMSNorm:
36+
"""
37+
Test RMSNorm conversion from PyTorch to CoreML
38+
"""
39+
40+
@staticmethod
41+
@pytest.mark.parametrize(
42+
"input_shape, normalized_shape, elementwise_affine, eps, test_name",
43+
[
44+
# Standard tests
45+
((2, 10, 768), 768, True, 1e-5, "standard_3d"),
46+
((32, 128, 1024), 1024, True, 1e-5, "large_batch"),
47+
((5, 512), 512, True, 1e-5, "2d_input"),
48+
((1, 1, 256), 256, True, 1e-5, "singleton_dims"),
49+
50+
# Without learnable parameters
51+
((10, 512), 512, False, 1e-5, "no_weight"),
52+
((2, 4, 512), 512, False, 1e-5, "no_weight_3d"),
53+
54+
# Different epsilon values
55+
((8, 256), 256, True, 1e-8, "small_epsilon"),
56+
((8, 256), 256, True, 1e-3, "large_epsilon"),
57+
58+
# Multiple axes normalization
59+
((4, 8, 16, 32), (16, 32), True, 1e-5, "multi_axis"),
60+
]
61+
)
62+
def test_rms_norm_conversion(
63+
input_shape,
64+
normalized_shape,
65+
elementwise_affine,
66+
eps,
67+
test_name
68+
):
69+
"""
70+
Test RMSNorm conversion with various configurations
71+
"""
72+
class TestModel(torch.nn.Module):
73+
def __init__(self):
74+
super().__init__()
75+
self.norm = torch.nn.RMSNorm(
76+
normalized_shape,
77+
eps=eps,
78+
elementwise_affine=elementwise_affine
79+
)
80+
81+
def forward(self, x):
82+
return self.norm(x)
83+
84+
model = TestModel()
85+
model.eval()
86+
87+
example = torch.randn(input_shape)
88+
torch_out = model(example)
89+
traced = torch.jit.trace(model, example)
90+
mlmodel = ct.convert(
91+
traced,
92+
inputs=[ct.TensorType(
93+
shape=input_shape,
94+
dtype=np.float32,
95+
name="input"
96+
)],
97+
outputs=[ct.TensorType(name="output", dtype=np.float32)],
98+
convert_to="mlprogram"
99+
)
100+
101+
result = mlmodel.predict({"input": example.numpy()})
102+
coreml_out = result["output"]
103+
104+
# Compare outputs
105+
np.testing.assert_allclose(
106+
torch_out.detach().numpy(),
107+
coreml_out,
108+
rtol=1e-2, # 0.01 relative tolerance
109+
atol=1e-3, # 0.001 absolute tolerance
110+
err_msg=f"Test '{test_name}' failed: outputs don't match"
111+
)
112+
113+
# Verify no NaN or Inf are present in any tensor
114+
assert not np.isnan(coreml_out).any(), \
115+
f"Test '{test_name}' produced NaN values"
116+
assert not np.isinf(coreml_out).any(), \
117+
f"Test '{test_name}' produced Inf values"
118+
119+
@staticmethod
120+
def test_edge_cases():
121+
"""
122+
Test edge cases like zero inputs, very small values
123+
"""
124+
class TestModel(torch.nn.Module):
125+
def __init__(self):
126+
super().__init__()
127+
self.norm = torch.nn.RMSNorm(512)
128+
129+
def forward(self, x):
130+
return self.norm(x)
131+
132+
model = TestModel()
133+
model.eval()
134+
135+
# Test with zeros
136+
zeros = torch.zeros(2, 512)
137+
out_zeros = model(zeros)
138+
assert not torch.isnan(out_zeros).any(), \
139+
"RMSNorm produced NaN with zero input"
140+
141+
# Test with very small values
142+
small = torch.full((2, 512), 1e-10)
143+
out_small = model(small)
144+
assert not torch.isinf(out_small).any(), \
145+
"RMSNorm produced Inf with small input"
146+
147+
@staticmethod
148+
def test_dynamic_shapes():
149+
"""
150+
Test RMSNorm with dynamic input shapes
151+
"""
152+
class TestModel(torch.nn.Module):
153+
def __init__(self):
154+
super().__init__()
155+
self.norm = torch.nn.RMSNorm(768)
156+
157+
def forward(self, x):
158+
return self.norm(x)
159+
160+
model = TestModel()
161+
example = torch.randn(1, 10, 768)
162+
traced = torch.jit.trace(model, example)
163+
164+
# Convert with flexible batch and sequence dimensions
165+
mlmodel = ct.convert(
166+
traced,
167+
inputs=[ct.TensorType(
168+
shape=ct.EnumeratedShapes(
169+
shapes=[[1, 10, 768], [2, 20, 768], [4, 50, 768]],
170+
default=[1, 10, 768],
171+
),
172+
dtype=np.float32,
173+
name="input"
174+
)],
175+
outputs=[ct.TensorType(name="output", dtype=np.float32)],
176+
convert_to="mlprogram"
177+
)
178+
179+
# Test different shapes
180+
for shape in [(1, 10, 768), (2, 20, 768), (4, 50, 768)]:
181+
test_input = torch.randn(shape)
182+
torch_out = model(test_input)
183+
coreml_out = mlmodel.predict({
184+
"input": test_input.numpy()
185+
})["output"]
186+
187+
np.testing.assert_allclose(
188+
torch_out.detach().numpy(),
189+
coreml_out,
190+
rtol=1e-2,
191+
atol=1e-3
192+
)

0 commit comments

Comments
 (0)