Skip to content

Commit 2589a52

Browse files
authored
Fix aria tests (#39879)
* fix aria tests * awful bug * fix copies * fix tests * fix style * revert this
1 parent 6e4a9a5 commit 2589a52

File tree

3 files changed

+4
-52
lines changed

3 files changed

+4
-52
lines changed

src/transformers/models/aria/modeling_aria.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,18 +1014,9 @@ def forward(
10141014
past_key_values: Optional[Cache] = None,
10151015
inputs_embeds: Optional[torch.FloatTensor] = None,
10161016
use_cache: Optional[bool] = None,
1017-
output_attentions: Optional[bool] = None,
1018-
output_hidden_states: Optional[bool] = None,
1019-
return_dict: Optional[bool] = None,
10201017
cache_position: Optional[torch.LongTensor] = None,
10211018
**kwargs: Unpack[FlashAttentionKwargs],
10221019
) -> Union[tuple, AriaModelOutputWithPast]:
1023-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1024-
output_hidden_states = (
1025-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1026-
)
1027-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1028-
10291020
if inputs_embeds is None:
10301021
inputs_embeds = self.get_input_embeddings()(input_ids)
10311022

@@ -1037,7 +1028,7 @@ def forward(
10371028
vision_feature_layer=self.config.vision_feature_layer,
10381029
)
10391030
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
1040-
special_image_mask = self._get_image_mask(
1031+
special_image_mask = self.get_placeholder_mask(
10411032
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
10421033
)
10431034
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
@@ -1048,9 +1039,6 @@ def forward(
10481039
past_key_values=past_key_values,
10491040
inputs_embeds=inputs_embeds,
10501041
use_cache=use_cache,
1051-
output_attentions=output_attentions,
1052-
output_hidden_states=output_hidden_states,
1053-
return_dict=True,
10541042
cache_position=cache_position,
10551043
**kwargs,
10561044
)
@@ -1156,9 +1144,6 @@ def forward(
11561144
inputs_embeds: Optional[torch.FloatTensor] = None,
11571145
labels: Optional[torch.LongTensor] = None,
11581146
use_cache: Optional[bool] = None,
1159-
output_attentions: Optional[bool] = None,
1160-
output_hidden_states: Optional[bool] = None,
1161-
return_dict: Optional[bool] = None,
11621147
logits_to_keep: Union[int, torch.Tensor] = 0,
11631148
cache_position: Optional[torch.LongTensor] = None,
11641149
**kwargs: Unpack[TransformersKwargs],
@@ -1223,12 +1208,6 @@ def forward(
12231208
>>> print(generated_texts[1])
12241209
Assistant: The bridge is in San Francisco.
12251210
```"""
1226-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1227-
output_hidden_states = (
1228-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1229-
)
1230-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1231-
12321211
outputs = self.model(
12331212
input_ids=input_ids,
12341213
pixel_values=pixel_values,
@@ -1238,9 +1217,6 @@ def forward(
12381217
past_key_values=past_key_values,
12391218
inputs_embeds=inputs_embeds,
12401219
use_cache=use_cache,
1241-
output_attentions=output_attentions,
1242-
output_hidden_states=output_hidden_states,
1243-
return_dict=return_dict,
12441220
cache_position=cache_position,
12451221
**kwargs,
12461222
)

src/transformers/models/aria/modular_aria.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,18 +1414,9 @@ def forward(
14141414
past_key_values: Optional[Cache] = None,
14151415
inputs_embeds: Optional[torch.FloatTensor] = None,
14161416
use_cache: Optional[bool] = None,
1417-
output_attentions: Optional[bool] = None,
1418-
output_hidden_states: Optional[bool] = None,
1419-
return_dict: Optional[bool] = None,
14201417
cache_position: Optional[torch.LongTensor] = None,
14211418
**kwargs: Unpack[FlashAttentionKwargs],
14221419
) -> Union[tuple, AriaModelOutputWithPast]:
1423-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1424-
output_hidden_states = (
1425-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1426-
)
1427-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1428-
14291420
if inputs_embeds is None:
14301421
inputs_embeds = self.get_input_embeddings()(input_ids)
14311422

@@ -1437,7 +1428,7 @@ def forward(
14371428
vision_feature_layer=self.config.vision_feature_layer,
14381429
)
14391430
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
1440-
special_image_mask = self._get_image_mask(
1431+
special_image_mask = self.get_placeholder_mask(
14411432
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
14421433
)
14431434
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
@@ -1448,9 +1439,6 @@ def forward(
14481439
past_key_values=past_key_values,
14491440
inputs_embeds=inputs_embeds,
14501441
use_cache=use_cache,
1451-
output_attentions=output_attentions,
1452-
output_hidden_states=output_hidden_states,
1453-
return_dict=True,
14541442
cache_position=cache_position,
14551443
**kwargs,
14561444
)
@@ -1498,9 +1486,6 @@ def forward(
14981486
inputs_embeds: Optional[torch.FloatTensor] = None,
14991487
labels: Optional[torch.LongTensor] = None,
15001488
use_cache: Optional[bool] = None,
1501-
output_attentions: Optional[bool] = None,
1502-
output_hidden_states: Optional[bool] = None,
1503-
return_dict: Optional[bool] = None,
15041489
logits_to_keep: Union[int, torch.Tensor] = 0,
15051490
cache_position: Optional[torch.LongTensor] = None,
15061491
**kwargs: Unpack[TransformersKwargs],
@@ -1565,12 +1550,6 @@ def forward(
15651550
>>> print(generated_texts[1])
15661551
Assistant: The bridge is in San Francisco.
15671552
```"""
1568-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1569-
output_hidden_states = (
1570-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1571-
)
1572-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1573-
15741553
outputs = self.model(
15751554
input_ids=input_ids,
15761555
pixel_values=pixel_values,
@@ -1580,9 +1559,6 @@ def forward(
15801559
past_key_values=past_key_values,
15811560
inputs_embeds=inputs_embeds,
15821561
use_cache=use_cache,
1583-
output_attentions=output_attentions,
1584-
output_hidden_states=output_hidden_states,
1585-
return_dict=return_dict,
15861562
cache_position=cache_position,
15871563
**kwargs,
15881564
)

tests/models/aria/test_modeling_aria.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ def __init__(
137137

138138
def get_config(self):
139139
return AriaConfig(
140-
text_config=self.text_config,
141-
vision_config=self.vision_config,
140+
text_config=self.text_config.to_dict(),
141+
vision_config=self.vision_config.to_dict(),
142142
ignore_index=self.ignore_index,
143143
image_token_index=self.image_token_index,
144144
projector_hidden_act=self.projector_hidden_act,

0 commit comments

Comments
 (0)