Commit 7b40ed8
authored
Allow option to choose between fused LCE and CE loss (#704)
## Summary
More context in #703
This adds a kwarg to the forward method of models that allows users to
choose if they want to use fused LCE or if they want to materialize
logits and use regular CE loss. Currently this decision is made
implicitly using model state.
## Testing Done
Ran this on a downstream transformers pipeline and measured peak memory
consumption. Reduced by as much as 80% on some runs as confirmed, with
no effect on quality. Still running repo tests, but shouldnt expect any
surprises.
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: H100
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence1 parent 8807a54 commit 7b40ed8
File tree
19 files changed
+225
-72
lines changed- src/liger_kernel/transformers/model
- test/convergence
- bf16
- fp32
19 files changed
+225
-72
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
137 | 137 | | |
138 | 138 | | |
139 | 139 | | |
| 140 | + | |
140 | 141 | | |
141 | 142 | | |
142 | 143 | | |
| |||
199 | 200 | | |
200 | 201 | | |
201 | 202 | | |
202 | | - | |
203 | | - | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
204 | 212 | | |
205 | 213 | | |
206 | 214 | | |
| |||
209 | 217 | | |
210 | 218 | | |
211 | 219 | | |
212 | | - | |
| 220 | + | |
213 | 221 | | |
214 | 222 | | |
215 | 223 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
146 | 146 | | |
147 | 147 | | |
148 | 148 | | |
| 149 | + | |
149 | 150 | | |
150 | 151 | | |
151 | 152 | | |
| |||
213 | 214 | | |
214 | 215 | | |
215 | 216 | | |
216 | | - | |
217 | | - | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
218 | 226 | | |
219 | 227 | | |
220 | 228 | | |
| |||
225 | 233 | | |
226 | 234 | | |
227 | 235 | | |
228 | | - | |
| 236 | + | |
229 | 237 | | |
230 | 238 | | |
231 | 239 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
35 | 35 | | |
36 | 36 | | |
37 | 37 | | |
| 38 | + | |
38 | 39 | | |
39 | 40 | | |
40 | 41 | | |
| |||
101 | 102 | | |
102 | 103 | | |
103 | 104 | | |
104 | | - | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
105 | 110 | | |
106 | 111 | | |
107 | 112 | | |
| |||
151 | 156 | | |
152 | 157 | | |
153 | 158 | | |
| 159 | + | |
154 | 160 | | |
155 | 161 | | |
156 | 162 | | |
| |||
272 | 278 | | |
273 | 279 | | |
274 | 280 | | |
275 | | - | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
276 | 288 | | |
277 | 289 | | |
278 | 290 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
| 29 | + | |
29 | 30 | | |
30 | 31 | | |
31 | 32 | | |
| |||
89 | 90 | | |
90 | 91 | | |
91 | 92 | | |
92 | | - | |
93 | | - | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
94 | 102 | | |
95 | 103 | | |
96 | 104 | | |
| |||
100 | 108 | | |
101 | 109 | | |
102 | 110 | | |
103 | | - | |
| 111 | + | |
104 | 112 | | |
105 | 113 | | |
106 | 114 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
151 | 151 | | |
152 | 152 | | |
153 | 153 | | |
| 154 | + | |
154 | 155 | | |
155 | 156 | | |
156 | 157 | | |
| |||
218 | 219 | | |
219 | 220 | | |
220 | 221 | | |
221 | | - | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
222 | 230 | | |
223 | 231 | | |
224 | 232 | | |
| |||
228 | 236 | | |
229 | 237 | | |
230 | 238 | | |
231 | | - | |
| 239 | + | |
232 | 240 | | |
233 | 241 | | |
234 | 242 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
223 | 223 | | |
224 | 224 | | |
225 | 225 | | |
| 226 | + | |
226 | 227 | | |
227 | 228 | | |
228 | 229 | | |
| |||
325 | 326 | | |
326 | 327 | | |
327 | 328 | | |
328 | | - | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
329 | 333 | | |
330 | 334 | | |
331 | 335 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
27 | 27 | | |
28 | 28 | | |
29 | 29 | | |
| 30 | + | |
30 | 31 | | |
31 | 32 | | |
32 | 33 | | |
| |||
93 | 94 | | |
94 | 95 | | |
95 | 96 | | |
96 | | - | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
97 | 104 | | |
98 | 105 | | |
99 | 106 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
156 | 156 | | |
157 | 157 | | |
158 | 158 | | |
| 159 | + | |
159 | 160 | | |
160 | 161 | | |
161 | 162 | | |
| |||
224 | 225 | | |
225 | 226 | | |
226 | 227 | | |
227 | | - | |
228 | | - | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
229 | 237 | | |
230 | 238 | | |
231 | 239 | | |
| |||
235 | 243 | | |
236 | 244 | | |
237 | 245 | | |
238 | | - | |
| 246 | + | |
239 | 247 | | |
240 | 248 | | |
241 | 249 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
147 | 147 | | |
148 | 148 | | |
149 | 149 | | |
| 150 | + | |
150 | 151 | | |
151 | 152 | | |
152 | 153 | | |
| |||
215 | 216 | | |
216 | 217 | | |
217 | 218 | | |
218 | | - | |
219 | | - | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
220 | 228 | | |
221 | 229 | | |
222 | 230 | | |
| |||
226 | 234 | | |
227 | 235 | | |
228 | 236 | | |
229 | | - | |
| 237 | + | |
230 | 238 | | |
231 | 239 | | |
232 | 240 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
| 29 | + | |
29 | 30 | | |
30 | 31 | | |
31 | 32 | | |
| |||
89 | 90 | | |
90 | 91 | | |
91 | 92 | | |
92 | | - | |
93 | | - | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
94 | 102 | | |
95 | 103 | | |
96 | 104 | | |
| |||
100 | 108 | | |
101 | 109 | | |
102 | 110 | | |
103 | | - | |
| 111 | + | |
104 | 112 | | |
105 | 113 | | |
106 | 114 | | |
| |||
0 commit comments