Skip to content

Commit c887610

Browse files
authored
fix switching between legacy and new processing for llava (#970)
* fix switching between legacy and new processing for llava * extend tests * update legacy processing path * replace llava test model * Update tests/openvino/test_modeling.py
1 parent 222748e commit c887610

File tree

3 files changed

+104
-53
lines changed

3 files changed

+104
-53
lines changed

optimum/intel/openvino/modeling_visual_language.py

Lines changed: 47 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,33 @@ def can_generate(self):
697697

698698

699699
class _OVLlavaForCausalLM(OVModelForVisualCausalLM):
700+
def __init__(
701+
self,
702+
language_model: ov.Model,
703+
text_embeddings: ov.Model,
704+
vision_embeddings: ov.Model,
705+
config: PretrainedConfig = None,
706+
device: str = "CPU",
707+
dynamic_shapes: bool = True,
708+
ov_config: Optional[Dict[str, str]] = None,
709+
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
710+
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
711+
**kwargs,
712+
):
713+
super().__init__(
714+
language_model=language_model,
715+
text_embeddings=text_embeddings,
716+
vision_embeddings=vision_embeddings,
717+
config=config,
718+
device=device,
719+
dynamic_shapes=dynamic_shapes,
720+
ov_config=ov_config,
721+
model_save_dir=model_save_dir,
722+
quantization_config=quantization_config,
723+
**kwargs,
724+
)
725+
self._support_new_processing = hasattr(self.config, "image_seq_length")
726+
700727
def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
701728
if input_ids is not None and input_ids.shape[1] == 1:
702729
return None
@@ -725,17 +752,11 @@ def merge_vision_text_embeddings(
725752
input_ids,
726753
attention_mask,
727754
position_ids=None,
728-
legacy_processing=None,
755+
legacy_processing=False,
729756
**kwargs,
730757
):
731758
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
732759
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
733-
if legacy_processing is None:
734-
legacy_processing = (
735-
not hasattr(self.config, "image_seq_length")
736-
or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length)
737-
or (input_ids.shape[-1] == 1)
738-
)
739760

740761
if legacy_processing:
741762
pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
@@ -768,15 +789,6 @@ def merge_vision_text_embeddings(
768789
final_attention_mask = torch.zeros(
769790
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
770791
)
771-
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
772-
# set the corresponding tensors into their correct target device.
773-
target_device = inputs_embeds.device
774-
batch_indices, non_image_indices, text_to_overwrite = (
775-
batch_indices.to(target_device),
776-
non_image_indices.to(target_device),
777-
text_to_overwrite.to(target_device),
778-
)
779-
attention_mask = attention_mask.to(target_device)
780792

781793
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
782794
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
@@ -787,15 +799,15 @@ def merge_vision_text_embeddings(
787799
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
788800
)
789801
image_to_overwrite[batch_indices, text_to_overwrite] = False
790-
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
802+
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None]
791803

792804
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
793805
raise ValueError(
794806
f"The input provided to the model a/pre-releasesre wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
795807
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
796808
)
797809

798-
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
810+
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim)
799811
final_attention_mask |= image_to_overwrite
800812
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
801813

@@ -815,11 +827,12 @@ def merge_vision_text_embeddings(
815827
def get_multimodal_embeddings(
816828
self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, past_key_values=None, **kwargs
817829
):
818-
legacy_processing = (
819-
not hasattr(self.config, "image_seq_length")
820-
or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length)
821-
or (input_ids.shape[-1] == 1 and pixel_values is not None)
822-
)
830+
if pixel_values is not None and self._support_new_processing and past_key_values is None:
831+
legacy_processing = (input_ids == self.config.image_token_index).sum(
832+
1
833+
).max() < self.config.image_seq_length
834+
else:
835+
legacy_processing = True
823836
inputs_embeds, attention_mask, position_ids = super().get_multimodal_embeddings(
824837
input_ids, pixel_values, attention_mask, position_ids, legacy_processing=legacy_processing, **kwargs
825838
)
@@ -830,38 +843,19 @@ def get_multimodal_embeddings(
830843
return inputs_embeds, attention_mask, position_ids
831844

832845
def _filter_unattended_tokens(self, input_ids, attention_mask, past_key_values):
833-
if not self.language_model.stateful:
834-
first_layer_past_key_value = torch.from_numpy(past_key_values[0][0][:, :, :, 0])
835-
else:
836-
first_layer_past_key_value = torch.from_numpy(
837-
self.language_model.request.query_state()[0].state.data[:, :, :, 0]
838-
)
839-
840-
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
841-
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
842-
843846
# Get the target length
844847
target_length = input_ids.shape[1]
845-
past_length = first_layer_past_key_value.shape[-1]
848+
past_length = self.language_model._get_past_length(past_key_values)
846849

847850
extended_attention_mask = torch.ones(
848851
(attention_mask.shape[0], past_length),
849852
dtype=attention_mask.dtype,
850853
device=attention_mask.device,
851854
)
852855

853-
# Filter out only the tokens that can be un-attended, this can happen
854-
# if one uses Llava + Fused modules where the cache on the
855-
# first iteration is already big enough, or if one passes custom cache
856-
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
857-
new_batch_index = batch_index[valid_indices]
858-
new_non_attended_tokens = non_attended_tokens[valid_indices]
859-
860-
# Zero-out the places where we don't need to attend
861-
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
862-
863856
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
864-
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
857+
position_ids = torch.cumsum(attention_mask, axis=1) - 1
858+
position_ids[attention_mask == 0] = 1
865859
return attention_mask, position_ids
866860

867861

@@ -938,11 +932,13 @@ def get_multimodal_embeddings(
938932

939933
inputs_embeds = self.get_text_embeddings(input_ids, **kwargs)
940934

941-
legacy_processing = (
942-
not hasattr(self.config, "image_seq_length")
943-
or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length)
944-
or (input_ids.shape[-1] == 1 and pixel_values is not None)
945-
)
935+
if pixel_values is not None and self._support_new_processing and past_key_values is None:
936+
legacy_processing = (input_ids == self.config.image_token_index).sum(
937+
1
938+
).max() < self.config.image_seq_length
939+
else:
940+
legacy_processing = True
941+
946942
if pixel_values is not None and pixel_values.size(0) > 0:
947943
# ! infer image_num_patches from image_sizes
948944
image_num_patches = [
@@ -996,7 +992,7 @@ def merge_vision_text_embeddings(
996992
input_ids,
997993
attention_mask,
998994
position_ids=None,
999-
legacy_processing=None,
995+
legacy_processing=False,
1000996
**kwargs,
1001997
):
1002998
image_token_index = self.config.image_token_index

tests/openvino/test_modeling.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1983,12 +1983,67 @@ def test_compare_to_transformers(self, model_arch):
19831983
torch.equal(ov_outputs, transformers_outputs),
19841984
f"generation config : {gen_config}, transformers output {transformers_outputs}, ov_model output {ov_outputs}",
19851985
)
1986-
19871986
del transformers_model
19881987
del ov_model
19891988

19901989
gc.collect()
19911990

1991+
@parameterized.expand(["llava", "llava_next"])
1992+
@unittest.skipIf(
1993+
is_transformers_version("<", "4.45.0"), reason="New preprocessing available only in transformers >= 4.45"
1994+
)
1995+
def test_llava_with_new_preprocessing(self, model_arch):
1996+
prompt = "<image>\n What is shown in this image?"
1997+
model_id = MODEL_NAMES[model_arch]
1998+
config = AutoConfig.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
1999+
processor = AutoProcessor.from_pretrained(
2000+
model_id,
2001+
patch_size=config.vision_config.patch_size,
2002+
vision_feature_select_strategy=config.vision_feature_select_strategy,
2003+
trust_remote_code=model_arch in self.REMOTE_CODE_MODELS,
2004+
)
2005+
transformers_model = self.get_transformer_model_class(model_arch).from_pretrained(model_id)
2006+
ov_model = OVModelForVisualCausalLM.from_pretrained(
2007+
model_id, export=True, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
2008+
)
2009+
self.assertTrue(ov_model._support_new_processing)
2010+
self.assertTrue(processor.patch_size is not None)
2011+
self.assertTrue(processor.vision_feature_select_strategy is not None)
2012+
inputs = processor(images=self.IMAGE, text=prompt, return_tensors="pt")
2013+
self.assertTrue(
2014+
(inputs.input_ids == ov_model.config.image_token_index).sum(1).max() >= ov_model.config.image_seq_length
2015+
)
2016+
set_seed(SEED)
2017+
with torch.no_grad():
2018+
transformers_outputs = transformers_model(**inputs)
2019+
set_seed(SEED)
2020+
ov_outputs = ov_model(**inputs)
2021+
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
2022+
ov_model.generation_config.eos_token_id = None
2023+
transformers_model.generation_config.eos_token_id = None
2024+
ov_model.config.eos_token_id = None
2025+
transformers_model.config.eos_token_id = None
2026+
gen_config = GenerationConfig(
2027+
max_new_tokens=30,
2028+
min_new_tokens=30,
2029+
num_beams=3,
2030+
do_sample=False,
2031+
eos_token_id=None,
2032+
)
2033+
set_seed(SEED)
2034+
ov_outputs = ov_model.generate(**inputs, generation_config=gen_config)
2035+
set_seed(SEED)
2036+
with torch.no_grad():
2037+
transformers_outputs = transformers_model.generate(**inputs, generation_config=gen_config)
2038+
self.assertTrue(
2039+
torch.equal(ov_outputs, transformers_outputs),
2040+
f"generation config : {gen_config}, transformers output {transformers_outputs}, ov_model output {ov_outputs}",
2041+
)
2042+
2043+
del ov_model
2044+
del transformers_model
2045+
gc.collect()
2046+
19922047
@parameterized.expand(SUPPORTED_ARCHITECTURES)
19932048
def test_generate_utils(self, model_arch):
19942049
model_id = MODEL_NAMES[model_arch]

tests/openvino/utils_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
"llama": "HuggingFaceM4/tiny-random-LlamaForCausalLM",
7777
"llama_awq": "HuggingFaceH4/tiny-random-LlamaForCausalLM",
7878
"llama_gptq": "hf-internal-testing/TinyLlama-1.1B-Chat-v0.3-GPTQ",
79-
"llava": "trl-internal-testing/tiny-random-LlavaForConditionalGeneration",
79+
"llava": "katuni4ka/tiny-random-llava",
8080
"llava_next": "katuni4ka/tiny-random-llava-next",
8181
"m2m_100": "hf-internal-testing/tiny-random-m2m_100",
8282
"opt": "hf-internal-testing/tiny-random-OPTModel",

0 commit comments

Comments
 (0)