Skip to content

Commit 14b4382

Browse files
authored
[AINode]Fix the parameter "predict_length" (#15900)
1 parent 1ad56c7 commit 14b4382

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

iotdb-core/ainode/ainode/core/manager/inference_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,15 @@ def infer(self, full_data, predict_length=96, **_):
8484

8585

8686
class BuiltInStrategy(InferenceStrategy):
87-
def infer(self, full_data):
87+
def infer(self, full_data, **_):
8888
data = pd.DataFrame(full_data[1]).T
8989
output = self.model.inference(data)
9090
df = pd.DataFrame(output)
9191
return convert_to_binary(df)
9292

9393

9494
class RegisteredStrategy(InferenceStrategy):
95-
def infer(self, full_data, window_interval=None, window_step=None, **kwargs):
95+
def infer(self, full_data, window_interval=None, window_step=None, **_):
9696
_, dataset, _, length = full_data
9797
if window_interval is None or window_step is None:
9898
window_interval = length
@@ -159,7 +159,7 @@ def _run(
159159

160160
# inference by strategy
161161
strategy = self._get_strategy(model_id, model)
162-
outputs = strategy.infer(full_data)
162+
outputs = strategy.infer(full_data, **inference_attrs)
163163

164164
# construct response
165165
status = get_status(TSStatusCode.SUCCESS_STATUS)

iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ def generate(
5454
) -> Union[GenerateOutput, torch.LongTensor]:
5555
if len(inputs.shape) != 2:
5656
raise ValueError("Input shape must be: [batch_size, seq_len]")
57+
batch_size, cur_len = inputs.shape
58+
if cur_len < self.config.input_token_len:
59+
raise ValueError(
60+
f"Input length must be at least {self.config.input_token_len}"
61+
)
5762
if revin:
5863
means = inputs.mean(dim=-1, keepdim=True)
5964
stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5

0 commit comments

Comments
 (0)