Skip to content

Commit 2b47be7

Browse files
authored
[AINode] Fix bug of sundial and forecast udf (apache#16768)
1 parent 2c381fe commit 2b47be7

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,11 @@ def prepare_inputs_for_generation(
616616
if attention_mask is not None and attention_mask.shape[1] > (
617617
input_ids.shape[1] // self.config.input_token_len
618618
):
619-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
619+
input_ids = input_ids[
620+
:,
621+
-(attention_mask.shape[1] - past_length)
622+
* self.config.input_token_len :,
623+
]
620624
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
621625
# input_ids based on the past_length.
622626
elif past_length < (input_ids.shape[1] // self.config.input_token_len):
@@ -629,9 +633,10 @@ def prepare_inputs_for_generation(
629633
position_ids = attention_mask.long().cumsum(-1) - 1
630634
position_ids.masked_fill_(attention_mask == 0, 1)
631635
if past_key_values:
632-
position_ids = position_ids[
633-
:, -(input_ids.shape[1] // self.config.input_token_len) :
634-
]
636+
token_num = (
637+
input_ids.shape[1] + self.config.input_token_len - 1
638+
) // self.config.input_token_len
639+
position_ids = position_ids[:, -token_num:]
635640

636641
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
637642
if inputs_embeds is not None and past_key_values is None:

iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,11 @@ def prepare_inputs_for_generation(
606606
if attention_mask is not None and attention_mask.shape[1] > (
607607
input_ids.shape[1] // self.config.input_token_len
608608
):
609-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
609+
input_ids = input_ids[
610+
:,
611+
-(attention_mask.shape[1] - past_length)
612+
* self.config.input_token_len :,
613+
]
610614
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
611615
# input_ids based on the past_length.
612616
elif past_length < (input_ids.shape[1] // self.config.input_token_len):

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ public void beforeStart(UDFParameters parameters, UDTFConfigurations configurati
114114
}
115115
ModelInferenceDescriptor descriptor = modelFetcher.fetchModel(this.model_id);
116116
this.targetAINode = descriptor.getTargetAINode();
117-
this.maxInputLength = descriptor.getModelInformation().getInputShape()[0];
118117

119118
this.outputInterval = parameters.getLongOrDefault(OUTPUT_INTERVAL, DEFAULT_OUTPUT_INTERVAL);
120119
this.outputLength =

0 commit comments

Comments
 (0)