Skip to content

Commit dd2afc6

Browse files
author
Liu Zhengyun
committed
fix sundial and forecast
1 parent cd443ba commit dd2afc6

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -603,35 +603,42 @@ def prepare_inputs_for_generation(
603603
**kwargs,
604604
):
605605
# Omit tokens covered by past_key_values
606+
past_length = 0
607+
token_num = (
608+
input_ids.shape[1] + self.config.input_token_len - 1
609+
) // self.config.input_token_len
610+
606611
if past_key_values is not None:
607612
if isinstance(past_key_values, Cache):
608613
past_length = past_key_values.get_seq_length()
609614
else:
610615
past_length = past_key_values[0][0].shape[2]
611616

617+
if past_key_values is not None and past_length > 0:
612618
# Keep only the unprocessed tokens:
613619
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
614620
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
615621
# input)
616-
if attention_mask is not None and attention_mask.shape[1] > (
617-
input_ids.shape[1] // self.config.input_token_len
618-
):
622+
if attention_mask is not None and attention_mask.shape[1] > token_num:
619623
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
620624
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
621625
# input_ids based on the past_length.
622-
elif past_length < (input_ids.shape[1] // self.config.input_token_len):
623-
input_ids = input_ids[:, past_length * self.config.input_token_len :]
626+
elif past_length < token_num:
627+
# TODO: Actually, we need to know the output_token_lens used in the last generation step.
628+
# Sundial will pad the input when it is non-divisible, so we cannot use past_length to slice input_ids
629+
input_ids = input_ids[:, -self.config.output_token_lens[0] :]
624630
# 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens.
625631

626632
position_ids = kwargs.get("position_ids", None)
627633
if attention_mask is not None and position_ids is None:
628634
# create position_ids on the fly for batch generation
629635
position_ids = attention_mask.long().cumsum(-1) - 1
630636
position_ids.masked_fill_(attention_mask == 0, 1)
631-
if past_key_values:
632-
position_ids = position_ids[
633-
:, -(input_ids.shape[1] // self.config.input_token_len) :
634-
]
637+
if past_key_values is not None and past_length > 0:
638+
token_num = (
639+
input_ids.shape[1] + self.config.input_token_len - 1
640+
) // self.config.input_token_len
641+
position_ids = position_ids[:, -token_num:]
635642

636643
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
637644
if inputs_embeds is not None and past_key_values is None:

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ public void beforeStart(UDFParameters parameters, UDTFConfigurations configurati
114114
}
115115
ModelInferenceDescriptor descriptor = modelFetcher.fetchModel(this.model_id);
116116
this.targetAINode = descriptor.getTargetAINode();
117-
this.maxInputLength = descriptor.getModelInformation().getInputShape()[0];
118117

119118
this.outputInterval = parameters.getLongOrDefault(OUTPUT_INTERVAL, DEFAULT_OUTPUT_INTERVAL);
120119
this.outputLength =

0 commit comments

Comments
 (0)