Commit 99ba87d
Adding RMSNorm support to arbitrary x and normalized_dim shapes
Summary:
In D72014553, we were adding initial support for RMS norm for an input in 3 or 4 dimensions, and a weight of dimension 1 (same size as x[:-1])
In this diff, we allow for:
- input of arbitrary shape
- shape broadcasting of w (w must have dim <= 1)
Differential Revision: D724841961 parent 1facfa9 commit 99ba87d
1 file changed
+2
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
139 | 139 | | |
140 | 140 | | |
141 | 141 | | |
142 | | - | |
| 142 | + | |
143 | 143 | | |
144 | 144 | | |
145 | 145 | | |
| |||
212 | 212 | | |
213 | 213 | | |
214 | 214 | | |
215 | | - | |
| 215 | + | |
216 | 216 | | |
217 | 217 | | |
218 | 218 | | |
| |||
0 commit comments