Skip to content

Commit 7a1aeec

Browse files
Fixes in check_model_inputs, GPTBigCodeModel and ImageGPTModel (#40811)
* misc fixes * fix * Update src/transformers/models/imagegpt/modeling_imagegpt.py * Apply suggestion from @IlyasMoutawwakil * pickup use_cache from args input as well * fix
1 parent 297a41a commit 7a1aeec

File tree

3 files changed

+34
-27
lines changed

3 files changed

+34
-27
lines changed

src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -472,14 +472,7 @@ def forward(
472472
raise ValueError("batch_size has to be defined and > 0")
473473

474474
if use_cache and past_key_values is None:
475-
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
476-
if use_cache and isinstance(past_key_values, tuple):
477-
logger.warning_once(
478-
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
479-
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
480-
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
481-
)
482-
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
475+
past_key_values = DynamicCache(config=self.config)
483476

484477
if inputs_embeds is None:
485478
inputs_embeds = self.wte(input_ids)

src/transformers/models/imagegpt/modeling_imagegpt.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -517,24 +517,20 @@ def forward(
517517
)
518518
use_cache = False
519519

520-
if use_cache and past_key_values is None:
521-
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
522-
if use_cache and isinstance(past_key_values, tuple):
523-
logger.warning_once(
524-
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
525-
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
526-
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
527-
)
528-
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
529-
530-
past_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values
531-
532520
if token_type_ids is not None:
533521
token_type_ids = token_type_ids.view(-1, input_shape[-1])
534522

523+
if use_cache and past_key_values is None:
524+
past_key_values = DynamicCache(config=self.config)
525+
526+
if cache_position is None:
527+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
528+
cache_position: torch.Tensor = torch.arange(
529+
past_seen_tokens, past_seen_tokens + input_shape[-1], device=device
530+
)
531+
535532
if position_ids is None:
536-
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
537-
position_ids = position_ids.unsqueeze(0)
533+
position_ids = cache_position.unsqueeze(0)
538534

539535
# ImageGPTAttention mask.
540536
if attention_mask is not None:

src/transformers/utils/generic.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -797,17 +797,34 @@ def check_model_inputs(tie_last_hidden_states=True):
797797
def wrapped_fn(func):
798798
@wraps(func)
799799
def wrapper(self, *args, **kwargs):
800-
use_cache = (
801-
kwargs["use_cache"] if kwargs.get("use_cache") is not None else getattr(self.config, "use_cache", None)
802-
)
800+
use_cache_arg_index = None
801+
if "use_cache" in func.__code__.co_varnames:
802+
use_cache_arg_index = func.__code__.co_varnames.index("use_cache") - 1 # -1 for self
803+
804+
if (
805+
use_cache_arg_index is not None
806+
and len(args) > use_cache_arg_index
807+
and args[use_cache_arg_index] is not None
808+
):
809+
use_cache = args[use_cache_arg_index]
810+
elif kwargs.get("use_cache") is not None:
811+
use_cache = kwargs["use_cache"]
812+
else:
813+
use_cache = getattr(self.config, "use_cache", None)
814+
803815
if use_cache is not None:
804816
if getattr(self, "gradient_checkpointing", False) and self.training and use_cache:
805817
logger.warning_once(
806818
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
807819
)
808820
use_cache = False
809821

810-
kwargs["use_cache"] = use_cache
822+
if use_cache_arg_index is not None and len(args) > use_cache_arg_index:
823+
args = list(args)
824+
args[use_cache_arg_index] = use_cache
825+
args = tuple(args)
826+
else:
827+
kwargs["use_cache"] = use_cache
811828

812829
return_dict = kwargs.pop("return_dict", None)
813830
if return_dict is None:
@@ -818,7 +835,8 @@ def wrapper(self, *args, **kwargs):
818835
for k, v in all_args["kwargs"].items():
819836
all_args[k] = v
820837

821-
capture_flags = _CAN_RECORD_REGISTRY.get(str(self.__class__), {}) # there is a weak ref for executorch
838+
# _can_record_outputs is None by default
839+
capture_flags = _CAN_RECORD_REGISTRY.get(str(self.__class__)) or {} # there is a weak ref for executorch
822840
recordable_keys = {
823841
f"output_{k}": all_args.get(
824842
f"output_{k}",

0 commit comments

Comments
 (0)