Skip to content

Commit 92894c6

Browse files
committed
runnable version
1 parent 580ee77 commit 92894c6

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,11 @@ def _run(
209209
inference_pipeline = load_pipeline(model_info, device="cpu")
210210
inputs = inference_pipeline.preprocess(inputs)
211211
if isinstance(inference_pipeline, ForecastPipeline):
212+
inputs = inference_pipeline._preprocess(inputs)
212213
outputs = inference_pipeline.forecast(
213214
inputs, predict_length=output_length, **inference_attrs
214215
)
216+
outputs = inference_pipeline._postprocess(outputs)
215217
elif isinstance(inference_pipeline, ClassificationPipeline):
216218
outputs = inference_pipeline.classify(inputs)
217219
elif isinstance(inference_pipeline, ChatPipeline):

iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def _preprocess(self, inputs):
4242
raise InferenceModelInternalException(
4343
f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}"
4444
)
45+
inputs = inputs.unsqueeze(0)
4546
return inputs
4647

4748
@property
@@ -239,7 +240,7 @@ def forecast(self, inputs, **infer_kwargs):
239240
)
240241
logger.warning(msg)
241242

242-
context_length = len(inputs)
243+
context_length = inputs.shape[-1]
243244
if context_length > self.model_context_length:
244245
logger.warning(
245246
f"The specified context_length {context_length} is greater than the model's default context length {self.model_context_length}. "
@@ -391,4 +392,4 @@ def _predict_step(
391392
return prediction
392393

393394
def _postprocess(self, output: torch.Tensor):
394-
return output
395+
return output[0].mean(dim=1)

iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444

4545
def load_model(model_info: ModelInfo, **model_kwargs) -> Any:
46-
if model_info.auto_map is not None:
46+
if model_info.category == ModelCategory.BUILTIN or model_info.auto_map is not None:
4747
model = load_model_from_transformers(model_info, **model_kwargs)
4848
else:
4949
if model_info.model_type == "sktime":

0 commit comments

Comments
 (0)