Skip to content

Commit c100996

Browse files
committed
Expose the processing interface of the pipeline
1 parent 4440993 commit c100996

File tree

6 files changed

+38
-32
lines changed

6 files changed

+38
-32
lines changed

iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def _step(self):
123123
batch_inputs = self._batcher.batch_request(requests).to(
124124
"cpu"
125125
) # The input data should first load to CPU in current version
126+
batch_inputs = self._inference_pipeline.preprocess(batch_inputs)
126127
if isinstance(self._inference_pipeline, ForecastPipeline):
127128
batch_output = self._inference_pipeline.forecast(
128129
batch_inputs,
@@ -140,7 +141,10 @@ def _step(self):
140141
# more infer kwargs can be added here
141142
)
142143
else:
144+
batch_output = None
143145
self._logger.error("[Inference] Unsupported pipeline type.")
146+
batch_output = self._inference_pipeline.postprocess(batch_output)
147+
144148
offset = 0
145149
for request in requests:
146150
request.output_tensor = request.output_tensor.to(self.device)

iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,59 +29,61 @@ def __init__(self, model_info, **model_kwargs):
2929
self.device = model_kwargs.get("device", "cpu")
3030
self.model = load_model(model_info, device_map=self.device, **model_kwargs)
3131

32-
def _preprocess(self, inputs):
32+
@abstractmethod
33+
def preprocess(self, inputs):
3334
"""
3435
Preprocess the input before inference, including shape validation and value transformation.
3536
"""
36-
return inputs
37+
raise NotImplementedError("preprocess not implemented")
3738

38-
def _postprocess(self, output: torch.Tensor):
39+
@abstractmethod
40+
def postprocess(self, output: torch.Tensor):
3941
"""
4042
Post-process the outputs after the entire inference task.
4143
"""
42-
return output
44+
raise NotImplementedError("postprocess not implemented")
4345

4446

4547
class ForecastPipeline(BasicPipeline):
4648
def __init__(self, model_info, **model_kwargs):
4749
super().__init__(model_info, model_kwargs=model_kwargs)
4850

49-
def _preprocess(self, inputs):
51+
def preprocess(self, inputs):
5052
return inputs
5153

5254
@abstractmethod
5355
def forecast(self, inputs, **infer_kwargs):
5456
pass
5557

56-
def _postprocess(self, output: torch.Tensor):
58+
def postprocess(self, output: torch.Tensor):
5759
return output
5860

5961

6062
class ClassificationPipeline(BasicPipeline):
6163
def __init__(self, model_info, **model_kwargs):
6264
super().__init__(model_info, model_kwargs=model_kwargs)
6365

64-
def _preprocess(self, inputs):
66+
def preprocess(self, inputs):
6567
return inputs
6668

6769
@abstractmethod
6870
def classify(self, inputs, **kwargs):
6971
pass
7072

71-
def _postprocess(self, output: torch.Tensor):
73+
def postprocess(self, output: torch.Tensor):
7274
return output
7375

7476

7577
class ChatPipeline(BasicPipeline):
7678
def __init__(self, model_info, **model_kwargs):
7779
super().__init__(model_info, model_kwargs=model_kwargs)
7880

79-
def _preprocess(self, inputs):
81+
def preprocess(self, inputs):
8082
return inputs
8183

8284
@abstractmethod
8385
def chat(self, inputs, **kwargs):
8486
pass
8587

86-
def _postprocess(self, output: torch.Tensor):
88+
def postprocess(self, output: torch.Tensor):
8789
return output

iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def _run(
215215
else:
216216
model_info = self._model_manager.get_model_info(model_id)
217217
inference_pipeline = load_pipeline(model_info, device="cpu")
218+
inputs = inference_pipeline.preprocess(inputs)
218219
if isinstance(inference_pipeline, ForecastPipeline):
219220
outputs = inference_pipeline.forecast(
220221
inputs, predict_length=output_length, **inference_attrs
@@ -224,7 +225,9 @@ def _run(
224225
elif isinstance(inference_pipeline, ChatPipeline):
225226
outputs = inference_pipeline.chat(inputs)
226227
else:
228+
outputs = None
227229
logger.error("[Inference] Unsupported pipeline type.")
230+
outputs = inference_pipeline.postprocess(outputs)
228231
outputs = convert_to_binary(pd.DataFrame(outputs[0]))
229232

230233
# construct response

iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,41 +28,40 @@ def __init__(self, model_info, **model_kwargs):
2828
model_kwargs.pop("device", None) # sktime models run on CPU
2929
super().__init__(model_info, model_kwargs=model_kwargs)
3030

31-
def _preprocess(self, inputs):
31+
def preprocess(self, inputs):
3232
return inputs
3333

3434
def forecast(self, inputs, **infer_kwargs):
3535
predict_length = infer_kwargs.get("predict_length", 96)
36-
input_ids = self._preprocess(inputs)
3736

3837
# Convert to pandas Series for sktime (sktime expects Series or DataFrame)
3938
# Handle batch dimension: if batch_size > 1, process each sample separately
40-
if len(input_ids.shape) == 2 and input_ids.shape[0] > 1:
39+
if len(inputs.shape) == 2 and inputs.shape[0] > 1:
4140
# Batch processing: convert each row to Series
4241
outputs = []
43-
for i in range(input_ids.shape[0]):
42+
for i in range(inputs.shape[0]):
4443
series = pd.Series(
45-
input_ids[i].cpu().numpy()
46-
if isinstance(input_ids, torch.Tensor)
47-
else input_ids[i]
44+
inputs[i].cpu().numpy()
45+
if isinstance(inputs, torch.Tensor)
46+
else inputs[i]
4847
)
4948
output = self.model.generate(series, predict_length=predict_length)
5049
outputs.append(output)
5150
output = np.array(outputs)
5251
else:
5352
# Single sample: convert to Series
54-
if isinstance(input_ids, torch.Tensor):
55-
series = pd.Series(input_ids.squeeze().cpu().numpy())
53+
if isinstance(inputs, torch.Tensor):
54+
series = pd.Series(inputs.squeeze().cpu().numpy())
5655
else:
57-
series = pd.Series(input_ids.squeeze())
56+
series = pd.Series(inputs.squeeze())
5857
output = self.model.generate(series, predict_length=predict_length)
5958
# Add batch dimension if needed
6059
if len(output.shape) == 1:
6160
output = output[np.newaxis, :]
6261

63-
return self._postprocess(output)
62+
return output
6463

65-
def _postprocess(self, output):
64+
def postprocess(self, output):
6665
if isinstance(output, np.ndarray):
6766
return torch.from_numpy(output).float()
6867
return output

iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class SundialPipeline(ForecastPipeline):
2626
def __init__(self, model_info, **model_kwargs):
2727
super().__init__(model_info, model_kwargs=model_kwargs)
2828

29-
def _preprocess(self, inputs):
29+
def preprocess(self, inputs):
3030
if len(inputs.shape) != 2:
3131
raise InferenceModelInternalException(
3232
f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}"
@@ -38,14 +38,13 @@ def forecast(self, inputs, **infer_kwargs):
3838
num_samples = infer_kwargs.get("num_samples", 10)
3939
revin = infer_kwargs.get("revin", True)
4040

41-
input_ids = self._preprocess(inputs)
4241
output = self.model.generate(
43-
input_ids,
42+
inputs,
4443
max_new_tokens=predict_length,
4544
num_samples=num_samples,
4645
revin=revin,
4746
)
48-
return self._postprocess(output)
47+
return output
4948

50-
def _postprocess(self, output: torch.Tensor):
49+
def postprocess(self, output: torch.Tensor):
5150
return output.mean(dim=1)

iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class TimerPipeline(ForecastPipeline):
2626
def __init__(self, model_info, **model_kwargs):
2727
super().__init__(model_info, model_kwargs=model_kwargs)
2828

29-
def _preprocess(self, inputs):
29+
def preprocess(self, inputs):
3030
if len(inputs.shape) != 2:
3131
raise InferenceModelInternalException(
3232
f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}"
@@ -37,11 +37,10 @@ def forecast(self, inputs, **infer_kwargs):
3737
predict_length = infer_kwargs.get("predict_length", 96)
3838
revin = infer_kwargs.get("revin", True)
3939

40-
input_ids = self._preprocess(inputs)
4140
output = self.model.generate(
42-
input_ids, max_new_tokens=predict_length, revin=revin
41+
inputs, max_new_tokens=predict_length, revin=revin
4342
)
44-
return self._postprocess(output)
43+
return output
4544

46-
def _postprocess(self, output: torch.Tensor):
45+
def postprocess(self, output: torch.Tensor):
4746
return output

0 commit comments

Comments
 (0)