Skip to content

Commit 41fd7b8

Browse files
authored
[AINode] Modify timer_xl and sundial for transformersv4.56.2 (apache#16568)
1 parent c983221 commit 41fd7b8

File tree

6 files changed

+1009
-1022
lines changed

6 files changed

+1009
-1022
lines changed

iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py

Lines changed: 30 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def forward(
156156
past_key_value: Optional[Cache] = None,
157157
output_attentions: bool = False,
158158
**kwargs,
159-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
159+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
160160
bsz, q_len, _ = hidden_states.size()
161161

162162
query_states = self.q_proj(hidden_states)
@@ -175,7 +175,7 @@ def forward(
175175

176176
kv_seq_len = key_states.shape[-2]
177177
if past_key_value is not None:
178-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
178+
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
179179
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
180180
query_states, key_states = apply_rotary_pos_emb(
181181
query_states, key_states, cos, sin, position_ids
@@ -238,15 +238,13 @@ def forward(
238238
hidden_states: torch.Tensor,
239239
attention_mask: Optional[torch.Tensor] = None,
240240
position_ids: Optional[torch.LongTensor] = None,
241-
past_key_value: Optional[Tuple[torch.Tensor]] = None,
241+
past_key_value: Optional[Cache] = None,
242242
output_attentions: Optional[bool] = False,
243-
use_cache: Optional[bool] = False,
244243
**kwargs,
245244
) -> Tuple[
246245
torch.FloatTensor,
247-
torch.FloatTensor,
248-
Optional[torch.FloatTensor],
249-
Optional[torch.FloatTensor],
246+
Optional[torch.Tensor],
247+
Optional[Cache],
250248
]:
251249
residual = hidden_states
252250

@@ -259,7 +257,6 @@ def forward(
259257
position_ids=position_ids,
260258
past_key_value=past_key_value,
261259
output_attentions=output_attentions,
262-
use_cache=use_cache,
263260
)
264261
hidden_states = residual + hidden_states
265262

@@ -272,8 +269,6 @@ def forward(
272269
if not output_attentions:
273270
self_attn_weights = None
274271

275-
if not use_cache:
276-
present_key_value = None
277272
return hidden_states, self_attn_weights, present_key_value
278273

279274

@@ -317,9 +312,10 @@ def forward(
317312
input_ids: torch.FloatTensor = None,
318313
attention_mask: Optional[torch.Tensor] = None,
319314
position_ids: Optional[torch.LongTensor] = None,
320-
past_key_values: Optional[List[torch.FloatTensor]] = None,
315+
past_key_values: Optional[
316+
Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]]
317+
] = None,
321318
inputs_embeds: Optional[torch.FloatTensor] = None,
322-
use_cache: Optional[bool] = None,
323319
output_attentions: Optional[bool] = None,
324320
output_hidden_states: Optional[bool] = None,
325321
return_dict: Optional[bool] = None,
@@ -335,7 +331,6 @@ def forward(
335331
if output_hidden_states is not None
336332
else self.config.output_hidden_states
337333
)
338-
use_cache = use_cache if use_cache is not None else self.config.use_cache
339334

340335
return_dict = (
341336
return_dict if return_dict is not None else self.config.use_return_dict
@@ -359,17 +354,25 @@ def forward(
359354
inputs_embeds = self.embed_layer(input_ids)
360355
seq_length = inputs_embeds.shape[1]
361356

362-
if self.gradient_checkpointing and self.training:
363-
if use_cache:
364-
use_cache = False
365-
366357
past_key_values_length = 0
358+
use_legacy_cache = False
367359

368-
if use_cache:
360+
if past_key_values is not None:
369361
use_legacy_cache = not isinstance(past_key_values, Cache)
362+
# Converts the legacy cache which is tuple into an equivalent Cache. Used for backward compatibility.
370363
if use_legacy_cache:
371364
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
372-
past_key_values_length = past_key_values.get_usable_length(seq_length)
365+
# Suppose the sequence length of each layer is the same
366+
past_key_values_length = past_key_values.get_seq_length()
367+
368+
# When training + checkpoints, caching is usually disabled (just do not transfer)
369+
if (
370+
self.gradient_checkpointing
371+
and self.training
372+
and isinstance(past_key_values, Cache)
373+
):
374+
past_key_values = None
375+
past_key_values_length = 0
373376

374377
if position_ids is None:
375378
device = input_ids.device if input_ids is not None else inputs_embeds.device
@@ -412,7 +415,6 @@ def forward(
412415
position_ids,
413416
past_key_values,
414417
output_attentions,
415-
use_cache,
416418
)
417419
else:
418420
layer_outputs = decoder_layer(
@@ -421,15 +423,14 @@ def forward(
421423
position_ids=position_ids,
422424
past_key_value=past_key_values,
423425
output_attentions=output_attentions,
424-
use_cache=use_cache,
425426
)
426427

427428
hidden_states = layer_outputs[0]
428429

429430
if output_attentions:
430431
all_self_attns += (layer_outputs[1],)
431432

432-
if use_cache:
433+
if isinstance(past_key_values, Cache):
433434
next_decoder_cache = layer_outputs[2]
434435

435436
hidden_states = self.norm(hidden_states)
@@ -438,7 +439,7 @@ def forward(
438439
all_hidden_states += (hidden_states,)
439440

440441
next_cache = None
441-
if use_cache:
442+
if isinstance(past_key_values, Cache):
442443
next_cache = (
443444
next_decoder_cache.to_legacy_cache()
444445
if use_legacy_cache
@@ -484,12 +485,13 @@ def forward(
484485
input_ids: torch.FloatTensor = None,
485486
attention_mask: Optional[torch.Tensor] = None,
486487
position_ids: Optional[torch.LongTensor] = None,
487-
past_key_values: Optional[List[torch.FloatTensor]] = None,
488+
past_key_values: Optional[
489+
Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]]
490+
] = None,
488491
inputs_embeds: Optional[torch.FloatTensor] = None,
489492
labels: Optional[torch.FloatTensor] = None,
490493
loss_masks: Optional[torch.FloatTensor] = None,
491494
mask_y: Optional[torch.FloatTensor] = None,
492-
use_cache: Optional[bool] = None,
493495
output_attentions: Optional[bool] = None,
494496
output_hidden_states: Optional[bool] = None,
495497
return_dict: Optional[bool] = None,
@@ -525,7 +527,6 @@ def forward(
525527
position_ids=position_ids,
526528
past_key_values=past_key_values,
527529
inputs_embeds=inputs_embeds,
528-
use_cache=use_cache,
529530
output_attentions=output_attentions,
530531
output_hidden_states=output_hidden_states,
531532
return_dict=return_dict,
@@ -604,16 +605,9 @@ def prepare_inputs_for_generation(
604605
# Omit tokens covered by past_key_values
605606
if past_key_values is not None:
606607
if isinstance(past_key_values, Cache):
607-
cache_length = past_key_values.get_seq_length()
608-
if isinstance(past_key_values, DynamicCache):
609-
past_length = past_key_values.seen_tokens
610-
else:
611-
past_length = cache_length
612-
613-
max_cache_length = past_key_values.get_max_length()
608+
past_length = past_key_values.get_seq_length()
614609
else:
615-
cache_length = past_length = past_key_values[0][0].shape[2]
616-
max_cache_length = None
610+
past_length = past_key_values[0][0].shape[2]
617611

618612
# Keep only the unprocessed tokens:
619613
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
@@ -622,26 +616,13 @@ def prepare_inputs_for_generation(
622616
if attention_mask is not None and attention_mask.shape[1] > (
623617
input_ids.shape[1] // self.config.input_token_len
624618
):
625-
input_ids = input_ids[
626-
:,
627-
-(attention_mask.shape[1] - past_length)
628-
* self.config.input_token_len :,
629-
]
619+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
630620
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
631621
# input_ids based on the past_length.
632622
elif past_length < (input_ids.shape[1] // self.config.input_token_len):
633623
input_ids = input_ids[:, past_length * self.config.input_token_len :]
634624
# 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens.
635625

636-
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
637-
if (
638-
max_cache_length is not None
639-
and attention_mask is not None
640-
and cache_length + (input_ids.shape[1] // self.config.input_token_len)
641-
> max_cache_length
642-
):
643-
attention_mask = attention_mask[:, -max_cache_length:]
644-
645626
position_ids = kwargs.get("position_ids", None)
646627
if attention_mask is not None and position_ids is None:
647628
# create position_ids on the fly for batch generation
@@ -662,7 +643,6 @@ def prepare_inputs_for_generation(
662643
{
663644
"position_ids": position_ids,
664645
"past_key_values": past_key_values,
665-
"use_cache": kwargs.get("use_cache"),
666646
"attention_mask": attention_mask,
667647
"revin": revin,
668648
"num_samples": num_samples,

iotdb-core/ainode/iotdb/ainode/core/model/sundial/ts_generation_mixin.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def generate(
8383
outputs = (outputs * stdev) + means
8484
return outputs
8585

86-
def _greedy_search(
86+
def _sample(
8787
self,
8888
input_ids: torch.Tensor,
8989
logits_processor: Optional[LogitsProcessorList] = None,
@@ -269,7 +269,7 @@ def _greedy_search(
269269
horizon_length = next_tokens.shape[-1] // self.config.input_token_len
270270

271271
past_key_values = model_kwargs.get("past_key_values")
272-
if past_key_values is None:
272+
if past_key_values is None or generate_results is None:
273273
generate_results = next_tokens
274274
else:
275275
generate_results = torch.cat([generate_results, next_tokens], dim=-1)
@@ -328,9 +328,13 @@ def _update_model_kwargs_for_generation(
328328
standardize_cache_format: bool = False,
329329
) -> Dict[str, Any]:
330330
# update past_key_values
331-
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
332-
outputs, standardize_cache_format=standardize_cache_format
333-
)
331+
if "past_key_values" in outputs:
332+
model_kwargs["past_key_values"] = outputs.past_key_values
333+
elif "mems" in outputs:
334+
model_kwargs["past_key_values"] = outputs.mems
335+
elif "past_buckets_states" in outputs:
336+
model_kwargs["past_key_values"] = outputs.past_buckets_states
337+
334338
if getattr(outputs, "state", None) is not None:
335339
model_kwargs["state"] = outputs.state
336340

0 commit comments

Comments
 (0)