Skip to content

Commit 9d58cd2

Browse files
authored
fix coverity (#3438)
1 parent 3fbf37d commit 9d58cd2

File tree

1 file changed

+2
-2
lines changed
  • intel_extension_for_pytorch/transformers/models/reference

1 file changed

+2
-2
lines changed

intel_extension_for_pytorch/transformers/models/reference/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ def LlavaForConditionalGeneration_forward(
803803
legacy_processing = (
804804
(input_ids == self.config.image_token_index).sum(1).max()
805805
< self.config.image_seq_length
806-
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
806+
) or (inputs_embeds.shape[-2] == 1 and pixel_values is not None)
807807

808808
image_features = None
809809
if pixel_values is not None:
@@ -814,7 +814,7 @@ def LlavaForConditionalGeneration_forward(
814814
)
815815
if legacy_processing:
816816
# prefill stage vs decoding stage (legacy behavior copied)
817-
if input_ids.shape[1] != 1:
817+
if inputs_embeds.shape[-2] != 1:
818818
inputs_embeds, attention_mask, labels, position_ids = (
819819
self._merge_input_ids_with_image_features(
820820
image_features, inputs_embeds, input_ids, attention_mask, labels

0 commit comments

Comments
 (0)