Skip to content

Commit 6853d5d

Browse files
authored
Fix llava format (#751)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence Signed-off-by: Tcc0403 <[email protected]>
1 parent 1449c19 commit 6853d5d

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

src/liger_kernel/transformers/model/llava.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from torch.nn import CrossEntropyLoss
99
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
10+
from transformers.utils import is_torchdynamo_compiling
1011

1112
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
1213
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -193,6 +194,7 @@ def lce_forward_deprecated(
193194
image_hidden_states=image_features if pixel_values is not None else None,
194195
)
195196

197+
196198
def lce_forward(
197199
self,
198200
input_ids: torch.LongTensor = None,
@@ -316,7 +318,6 @@ def lce_forward(
316318
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **lm_kwargs
317319
)
318320

319-
320321
if not return_dict:
321322
output = (logits,) + outputs[1:]
322323
return (loss,) + output if loss is not None else output

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def apply_liger_kernel_to_llava(
316316
if fused_linear_cross_entropy:
317317
if transformer_version >= version.parse("4.52.0"):
318318
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
319-
elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
319+
elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
320320
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
321321
else: # if version < 4.49.0
322322
logger.warning(

0 commit comments

Comments
 (0)