|
23 | 23 | from torch import nn |
24 | 24 |
|
25 | 25 |
|
26 | | -# These are just to prevent to_edge from decomposing SDPA |
27 | | -# A better method is to use the to_edge_transform_and_lower API for CoreML |
28 | | -# and not decompose SDPA |
29 | | -@torch.library.custom_op("coreml::sdpa", mutates_args=()) |
30 | | -def sdpa( |
31 | | - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor |
32 | | -) -> torch.Tensor: |
33 | | - """Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion.""" |
34 | | - return torch.ops.aten.scaled_dot_product_attention.default( |
35 | | - q, k, v, attn_mask=attn_mask |
36 | | - ) |
37 | | - |
38 | | - |
39 | | -@torch.library.register_fake("coreml::sdpa") |
40 | | -def _( |
41 | | - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor |
42 | | -) -> torch.Tensor: |
43 | | - """Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing.""" |
44 | | - expected_shape = list(q.shape) |
45 | | - expected_shape[-1] = v.shape[-1] |
46 | | - return q.new_empty(expected_shape) |
47 | | - |
48 | | - |
49 | 26 | def find_multiple(n: int, k: int) -> int: |
50 | 27 | if n % k == 0: |
51 | 28 | return n |
@@ -149,10 +126,15 @@ def _norm(self, x): |
149 | 126 | torch.Tensor: The normalized tensor. |
150 | 127 |
|
151 | 128 | """ |
152 | | - x_max, _ = torch.abs(x).max(-1, keepdim=True) |
153 | | - x = x / x_max # This makes the op more stable in FP16 |
154 | | - eps = self.eps / (x_max * x_max) |
155 | | - return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + eps) |
| 129 | + # CoreML ignores casts to FP32, so existing implementation of RMSNorm was not stable |
| 130 | + # We instead use (x * sqrt(n)) / norm(x, dim=-1) |
| 131 | + # Using torch.norm and preserving this op in CoreML improves stability |
| 132 | + # Note, we ignore eps, but could add it by using torch.norm(torch.concat(x, sqrt(n*eps))) in the denominator |
| 133 | + # In future, we want to add CoreML support for the functional RMSNorm op |
| 134 | + rms_norm_eps0 = ( |
| 135 | + x * torch.sqrt(torch.tensor(self.dim, dtype=x.dtype)) |
| 136 | + ) / torch.linalg.vector_norm(x, dim=-1, keepdim=True) |
| 137 | + return rms_norm_eps0 |
156 | 138 |
|
157 | 139 | def forward(self, x): |
158 | 140 | """ |
@@ -352,7 +334,9 @@ def forward( |
352 | 334 | k = k.repeat_interleave(self.n_rep, dim=1) |
353 | 335 | v = v.repeat_interleave(self.n_rep, dim=1) |
354 | 336 |
|
355 | | - output = torch.ops.coreml.sdpa(q, k, v, attn_mask) |
| 337 | + output = torch.ops.aten.scaled_dot_product_attention.default( |
| 338 | + q, k, v, attn_mask=attn_mask |
| 339 | + ) |
356 | 340 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) |
357 | 341 | output = self.wo(output) |
358 | 342 | return output, new_k, new_v |
|
0 commit comments