Commit 5b53f8c
Adding RMSNorm support to arbitrary x and normalized_dim shapes (pytorch#9966)
Summary:
In D72014553, we were adding initial support for RMS norm for an input in 3 or 4 dimensions, and a weight of dimension 1 (same size as x[:-1])
In this diff, we allow for:
- input of arbitrary shape
- shape broadcasting of w (w must have dim <= 1)
Differential Revision: D724841961 parent 8d80185 commit 5b53f8c
1 file changed
+4
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
139 | 139 | | |
140 | 140 | | |
141 | 141 | | |
142 | | - | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
143 | 145 | | |
144 | 146 | | |
145 | 147 | | |
| |||
212 | 214 | | |
213 | 215 | | |
214 | 216 | | |
215 | | - | |
| 217 | + | |
216 | 218 | | |
217 | 219 | | |
218 | 220 | | |
| |||
0 commit comments