diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/ainode/core/manager/inference_manager.py index 9eda1c22651b..a8109e278db0 100644 --- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py @@ -84,7 +84,7 @@ def infer(self, full_data, predict_length=96, **_): class BuiltInStrategy(InferenceStrategy): - def infer(self, full_data): + def infer(self, full_data, **_): data = pd.DataFrame(full_data[1]).T output = self.model.inference(data) df = pd.DataFrame(output) @@ -92,7 +92,7 @@ def infer(self, full_data): class RegisteredStrategy(InferenceStrategy): - def infer(self, full_data, window_interval=None, window_step=None, **kwargs): + def infer(self, full_data, window_interval=None, window_step=None, **_): _, dataset, _, length = full_data if window_interval is None or window_step is None: window_interval = length @@ -159,7 +159,7 @@ def _run( # inference by strategy strategy = self._get_strategy(model_id, model) - outputs = strategy.infer(full_data) + outputs = strategy.infer(full_data, **inference_attrs) # construct response status = get_status(TSStatusCode.SUCCESS_STATUS) diff --git a/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py b/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py index d894d3d5ed3d..045711616607 100644 --- a/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py +++ b/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py @@ -54,6 +54,11 @@ def generate( ) -> Union[GenerateOutput, torch.LongTensor]: if len(inputs.shape) != 2: raise ValueError("Input shape must be: [batch_size, seq_len]") + batch_size, cur_len = inputs.shape + if cur_len < self.config.input_token_len: + raise ValueError( + f"Input length must be at least {self.config.input_token_len}" + ) if revin: means = inputs.mean(dim=-1, keepdim=True) stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5