Skip to content

Commit fe16a22

Browse files
author
Liu Zhengyun
committed
redefine parameter input to targets and fix some bugs
1 parent 48bd352 commit fe16a22

File tree

5 files changed

+180
-207
lines changed

5 files changed

+180
-207
lines changed

iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,11 @@ def prepare_inputs_for_generation(
610610
if attention_mask is not None and attention_mask.shape[1] > (
611611
input_ids.shape[1] // self.config.input_token_len
612612
):
613-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
613+
input_ids = input_ids[
614+
:,
615+
-(attention_mask.shape[1] - past_length)
616+
* self.config.input_token_len :,
617+
]
614618
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
615619
# input_ids based on the past_length.
616620
elif past_length < (input_ids.shape[1] // self.config.input_token_len):
@@ -623,9 +627,10 @@ def prepare_inputs_for_generation(
623627
position_ids = attention_mask.long().cumsum(-1) - 1
624628
position_ids.masked_fill_(attention_mask == 0, 1)
625629
if past_key_values:
626-
position_ids = position_ids[
627-
:, -(input_ids.shape[1] // self.config.input_token_len) :
628-
]
630+
token_num = (
631+
input_ids.shape[1] + self.config.input_token_len - 1
632+
) // self.config.input_token_len
633+
position_ids = position_ids[:, -token_num:]
629634

630635
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
631636
if inputs_embeds is not None and past_key_values is None:

iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,11 @@ def prepare_inputs_for_generation(
603603
if attention_mask is not None and attention_mask.shape[1] > (
604604
input_ids.shape[1] // self.config.input_token_len
605605
):
606-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
606+
input_ids = input_ids[
607+
:,
608+
-(attention_mask.shape[1] - past_length)
609+
* self.config.input_token_len :,
610+
]
607611
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
608612
# input_ids based on the past_length.
609613
elif past_length < (input_ids.shape[1] // self.config.input_token_len):

0 commit comments

Comments
 (0)