Skip to content

Commit ea3ac1b

Browse files
Tcc0403lancertsyundai424vaibhavjindal
authored
Fix llava eval mode (#714)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Llava is missing logits in eval mode. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Co-authored-by: Shao Tang <tangshao28@gmail.com> Co-authored-by: Yun Dai <yundai424@gmail.com> Co-authored-by: Vaibhav Jindal <vaibhav.jndl@gmail.com>
1 parent b53d954 commit ea3ac1b

File tree

1 file changed

+37
-1
lines changed
  • src/liger_kernel/transformers/model

1 file changed

+37
-1
lines changed

src/liger_kernel/transformers/model/llava.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77

8+
from torch.nn import CrossEntropyLoss
89
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
910
from transformers.utils import is_torchdynamo_compiling
1011
from transformers.utils.deprecation import deprecate_kwarg
@@ -189,7 +190,20 @@ def lce_forward_deprecated(
189190

190191
lce = LigerFusedLinearCrossEntropyLoss()
191192
loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
192-
193+
else:
194+
logits = self.language_model.lm_head(hidden_states)
195+
if labels is not None:
196+
# Shift so that tokens < n predict n
197+
if attention_mask is not None:
198+
shift_attention_mask = attention_mask[..., 1:]
199+
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
200+
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
201+
else:
202+
shift_logits = logits[..., :-1, :].contiguous()
203+
shift_labels = labels[..., 1:].contiguous()
204+
# Flatten the tokens
205+
loss_fct = CrossEntropyLoss()
206+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device))
193207
if not return_dict:
194208
# NOTE: This part has not been tested.
195209
output = outputs[1:]
@@ -349,6 +363,28 @@ def lce_forward(
349363
shift_hidden_states.view(-1, shift_hidden_states.size(-1)),
350364
shift_labels.view(-1).to(shift_hidden_states.device),
351365
)
366+
else:
367+
logits = self.language_model.lm_head(hidden_states)
368+
if labels is not None:
369+
# Upcast to float if we need to compute the loss to avoid potential precision issues
370+
logits = logits.float()
371+
shift_logits = logits[..., :-1, :]
372+
shift_labels = labels[..., 1:]
373+
if attention_mask is not None:
374+
# we use the input attention mask to shift the logits and labels, because it is 2D.
375+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
376+
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
377+
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
378+
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
379+
else:
380+
shift_logits = shift_logits.contiguous()
381+
shift_labels = shift_labels.contiguous()
382+
# Flatten the tokens
383+
loss_fct = CrossEntropyLoss()
384+
385+
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
386+
flat_labels = shift_labels.view(-1).to(shift_logits.device)
387+
loss = loss_fct(flat_logits, flat_labels)
352388

353389
if not return_dict:
354390
# NOTE: This part has not been tested.

0 commit comments

Comments
 (0)