Commit 222a6b6
Use logits_to_keep logic for training runs (#696)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
This PR fixes #694.
It adds a change to slice the outputs according to `logits_to_keep`
before calculating the loss during the training.
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
Co-authored-by: Shivam Sahni <shivam15800@gmail.com>1 parent 7c7edeb commit 222a6b6
File tree
10 files changed
+50
-30
lines changed- src/liger_kernel/transformers/model
10 files changed
+50
-30
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
200 | 200 | | |
201 | 201 | | |
202 | 202 | | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
203 | 206 | | |
204 | 207 | | |
205 | 208 | | |
206 | 209 | | |
207 | 210 | | |
208 | 211 | | |
209 | 212 | | |
210 | | - | |
| 213 | + | |
211 | 214 | | |
212 | 215 | | |
213 | 216 | | |
214 | 217 | | |
215 | 218 | | |
216 | 219 | | |
217 | 220 | | |
218 | | - | |
219 | | - | |
| 221 | + | |
220 | 222 | | |
221 | 223 | | |
222 | 224 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
212 | 212 | | |
213 | 213 | | |
214 | 214 | | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
215 | 218 | | |
216 | 219 | | |
217 | 220 | | |
218 | 221 | | |
219 | 222 | | |
220 | 223 | | |
221 | 224 | | |
222 | | - | |
| 225 | + | |
223 | 226 | | |
224 | 227 | | |
225 | 228 | | |
| |||
229 | 232 | | |
230 | 233 | | |
231 | 234 | | |
232 | | - | |
233 | | - | |
| 235 | + | |
234 | 236 | | |
235 | 237 | | |
236 | 238 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
88 | 88 | | |
89 | 89 | | |
90 | 90 | | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
91 | 94 | | |
92 | 95 | | |
93 | 96 | | |
94 | 97 | | |
95 | 98 | | |
96 | 99 | | |
97 | 100 | | |
98 | | - | |
| 101 | + | |
99 | 102 | | |
100 | 103 | | |
101 | 104 | | |
| |||
104 | 107 | | |
105 | 108 | | |
106 | 109 | | |
107 | | - | |
108 | | - | |
| 110 | + | |
109 | 111 | | |
110 | 112 | | |
111 | 113 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
209 | 209 | | |
210 | 210 | | |
211 | 211 | | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
212 | 215 | | |
213 | 216 | | |
214 | 217 | | |
| |||
219 | 222 | | |
220 | 223 | | |
221 | 224 | | |
222 | | - | |
| 225 | + | |
223 | 226 | | |
224 | 227 | | |
225 | 228 | | |
| |||
228 | 231 | | |
229 | 232 | | |
230 | 233 | | |
231 | | - | |
232 | | - | |
| 234 | + | |
233 | 235 | | |
234 | 236 | | |
235 | 237 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
91 | 91 | | |
92 | 92 | | |
93 | 93 | | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
94 | 97 | | |
95 | 98 | | |
96 | 99 | | |
97 | 100 | | |
98 | 101 | | |
99 | 102 | | |
100 | 103 | | |
101 | | - | |
| 104 | + | |
102 | 105 | | |
103 | 106 | | |
104 | 107 | | |
| |||
107 | 110 | | |
108 | 111 | | |
109 | 112 | | |
110 | | - | |
111 | | - | |
| 113 | + | |
112 | 114 | | |
113 | 115 | | |
114 | 116 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
225 | 225 | | |
226 | 226 | | |
227 | 227 | | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
228 | 231 | | |
229 | 232 | | |
230 | 233 | | |
231 | 234 | | |
232 | 235 | | |
233 | 236 | | |
234 | 237 | | |
235 | | - | |
| 238 | + | |
236 | 239 | | |
237 | 240 | | |
238 | 241 | | |
| |||
241 | 244 | | |
242 | 245 | | |
243 | 246 | | |
244 | | - | |
245 | | - | |
| 247 | + | |
246 | 248 | | |
247 | 249 | | |
248 | 250 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
215 | 215 | | |
216 | 216 | | |
217 | 217 | | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
218 | 221 | | |
219 | 222 | | |
220 | 223 | | |
221 | 224 | | |
222 | 225 | | |
223 | 226 | | |
224 | 227 | | |
225 | | - | |
| 228 | + | |
226 | 229 | | |
227 | 230 | | |
228 | 231 | | |
| |||
231 | 234 | | |
232 | 235 | | |
233 | 236 | | |
234 | | - | |
235 | | - | |
| 237 | + | |
236 | 238 | | |
237 | 239 | | |
238 | 240 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
88 | 88 | | |
89 | 89 | | |
90 | 90 | | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
91 | 94 | | |
92 | 95 | | |
93 | 96 | | |
94 | 97 | | |
95 | 98 | | |
96 | 99 | | |
97 | 100 | | |
98 | | - | |
| 101 | + | |
99 | 102 | | |
100 | 103 | | |
101 | 104 | | |
| |||
104 | 107 | | |
105 | 108 | | |
106 | 109 | | |
107 | | - | |
108 | | - | |
| 110 | + | |
109 | 111 | | |
110 | 112 | | |
111 | 113 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
213 | 213 | | |
214 | 214 | | |
215 | 215 | | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
216 | 219 | | |
217 | 220 | | |
218 | 221 | | |
219 | 222 | | |
220 | 223 | | |
221 | 224 | | |
222 | 225 | | |
223 | | - | |
| 226 | + | |
224 | 227 | | |
225 | 228 | | |
226 | 229 | | |
| |||
229 | 232 | | |
230 | 233 | | |
231 | 234 | | |
232 | | - | |
233 | | - | |
| 235 | + | |
234 | 236 | | |
235 | 237 | | |
236 | 238 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
199 | 199 | | |
200 | 200 | | |
201 | 201 | | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
202 | 205 | | |
203 | 206 | | |
204 | 207 | | |
205 | 208 | | |
206 | 209 | | |
207 | 210 | | |
208 | 211 | | |
209 | | - | |
| 212 | + | |
210 | 213 | | |
211 | 214 | | |
212 | 215 | | |
| |||
215 | 218 | | |
216 | 219 | | |
217 | 220 | | |
218 | | - | |
219 | | - | |
| 221 | + | |
220 | 222 | | |
221 | 223 | | |
222 | 224 | | |
| |||
0 commit comments