Skip to content

Commit 7b40ed8

Browse files
authored
Allow option to choose between fused LCE and CE loss (#704)
## Summary More context in #703 This adds a kwarg to the forward method of models that allows users to choose if they want to use fused LCE or if they want to materialize logits and use regular CE loss. Currently this decision is made implicitly using model state. ## Testing Done Ran this on a downstream transformers pipeline and measured peak memory consumption. Reduced by as much as 80% on some runs as confirmed, with no effect on quality. Still running repo tests, but shouldnt expect any surprises. <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: H100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent 8807a54 commit 7b40ed8

File tree

19 files changed

+225
-72
lines changed

19 files changed

+225
-72
lines changed

src/liger_kernel/transformers/model/gemma.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def lce_forward(
137137
return_dict: Optional[bool] = None,
138138
cache_position: Optional[torch.LongTensor] = None,
139139
logits_to_keep: Union[int, torch.Tensor] = 0,
140+
skip_logits: Optional[bool] = None,
140141
**loss_kwargs,
141142
) -> Union[Tuple, CausalLMOutputWithPast]:
142143
r"""
@@ -199,8 +200,15 @@ def lce_forward(
199200
shift_labels = loss_kwargs.pop("shift_labels", None)
200201
logits = None
201202
loss = None
202-
# if in training mode, don't materialize logits
203-
if self.training and (labels is not None or shift_labels is not None):
203+
204+
if skip_logits and labels is None and shift_labels is None:
205+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
206+
207+
if skip_logits is None:
208+
# By default, if in training mode, don't materialize logits
209+
skip_logits = self.training and (labels is not None or shift_labels is not None)
210+
211+
if skip_logits:
204212
loss = LigerForCausalLMLoss(
205213
hidden_states=kept_hidden_states,
206214
lm_head_weight=self.lm_head.weight,
@@ -209,7 +217,7 @@ def lce_forward(
209217
hidden_size=self.config.hidden_size,
210218
**loss_kwargs,
211219
)
212-
else: # if in inference mode materialize logits
220+
else:
213221
logits = self.lm_head(kept_hidden_states)
214222
if labels is not None:
215223
loss = self.loss_function(

src/liger_kernel/transformers/model/gemma2.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def lce_forward(
146146
return_dict: Optional[bool] = None,
147147
cache_position: Optional[torch.LongTensor] = None,
148148
logits_to_keep: Union[int, torch.Tensor] = 0,
149+
skip_logits: Optional[bool] = None,
149150
**loss_kwargs,
150151
) -> Union[Tuple, CausalLMOutputWithPast]:
151152
r"""
@@ -213,8 +214,15 @@ def lce_forward(
213214
shift_labels = loss_kwargs.pop("shift_labels", None)
214215
logits = None
215216
loss = None
216-
# if in training mode, don't materialize logits
217-
if self.training and (labels is not None or shift_labels is not None):
217+
218+
if skip_logits and labels is None and shift_labels is None:
219+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
220+
221+
if skip_logits is None:
222+
# By default, if in training mode, don't materialize logits
223+
skip_logits = self.training and (labels is not None or shift_labels is not None)
224+
225+
if skip_logits:
218226
loss = LigerForCausalLMLoss(
219227
hidden_states=kept_hidden_states,
220228
lm_head_weight=self.lm_head.weight,
@@ -225,7 +233,7 @@ def lce_forward(
225233
**loss_kwargs,
226234
)
227235

228-
else: # if in inference mode materialize logits
236+
else:
229237
logits = self.lm_head(kept_hidden_states)
230238
if self.config.final_logit_softcapping is not None:
231239
logits = logits / self.config.final_logit_softcapping

src/liger_kernel/transformers/model/gemma3.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def causal_forward(
3535
return_dict: Optional[bool] = None,
3636
cache_position: Optional[torch.LongTensor] = None,
3737
logits_to_keep: Union[int, torch.Tensor] = 0,
38+
skip_logits: Optional[bool] = None,
3839
**loss_kwargs,
3940
) -> Union[Tuple, CausalLMOutputWithPast]:
4041
r"""
@@ -101,7 +102,11 @@ def causal_forward(
101102
shift_labels = loss_kwargs.pop("shift_labels", None)
102103
loss = None
103104
logits = None
104-
if self.training and (labels is not None or shift_labels is not None):
105+
106+
if skip_logits is None:
107+
skip_logits = self.training and (labels is not None or shift_labels is not None)
108+
109+
if skip_logits:
105110
loss = LigerForCausalLMLoss(
106111
hidden_states=kept_hidden_states,
107112
lm_head_weight=self.lm_head.weight,
@@ -151,6 +156,7 @@ def multimodal_forward(
151156
output_hidden_states: Optional[bool] = None,
152157
return_dict: Optional[bool] = None,
153158
logits_to_keep: Union[int, torch.Tensor] = 0,
159+
skip_logits: Optional[bool] = None,
154160
**lm_kwargs,
155161
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
156162
r"""
@@ -272,7 +278,13 @@ def multimodal_forward(
272278
loss = None
273279
logits = None
274280

275-
if self.training and (labels is not None):
281+
if skip_logits and labels is None:
282+
raise ValueError("skip_logits is True, but labels is None")
283+
284+
if skip_logits is None:
285+
skip_logits = self.training and (labels is not None)
286+
287+
if skip_logits:
276288
shift_hidden_states = hidden_states[..., :-1, :]
277289
shift_labels = labels[..., 1:]
278290

src/liger_kernel/transformers/model/glm4.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def lce_forward(
2626
return_dict: Optional[bool] = None,
2727
cache_position: Optional[torch.LongTensor] = None,
2828
logits_to_keep: Union[int, torch.Tensor] = 0,
29+
skip_logits: Optional[bool] = None,
2930
**loss_kwargs,
3031
) -> Union[Tuple, CausalLMOutputWithPast]:
3132
r"""
@@ -89,8 +90,15 @@ def lce_forward(
8990
shift_labels = loss_kwargs.pop("shift_labels", None)
9091
logits = None
9192
loss = None
92-
# if in training mode, don't materialize logits
93-
if self.training and (labels is not None or shift_labels is not None):
93+
94+
if skip_logits and labels is None and shift_labels is None:
95+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
96+
97+
if skip_logits is None:
98+
# By default, if in training mode, don't materialize logits
99+
skip_logits = self.training and (labels is not None or shift_labels is not None)
100+
101+
if skip_logits:
94102
loss = LigerForCausalLMLoss(
95103
hidden_states=kept_hidden_states,
96104
lm_head_weight=self.lm_head.weight,
@@ -100,7 +108,7 @@ def lce_forward(
100108
**loss_kwargs,
101109
)
102110

103-
else: # if in inference mode materialize logits
111+
else:
104112
logits = self.lm_head(kept_hidden_states)
105113
if labels is not None:
106114
loss = self.loss_function(

src/liger_kernel/transformers/model/llama.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def lce_forward(
151151
return_dict: Optional[bool] = None,
152152
cache_position: Optional[torch.LongTensor] = None,
153153
logits_to_keep: Union[int, torch.Tensor] = 0,
154+
skip_logits: Optional[bool] = None,
154155
**loss_kwargs,
155156
) -> Union[Tuple, CausalLMOutputWithPast]:
156157
r"""
@@ -218,7 +219,14 @@ def lce_forward(
218219
logits = None
219220
loss = None
220221
# if in training mode, don't materialize logits
221-
if self.training and (labels is not None or shift_labels is not None):
222+
if skip_logits and labels is None and shift_labels is None:
223+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
224+
225+
if skip_logits is None:
226+
# By default, if in training mode, don't materialize logits
227+
skip_logits = self.training and (labels is not None or shift_labels is not None)
228+
229+
if skip_logits:
222230
loss = lce_maybe_trainable_lm_head(
223231
self,
224232
hidden_states=kept_hidden_states,
@@ -228,7 +236,7 @@ def lce_forward(
228236
**loss_kwargs,
229237
)
230238

231-
else: # if in inference mode materialize logits
239+
else:
232240
logits = self.lm_head(kept_hidden_states)
233241
if labels is not None:
234242
loss = self.loss_function(

src/liger_kernel/transformers/model/llava.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def lce_forward(
223223
cache_position: Optional[torch.LongTensor] = None,
224224
logits_to_keep: Union[int, torch.Tensor] = 0,
225225
image_sizes: torch.Tensor = None,
226+
skip_logits: Optional[bool] = None,
226227
**lm_kwargs,
227228
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
228229
r"""
@@ -325,7 +326,10 @@ def lce_forward(
325326
loss = None
326327
logits = None
327328

328-
if self.training and (labels is not None):
329+
# Overwrite skip_logits, since llava never materializes logits
330+
skip_logits = labels is not None
331+
332+
if skip_logits:
329333
# Shift so that tokens < n predict n
330334
if attention_mask is not None:
331335
# we use the input attention mask to shift the logits and labels, because it is 2D.

src/liger_kernel/transformers/model/mistral.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def lce_forward(
2727
return_dict: Optional[bool] = None,
2828
cache_position: Optional[torch.LongTensor] = None,
2929
logits_to_keep: Union[int, torch.Tensor] = 0,
30+
skip_logits: Optional[bool] = None,
3031
**loss_kwargs,
3132
) -> Union[Tuple, CausalLMOutputWithPast]:
3233
r"""
@@ -93,7 +94,13 @@ def lce_forward(
9394
loss = None
9495
logits = None
9596

96-
if self.training and (labels is not None or shift_labels is not None):
97+
if skip_logits and labels is None and shift_labels is None:
98+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
99+
100+
if skip_logits is None:
101+
skip_logits = self.training and (labels is not None or shift_labels is not None)
102+
103+
if skip_logits:
97104
loss = LigerForCausalLMLoss(
98105
hidden_states=kept_hidden_states,
99106
lm_head_weight=self.lm_head.weight,

src/liger_kernel/transformers/model/mixtral.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def lce_forward(
156156
return_dict: Optional[bool] = None,
157157
cache_position: Optional[torch.LongTensor] = None,
158158
logits_to_keep: Union[int, torch.Tensor] = 0,
159+
skip_logits: Optional[bool] = None,
159160
**loss_kwargs,
160161
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
161162
r"""
@@ -224,8 +225,15 @@ def lce_forward(
224225
shift_labels = loss_kwargs.pop("shift_labels", None)
225226
logits = None
226227
loss = None
227-
# if in training mode, don't materialize logits
228-
if self.training and (labels is not None or shift_labels is not None):
228+
229+
if skip_logits and labels is None and shift_labels is None:
230+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
231+
232+
if skip_logits is None:
233+
# By default, if in training mode, don't materialize logits
234+
skip_logits = self.training and (labels is not None or shift_labels is not None)
235+
236+
if skip_logits:
229237
loss = LigerForCausalLMLoss(
230238
hidden_states=kept_hidden_states,
231239
lm_head_weight=self.lm_head.weight,
@@ -235,7 +243,7 @@ def lce_forward(
235243
**loss_kwargs,
236244
)
237245

238-
else: # if in inference mode materialize logits
246+
else:
239247
logits = self.lm_head(kept_hidden_states)
240248

241249
loss = None

src/liger_kernel/transformers/model/mllama.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def lce_forward(
147147
return_dict: Optional[bool] = None,
148148
cache_position: Optional[torch.LongTensor] = None,
149149
logits_to_keep: Union[int, torch.Tensor] = 0,
150+
skip_logits: Optional[bool] = None,
150151
**loss_kwargs,
151152
) -> Union[Tuple, CausalLMOutputWithPast]:
152153
r"""
@@ -215,8 +216,15 @@ def lce_forward(
215216
shift_labels = loss_kwargs.pop("shift_labels", None)
216217
logits = None
217218
loss = None
218-
# if in training mode, don't materialize logits
219-
if self.training and (labels is not None or shift_labels is not None):
219+
220+
if skip_logits and labels is None and shift_labels is None:
221+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
222+
223+
if skip_logits is None:
224+
# By default, if in training mode, don't materialize logits
225+
skip_logits = self.training and (labels is not None or shift_labels is not None)
226+
227+
if skip_logits:
220228
loss = LigerForCausalLMLoss(
221229
hidden_states=kept_hidden_states,
222230
lm_head_weight=self.lm_head.weight,
@@ -226,7 +234,7 @@ def lce_forward(
226234
**loss_kwargs,
227235
)
228236

229-
else: # if in inference mode materialize logits
237+
else:
230238
logits = self.lm_head(kept_hidden_states)
231239
if labels is not None:
232240
loss = self.loss_function(

src/liger_kernel/transformers/model/olmo2.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def lce_forward(
2626
return_dict: Optional[bool] = None,
2727
cache_position: Optional[torch.LongTensor] = None,
2828
logits_to_keep: Union[int, torch.Tensor] = 0,
29+
skip_logits: Optional[bool] = None,
2930
**loss_kwargs,
3031
) -> Union[Tuple, CausalLMOutputWithPast]:
3132
r"""
@@ -89,8 +90,15 @@ def lce_forward(
8990
shift_labels = loss_kwargs.pop("shift_labels", None)
9091
logits = None
9192
loss = None
92-
# if in training mode, don't materialize logits
93-
if self.training and (labels is not None or shift_labels is not None):
93+
94+
if skip_logits and labels is None and shift_labels is None:
95+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
96+
97+
if skip_logits is None:
98+
# By default, if in training mode, don't materialize logits
99+
skip_logits = self.training and (labels is not None or shift_labels is not None)
100+
101+
if skip_logits:
94102
loss = LigerForCausalLMLoss(
95103
hidden_states=kept_hidden_states,
96104
lm_head_weight=self.lm_head.weight,
@@ -100,7 +108,7 @@ def lce_forward(
100108
**loss_kwargs,
101109
)
102110

103-
else: # if in inference mode materialize logits
111+
else:
104112
logits = self.lm_head(kept_hidden_states)
105113
if labels is not None:
106114
loss = self.loss_function(

0 commit comments

Comments
 (0)