diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py index 544193e4d9c6..3ebf516f705e 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py @@ -616,7 +616,11 @@ def prepare_inputs_for_generation( if attention_mask is not None and attention_mask.shape[1] > ( input_ids.shape[1] // self.config.input_token_len ): - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + input_ids = input_ids[ + :, + -(attention_mask.shape[1] - past_length) + * self.config.input_token_len :, + ] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < (input_ids.shape[1] // self.config.input_token_len): @@ -629,9 +633,10 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[ - :, -(input_ids.shape[1] // self.config.input_token_len) : - ] + token_num = ( + input_ids.shape[1] + self.config.input_token_len - 1 + ) // self.config.input_token_len + position_ids = position_ids[:, -token_num:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py index 37bf56dfc59a..0a33c682742a 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py @@ -606,7 +606,11 @@ def prepare_inputs_for_generation( if attention_mask is not None and attention_mask.shape[1] > ( input_ids.shape[1] // self.config.input_token_len ): - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + input_ids = input_ids[ + :, + -(attention_mask.shape[1] - past_length) + * self.config.input_token_len :, + ] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < (input_ids.shape[1] // self.config.input_token_len): diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java index e77e0641ae97..260410954d4c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java @@ -114,7 +114,6 @@ public void beforeStart(UDFParameters parameters, UDTFConfigurations configurati } ModelInferenceDescriptor descriptor = modelFetcher.fetchModel(this.model_id); this.targetAINode = descriptor.getTargetAINode(); - this.maxInputLength = descriptor.getModelInformation().getInputShape()[0]; this.outputInterval = parameters.getLongOrDefault(OUTPUT_INTERVAL, DEFAULT_OUTPUT_INTERVAL); this.outputLength =