Skip to content

Commit b18da36

Browse files
author
Liu Zhengyun
committed
update and recover some code
1 parent dd2afc6 commit b18da36

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -603,38 +603,36 @@ def prepare_inputs_for_generation(
603603
**kwargs,
604604
):
605605
# Omit tokens covered by past_key_values
606-
past_length = 0
607-
token_num = (
608-
input_ids.shape[1] + self.config.input_token_len - 1
609-
) // self.config.input_token_len
610-
611606
if past_key_values is not None:
612607
if isinstance(past_key_values, Cache):
613608
past_length = past_key_values.get_seq_length()
614609
else:
615610
past_length = past_key_values[0][0].shape[2]
616611

617-
if past_key_values is not None and past_length > 0:
618612
# Keep only the unprocessed tokens:
619613
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
620614
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
621615
# input)
622-
if attention_mask is not None and attention_mask.shape[1] > token_num:
623-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
616+
if attention_mask is not None and attention_mask.shape[1] > (
617+
input_ids.shape[1] // self.config.input_token_len
618+
):
619+
input_ids = input_ids[
620+
:,
621+
-(attention_mask.shape[1] - past_length)
622+
* self.config.input_token_len :,
623+
]
624624
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
625625
# input_ids based on the past_length.
626-
elif past_length < token_num:
627-
# TODO: Actually, we need to know the output_token_lens used in the last generation step.
628-
# Sundial will pad the input when it is non-divisible, so we cannot use past_length to slice input_ids
629-
input_ids = input_ids[:, -self.config.output_token_lens[0] :]
626+
elif past_length < (input_ids.shape[1] // self.config.input_token_len):
627+
input_ids = input_ids[:, past_length * self.config.input_token_len :]
630628
# 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens.
631629

632630
position_ids = kwargs.get("position_ids", None)
633631
if attention_mask is not None and position_ids is None:
634632
# create position_ids on the fly for batch generation
635633
position_ids = attention_mask.long().cumsum(-1) - 1
636634
position_ids.masked_fill_(attention_mask == 0, 1)
637-
if past_key_values is not None and past_length > 0:
635+
if past_key_values is not None:
638636
token_num = (
639637
input_ids.shape[1] + self.config.input_token_len - 1
640638
) // self.config.input_token_len

iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,11 @@ def prepare_inputs_for_generation(
606606
if attention_mask is not None and attention_mask.shape[1] > (
607607
input_ids.shape[1] // self.config.input_token_len
608608
):
609-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
609+
input_ids = input_ids[
610+
:,
611+
-(attention_mask.shape[1] - past_length)
612+
* self.config.input_token_len :,
613+
]
610614
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
611615
# input_ids based on the past_length.
612616
elif past_length < (input_ids.shape[1] // self.config.input_token_len):

0 commit comments

Comments
 (0)