Skip to content

Commit 1366410

Browse files
committed
Update pipeline_chronos2.py
1 parent 795bbbb commit 1366410

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,8 @@ class Chronos2Pipeline(ForecastPipeline):
3737
def __init__(self, model_info, **model_kwargs):
3838
super().__init__(model_info, model_kwargs=model_kwargs)
3939

40-
def _preprocess(self, inputs):
41-
if len(inputs.shape) != 2:
42-
raise InferenceModelInternalException(
43-
f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}"
44-
)
45-
inputs = inputs.unsqueeze(0)
40+
def preprocess(self, inputs):
41+
inputs = super().preprocess(inputs)
4642
return inputs
4743

4844
@property
@@ -391,5 +387,5 @@ def _predict_step(
391387

392388
return prediction
393389

394-
def _postprocess(self, output: torch.Tensor):
390+
def postprocess(self, output: torch.Tensor):
395391
return output[0].mean(dim=1)

0 commit comments

Comments
 (0)