Skip to content

Commit 4d11fee

Browse files
authored
fix: don't drop kwargs from huggingface forward (#708)
## Summary HuggingFace forward passes kwargs through: https://github.com/huggingface/transformers/blob/716819b8309324302e00a3488a3c3d6faa427f79/src/transformers/models/qwen2/modeling_qwen2.py#L712 This is important to compute FlashAttention kwargs outside of the forward, so that it's not recomputed on every attention layer, which causes a number of issues: huggingface/transformers#35588 ## Testing Done - Hardware Type: H100 - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 00fa58a commit 4d11fee

File tree

13 files changed

+65
-50
lines changed

13 files changed

+65
-50
lines changed

src/liger_kernel/transformers/model/gemma.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def lce_forward(
138138
cache_position: Optional[torch.LongTensor] = None,
139139
logits_to_keep: Union[int, torch.Tensor] = 0,
140140
skip_logits: Optional[bool] = None,
141-
**loss_kwargs,
141+
**kwargs,
142142
) -> Union[Tuple, CausalLMOutputWithPast]:
143143
r"""
144144
Args:
@@ -190,14 +190,15 @@ def lce_forward(
190190
output_hidden_states=output_hidden_states,
191191
return_dict=return_dict,
192192
cache_position=cache_position,
193+
**kwargs,
193194
)
194195

195196
hidden_states = outputs[0]
196197
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
197198
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
198199
kept_hidden_states = hidden_states[:, slice_indices, :]
199200

200-
shift_labels = loss_kwargs.pop("shift_labels", None)
201+
shift_labels = kwargs.pop("shift_labels", None)
201202
logits = None
202203
loss = None
203204

@@ -215,7 +216,7 @@ def lce_forward(
215216
labels=labels,
216217
shift_labels=shift_labels,
217218
hidden_size=self.config.hidden_size,
218-
**loss_kwargs,
219+
**kwargs,
219220
)
220221
else:
221222
logits = self.lm_head(kept_hidden_states)
@@ -224,7 +225,7 @@ def lce_forward(
224225
logits=logits,
225226
labels=labels,
226227
vocab_size=self.config.vocab_size,
227-
**loss_kwargs,
228+
**kwargs,
228229
)
229230

230231
if not return_dict:

src/liger_kernel/transformers/model/gemma2.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def lce_forward_deprecated(
3030
output_hidden_states: Optional[bool] = None,
3131
return_dict: Optional[bool] = None,
3232
cache_position: Optional[torch.LongTensor] = None,
33+
**kwargs,
3334
) -> Union[Tuple, CausalLMOutputWithPast]:
3435
r"""
3536
Args:
@@ -76,6 +77,7 @@ def lce_forward_deprecated(
7677
output_hidden_states=output_hidden_states,
7778
return_dict=return_dict,
7879
cache_position=cache_position,
80+
**kwargs,
7981
)
8082

8183
hidden_states = outputs[0]
@@ -147,7 +149,7 @@ def lce_forward(
147149
cache_position: Optional[torch.LongTensor] = None,
148150
logits_to_keep: Union[int, torch.Tensor] = 0,
149151
skip_logits: Optional[bool] = None,
150-
**loss_kwargs,
152+
**kwargs,
151153
) -> Union[Tuple, CausalLMOutputWithPast]:
152154
r"""
153155
Args:
@@ -204,14 +206,15 @@ def lce_forward(
204206
output_hidden_states=output_hidden_states,
205207
return_dict=return_dict,
206208
cache_position=cache_position,
209+
**kwargs,
207210
)
208211

209212
hidden_states = outputs[0]
210213
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
211214
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
212215
kept_hidden_states = hidden_states[:, slice_indices, :]
213216

214-
shift_labels = loss_kwargs.pop("shift_labels", None)
217+
shift_labels = kwargs.pop("shift_labels", None)
215218
logits = None
216219
loss = None
217220

@@ -230,7 +233,7 @@ def lce_forward(
230233
shift_labels=shift_labels,
231234
hidden_size=self.config.hidden_size,
232235
final_logit_softcapping=self.config.final_logit_softcapping,
233-
**loss_kwargs,
236+
**kwargs,
234237
)
235238

236239
else:
@@ -242,7 +245,7 @@ def lce_forward(
242245

243246
loss = None
244247
if labels is not None:
245-
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
248+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
246249

247250
if not return_dict:
248251
output = (logits,) + outputs[1:]

src/liger_kernel/transformers/model/glm4.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def lce_forward(
2727
cache_position: Optional[torch.LongTensor] = None,
2828
logits_to_keep: Union[int, torch.Tensor] = 0,
2929
skip_logits: Optional[bool] = None,
30-
**loss_kwargs,
30+
**kwargs,
3131
) -> Union[Tuple, CausalLMOutputWithPast]:
3232
r"""
3333
Args:
@@ -80,14 +80,15 @@ def lce_forward(
8080
output_hidden_states=output_hidden_states,
8181
return_dict=return_dict,
8282
cache_position=cache_position,
83+
**kwargs,
8384
)
8485

8586
hidden_states = outputs[0]
8687
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
8788
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
8889
kept_hidden_states = hidden_states[:, slice_indices, :]
8990

90-
shift_labels = loss_kwargs.pop("shift_labels", None)
91+
shift_labels = kwargs.pop("shift_labels", None)
9192
logits = None
9293
loss = None
9394

@@ -105,7 +106,7 @@ def lce_forward(
105106
labels=labels,
106107
shift_labels=shift_labels,
107108
hidden_size=self.config.hidden_size,
108-
**loss_kwargs,
109+
**kwargs,
109110
)
110111

111112
else:
@@ -115,7 +116,7 @@ def lce_forward(
115116
logits=logits,
116117
labels=labels,
117118
vocab_size=self.config.vocab_size,
118-
**loss_kwargs,
119+
**kwargs,
119120
)
120121

121122
return CausalLMOutputWithPast(

src/liger_kernel/transformers/model/llama.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def lce_forward(
152152
cache_position: Optional[torch.LongTensor] = None,
153153
logits_to_keep: Union[int, torch.Tensor] = 0,
154154
skip_logits: Optional[bool] = None,
155-
**loss_kwargs,
155+
**kwargs,
156156
) -> Union[Tuple, CausalLMOutputWithPast]:
157157
r"""
158158
Args:
@@ -205,6 +205,7 @@ def lce_forward(
205205
output_hidden_states=output_hidden_states,
206206
return_dict=return_dict,
207207
cache_position=cache_position,
208+
**kwargs,
208209
)
209210

210211
hidden_states = outputs[0]
@@ -215,7 +216,7 @@ def lce_forward(
215216
if self.config.pretraining_tp > 1:
216217
raise Exception("Liger Kernel does not support pretraining_tp!!")
217218

218-
shift_labels = loss_kwargs.pop("shift_labels", None)
219+
shift_labels = kwargs.pop("shift_labels", None)
219220
logits = None
220221
loss = None
221222
# if in training mode, don't materialize logits
@@ -233,7 +234,7 @@ def lce_forward(
233234
hidden_size=self.config.hidden_size,
234235
labels=labels,
235236
shift_labels=shift_labels,
236-
**loss_kwargs,
237+
**kwargs,
237238
)
238239

239240
else:
@@ -243,7 +244,7 @@ def lce_forward(
243244
logits=logits,
244245
labels=labels,
245246
vocab_size=self.config.vocab_size,
246-
**loss_kwargs,
247+
**kwargs,
247248
)
248249

249250
if not return_dict:

src/liger_kernel/transformers/model/mistral.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def lce_forward(
2828
cache_position: Optional[torch.LongTensor] = None,
2929
logits_to_keep: Union[int, torch.Tensor] = 0,
3030
skip_logits: Optional[bool] = None,
31-
**loss_kwargs,
31+
**kwargs,
3232
) -> Union[Tuple, CausalLMOutputWithPast]:
3333
r"""
3434
Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -83,14 +83,15 @@ def lce_forward(
8383
output_hidden_states=output_hidden_states,
8484
return_dict=return_dict,
8585
cache_position=cache_position,
86+
**kwargs,
8687
)
8788

8889
hidden_states = outputs[0]
8990
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
9091
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
9192
kept_hidden_states = hidden_states[:, slice_indices, :]
9293

93-
shift_labels = loss_kwargs.pop("shift_labels", None)
94+
shift_labels = kwargs.pop("shift_labels", None)
9495
loss = None
9596
logits = None
9697

@@ -107,7 +108,7 @@ def lce_forward(
107108
labels=labels,
108109
shift_labels=shift_labels,
109110
hidden_size=self.config.hidden_size,
110-
**loss_kwargs,
111+
**kwargs,
111112
)
112113

113114
else:
@@ -119,7 +120,7 @@ def lce_forward(
119120
logits=logits,
120121
labels=labels,
121122
vocab_size=self.config.vocab_size,
122-
**loss_kwargs,
123+
**kwargs,
123124
)
124125
if not return_dict:
125126
output = (logits,) + outputs[1:]

src/liger_kernel/transformers/model/mixtral.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def lce_forward(
157157
cache_position: Optional[torch.LongTensor] = None,
158158
logits_to_keep: Union[int, torch.Tensor] = 0,
159159
skip_logits: Optional[bool] = None,
160-
**loss_kwargs,
160+
**kwargs,
161161
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
162162
r"""
163163
Args:
@@ -215,14 +215,15 @@ def lce_forward(
215215
output_router_logits=output_router_logits,
216216
return_dict=return_dict,
217217
cache_position=cache_position,
218+
**kwargs,
218219
)
219220

220221
hidden_states = outputs[0]
221222
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
222223
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
223224
kept_hidden_states = hidden_states[:, slice_indices, :]
224225

225-
shift_labels = loss_kwargs.pop("shift_labels", None)
226+
shift_labels = kwargs.pop("shift_labels", None)
226227
logits = None
227228
loss = None
228229

@@ -240,15 +241,15 @@ def lce_forward(
240241
labels=labels,
241242
shift_labels=shift_labels,
242243
hidden_size=self.config.hidden_size,
243-
**loss_kwargs,
244+
**kwargs,
244245
)
245246

246247
else:
247248
logits = self.lm_head(kept_hidden_states)
248249

249250
loss = None
250251
if labels is not None:
251-
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
252+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
252253
aux_loss = None
253254
if output_router_logits:
254255
aux_loss = load_balancing_loss_func(

src/liger_kernel/transformers/model/mllama.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def lce_forward(
148148
cache_position: Optional[torch.LongTensor] = None,
149149
logits_to_keep: Union[int, torch.Tensor] = 0,
150150
skip_logits: Optional[bool] = None,
151-
**loss_kwargs,
151+
**kwargs,
152152
) -> Union[Tuple, CausalLMOutputWithPast]:
153153
r"""
154154
Args:
@@ -206,14 +206,15 @@ def lce_forward(
206206
output_hidden_states=output_hidden_states,
207207
return_dict=return_dict,
208208
cache_position=cache_position,
209+
**kwargs,
209210
)
210211

211212
hidden_states = outputs[0]
212213
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
213214
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
214215
kept_hidden_states = hidden_states[:, slice_indices, :]
215216

216-
shift_labels = loss_kwargs.pop("shift_labels", None)
217+
shift_labels = kwargs.pop("shift_labels", None)
217218
logits = None
218219
loss = None
219220

@@ -231,7 +232,7 @@ def lce_forward(
231232
labels=labels,
232233
shift_labels=shift_labels,
233234
hidden_size=self.config.hidden_size,
234-
**loss_kwargs,
235+
**kwargs,
235236
)
236237

237238
else:
@@ -241,7 +242,7 @@ def lce_forward(
241242
logits=logits,
242243
labels=labels,
243244
vocab_size=self.config.vocab_size,
244-
**loss_kwargs,
245+
**kwargs,
245246
)
246247

247248
if not return_dict:

src/liger_kernel/transformers/model/olmo2.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def lce_forward(
2727
cache_position: Optional[torch.LongTensor] = None,
2828
logits_to_keep: Union[int, torch.Tensor] = 0,
2929
skip_logits: Optional[bool] = None,
30-
**loss_kwargs,
30+
**kwargs,
3131
) -> Union[Tuple, CausalLMOutputWithPast]:
3232
r"""
3333
Args:
@@ -80,14 +80,15 @@ def lce_forward(
8080
output_hidden_states=output_hidden_states,
8181
return_dict=return_dict,
8282
cache_position=cache_position,
83+
**kwargs,
8384
)
8485

8586
hidden_states = outputs[0]
8687
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
8788
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
8889
kept_hidden_states = hidden_states[:, slice_indices, :]
8990

90-
shift_labels = loss_kwargs.pop("shift_labels", None)
91+
shift_labels = kwargs.pop("shift_labels", None)
9192
logits = None
9293
loss = None
9394

@@ -105,7 +106,7 @@ def lce_forward(
105106
labels=labels,
106107
shift_labels=shift_labels,
107108
hidden_size=self.config.hidden_size,
108-
**loss_kwargs,
109+
**kwargs,
109110
)
110111

111112
else:
@@ -115,7 +116,7 @@ def lce_forward(
115116
logits=logits,
116117
labels=labels,
117118
vocab_size=self.config.vocab_size,
118-
**loss_kwargs,
119+
**kwargs,
119120
)
120121

121122
return CausalLMOutputWithPast(

0 commit comments

Comments
 (0)