Skip to content

Commit 222a6b6

Browse files
Use logits_to_keep logic for training runs (#696)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This PR fixes #694. It adds a change to slice the outputs according to `logits_to_keep` before calculating the loss during the training. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence Co-authored-by: Shivam Sahni <shivam15800@gmail.com>
1 parent 7c7edeb commit 222a6b6

File tree

10 files changed

+50
-30
lines changed

10 files changed

+50
-30
lines changed

src/liger_kernel/transformers/model/gemma.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,23 +200,25 @@ def lce_forward(
200200
)
201201

202202
hidden_states = outputs[0]
203+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
204+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
205+
kept_hidden_states = hidden_states[:, slice_indices, :]
203206

204207
shift_labels = loss_kwargs.pop("shift_labels", None)
205208
logits = None
206209
loss = None
207210
# if in training mode, don't materialize logits
208211
if self.training and (labels is not None or shift_labels is not None):
209212
loss = LigerForCausalLMLoss(
210-
hidden_states=hidden_states,
213+
hidden_states=kept_hidden_states,
211214
lm_head_weight=self.lm_head.weight,
212215
labels=labels,
213216
shift_labels=shift_labels,
214217
hidden_size=self.config.hidden_size,
215218
**loss_kwargs,
216219
)
217220
else: # if in inference mode materialize logits
218-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
219-
logits = self.lm_head(hidden_states[:, slice_indices, :])
221+
logits = self.lm_head(kept_hidden_states)
220222
if labels is not None:
221223
loss = self.loss_function(
222224
logits=logits,

src/liger_kernel/transformers/model/gemma2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,14 +212,17 @@ def lce_forward(
212212
)
213213

214214
hidden_states = outputs[0]
215+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
216+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217+
kept_hidden_states = hidden_states[:, slice_indices, :]
215218

216219
shift_labels = loss_kwargs.pop("shift_labels", None)
217220
logits = None
218221
loss = None
219222
# if in training mode, don't materialize logits
220223
if self.training and (labels is not None or shift_labels is not None):
221224
loss = LigerForCausalLMLoss(
222-
hidden_states=hidden_states,
225+
hidden_states=kept_hidden_states,
223226
lm_head_weight=self.lm_head.weight,
224227
labels=labels,
225228
shift_labels=shift_labels,
@@ -229,8 +232,7 @@ def lce_forward(
229232
)
230233

231234
else: # if in inference mode materialize logits
232-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
233-
logits = self.lm_head(hidden_states[:, slice_indices, :])
235+
logits = self.lm_head(kept_hidden_states)
234236
if self.config.final_logit_softcapping is not None:
235237
logits = logits / self.config.final_logit_softcapping
236238
logits = torch.tanh(logits)

src/liger_kernel/transformers/model/glm4.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,17 @@ def lce_forward(
8888
)
8989

9090
hidden_states = outputs[0]
91+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
92+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
93+
kept_hidden_states = hidden_states[:, slice_indices, :]
9194

9295
shift_labels = loss_kwargs.pop("shift_labels", None)
9396
logits = None
9497
loss = None
9598
# if in training mode, don't materialize logits
9699
if self.training and (labels is not None or shift_labels is not None):
97100
loss = LigerForCausalLMLoss(
98-
hidden_states=hidden_states,
101+
hidden_states=kept_hidden_states,
99102
lm_head_weight=self.lm_head.weight,
100103
labels=labels,
101104
shift_labels=shift_labels,
@@ -104,8 +107,7 @@ def lce_forward(
104107
)
105108

106109
else: # if in inference mode materialize logits
107-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
108-
logits = self.lm_head(hidden_states[:, slice_indices, :])
110+
logits = self.lm_head(kept_hidden_states)
109111
if labels is not None:
110112
loss = self.loss_function(
111113
logits=logits,

src/liger_kernel/transformers/model/llama.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,9 @@ def lce_forward(
209209
)
210210

211211
hidden_states = outputs[0]
212+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
213+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
214+
kept_hidden_states = hidden_states[:, slice_indices, :]
212215

213216
if self.config.pretraining_tp > 1:
214217
raise Exception("Liger Kernel does not support pretraining_tp!!")
@@ -219,7 +222,7 @@ def lce_forward(
219222
# if in training mode, don't materialize logits
220223
if self.training and (labels is not None or shift_labels is not None):
221224
loss = LigerForCausalLMLoss(
222-
hidden_states=hidden_states,
225+
hidden_states=kept_hidden_states,
223226
lm_head_weight=self.lm_head.weight,
224227
labels=labels,
225228
shift_labels=shift_labels,
@@ -228,8 +231,7 @@ def lce_forward(
228231
)
229232

230233
else: # if in inference mode materialize logits
231-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
232-
logits = self.lm_head(hidden_states[:, slice_indices, :])
234+
logits = self.lm_head(kept_hidden_states)
233235
if labels is not None:
234236
loss = self.loss_function(
235237
logits=logits,

src/liger_kernel/transformers/model/mistral.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,17 @@ def lce_forward(
9191
)
9292

9393
hidden_states = outputs[0]
94+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
95+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
96+
kept_hidden_states = hidden_states[:, slice_indices, :]
9497

9598
shift_labels = loss_kwargs.pop("shift_labels", None)
9699
loss = None
97100
logits = None
98101

99102
if self.training and (labels is not None or shift_labels is not None):
100103
loss = LigerForCausalLMLoss(
101-
hidden_states=hidden_states,
104+
hidden_states=kept_hidden_states,
102105
lm_head_weight=self.lm_head.weight,
103106
labels=labels,
104107
shift_labels=shift_labels,
@@ -107,8 +110,7 @@ def lce_forward(
107110
)
108111

109112
else:
110-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
111-
logits = self.lm_head(hidden_states[:, slice_indices, :])
113+
logits = self.lm_head(kept_hidden_states)
112114

113115
loss = None
114116
if labels is not None:

src/liger_kernel/transformers/model/mixtral.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,14 +225,17 @@ def lce_forward(
225225
)
226226

227227
hidden_states = outputs[0]
228+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
229+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
230+
kept_hidden_states = hidden_states[:, slice_indices, :]
228231

229232
shift_labels = loss_kwargs.pop("shift_labels", None)
230233
logits = None
231234
loss = None
232235
# if in training mode, don't materialize logits
233236
if self.training and (labels is not None or shift_labels is not None):
234237
loss = LigerForCausalLMLoss(
235-
hidden_states=hidden_states,
238+
hidden_states=kept_hidden_states,
236239
lm_head_weight=self.lm_head.weight,
237240
labels=labels,
238241
shift_labels=shift_labels,
@@ -241,8 +244,7 @@ def lce_forward(
241244
)
242245

243246
else: # if in inference mode materialize logits
244-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
245-
logits = self.lm_head(hidden_states[:, slice_indices, :])
247+
logits = self.lm_head(kept_hidden_states)
246248

247249
loss = None
248250
if labels is not None:

src/liger_kernel/transformers/model/mllama.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,14 +215,17 @@ def lce_forward(
215215
)
216216

217217
hidden_states = outputs[0]
218+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
219+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
220+
kept_hidden_states = hidden_states[:, slice_indices, :]
218221

219222
shift_labels = loss_kwargs.pop("shift_labels", None)
220223
logits = None
221224
loss = None
222225
# if in training mode, don't materialize logits
223226
if self.training and (labels is not None or shift_labels is not None):
224227
loss = LigerForCausalLMLoss(
225-
hidden_states=hidden_states,
228+
hidden_states=kept_hidden_states,
226229
lm_head_weight=self.lm_head.weight,
227230
labels=labels,
228231
shift_labels=shift_labels,
@@ -231,8 +234,7 @@ def lce_forward(
231234
)
232235

233236
else: # if in inference mode materialize logits
234-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
235-
logits = self.lm_head(hidden_states[:, slice_indices, :])
237+
logits = self.lm_head(kept_hidden_states)
236238
if labels is not None:
237239
loss = self.loss_function(
238240
logits=logits,

src/liger_kernel/transformers/model/olmo2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,17 @@ def lce_forward(
8888
)
8989

9090
hidden_states = outputs[0]
91+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
92+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
93+
kept_hidden_states = hidden_states[:, slice_indices, :]
9194

9295
shift_labels = loss_kwargs.pop("shift_labels", None)
9396
logits = None
9497
loss = None
9598
# if in training mode, don't materialize logits
9699
if self.training and (labels is not None or shift_labels is not None):
97100
loss = LigerForCausalLMLoss(
98-
hidden_states=hidden_states,
101+
hidden_states=kept_hidden_states,
99102
lm_head_weight=self.lm_head.weight,
100103
labels=labels,
101104
shift_labels=shift_labels,
@@ -104,8 +107,7 @@ def lce_forward(
104107
)
105108

106109
else: # if in inference mode materialize logits
107-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
108-
logits = self.lm_head(hidden_states[:, slice_indices, :])
110+
logits = self.lm_head(kept_hidden_states)
109111
if labels is not None:
110112
loss = self.loss_function(
111113
logits=logits,

src/liger_kernel/transformers/model/phi3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,17 @@ def lce_forward(
213213
)
214214

215215
hidden_states = outputs[0]
216+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
217+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
218+
kept_hidden_states = hidden_states[:, slice_indices, :]
216219

217220
shift_labels = loss_kwargs.pop("shift_labels", None)
218221
logits = None
219222
loss = None
220223
# if in training mode, don't materialize logits
221224
if self.training and (labels is not None or shift_labels is not None):
222225
loss = LigerForCausalLMLoss(
223-
hidden_states=hidden_states,
226+
hidden_states=kept_hidden_states,
224227
lm_head_weight=self.lm_head.weight,
225228
labels=labels,
226229
shift_labels=shift_labels,
@@ -229,8 +232,7 @@ def lce_forward(
229232
)
230233

231234
else: # if in inference mode materialize logits
232-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
233-
logits = self.lm_head(hidden_states[:, slice_indices, :])
235+
logits = self.lm_head(kept_hidden_states)
234236
if labels is not None:
235237
loss = self.loss_function(
236238
logits=logits,

src/liger_kernel/transformers/model/qwen2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,14 +199,17 @@ def lce_forward(
199199
)
200200

201201
hidden_states = outputs[0]
202+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
203+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
204+
kept_hidden_states = hidden_states[:, slice_indices, :]
202205

203206
shift_labels = loss_kwargs.pop("shift_labels", None)
204207
logits = None
205208
loss = None
206209
# if in training mode, don't materialize logits
207210
if self.training and (labels is not None or shift_labels is not None):
208211
loss = LigerForCausalLMLoss(
209-
hidden_states=hidden_states,
212+
hidden_states=kept_hidden_states,
210213
lm_head_weight=self.lm_head.weight,
211214
labels=labels,
212215
shift_labels=shift_labels,
@@ -215,8 +218,7 @@ def lce_forward(
215218
)
216219

217220
else: # if in inference mode materialize logits
218-
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
219-
logits = self.lm_head(hidden_states[:, slice_indices, :])
221+
logits = self.lm_head(kept_hidden_states)
220222
if labels is not None:
221223
loss = self.loss_function(
222224
logits=logits,

0 commit comments

Comments
 (0)