Skip to content

Commit c5c89f8

Browse files
committed
Standardize the dimensions of the input and output of forecast task to 3D
1 parent c4a9866 commit c5c89f8

File tree

7 files changed

+125
-67
lines changed

7 files changed

+125
-67
lines changed

iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
output_length: int = 96,
4343
**infer_kwargs,
4444
):
45-
if inputs.ndim == 1:
45+
while inputs.ndim < 3:
4646
inputs = inputs.unsqueeze(0)
4747

4848
self.req_id = req_id
@@ -54,15 +54,16 @@ def __init__(
5454
)
5555

5656
self.batch_size = inputs.size(0)
57+
self.variable_size = inputs.size(1)
5758
self.state = InferenceRequestState.WAITING
5859
self.cur_step_idx = 0 # Current write position in the output step index
5960
self.assigned_pool_id = -1 # The pool handling this request
6061
self.assigned_device_id = -1 # The device handling this request
6162

6263
# Preallocate output buffer [batch_size, max_new_tokens]
6364
self.output_tensor = torch.zeros(
64-
self.batch_size, output_length, device="cpu"
65-
) # shape: [self.batch_size, max_new_steps]
65+
self.batch_size, self.variable_size, output_length, device="cpu"
66+
) # shape: [batch_size, target_count, predict_length]
6667

6768
def mark_running(self):
6869
self.state = InferenceRequestState.RUNNING
@@ -77,26 +78,26 @@ def is_finished(self) -> bool:
7778
)
7879

7980
def write_step_output(self, step_output: torch.Tensor):
80-
if step_output.ndim == 1:
81+
while step_output.ndim < 3:
8182
step_output = step_output.unsqueeze(0)
8283

83-
batch_size, step_size = step_output.shape
84+
batch_size, variable_size, step_size = step_output.shape
8485
end_idx = self.cur_step_idx + step_size
8586

8687
if end_idx > self.output_length:
87-
self.output_tensor[:, self.cur_step_idx :] = step_output[
88-
:, : self.output_length - self.cur_step_idx
88+
self.output_tensor[:, :, self.cur_step_idx :] = step_output[
89+
:, :, : self.output_length - self.cur_step_idx
8990
]
9091
self.cur_step_idx = self.output_length
9192
else:
92-
self.output_tensor[:, self.cur_step_idx : end_idx] = step_output
93+
self.output_tensor[:, :, self.cur_step_idx : end_idx] = step_output
9394
self.cur_step_idx = end_idx
9495

9596
if self.is_finished():
9697
self.mark_finished()
9798

9899
def get_final_output(self) -> torch.Tensor:
99-
return self.output_tensor[:, : self.cur_step_idx]
100+
return self.output_tensor[:, :, : self.cur_step_idx]
100101

101102

102103
class InferenceRequestProxy:

iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import torch
2222

23+
from iotdb.ainode.core.exception import InferenceModelInternalException
2324
from iotdb.ainode.core.model.model_loader import load_model
2425

2526

@@ -37,7 +38,7 @@ def preprocess(self, inputs):
3738
raise NotImplementedError("preprocess not implemented")
3839

3940
@abstractmethod
40-
def postprocess(self, output: torch.Tensor):
41+
def postprocess(self, outputs: torch.Tensor):
4142
"""
4243
Post-process the outputs after the entire inference task.
4344
"""
@@ -49,14 +50,28 @@ def __init__(self, model_info, **model_kwargs):
4950
super().__init__(model_info, model_kwargs=model_kwargs)
5051

5152
def preprocess(self, inputs):
53+
"""
54+
The inputs should be 3D tensor: [batch_size, target_count, sequence_length].
55+
"""
56+
if len(inputs.shape) != 3:
57+
raise InferenceModelInternalException(
58+
f"[Inference] Input must be: [batch_size, target_count, sequence_length], but receives {inputs.shape}"
59+
)
5260
return inputs
5361

5462
@abstractmethod
5563
def forecast(self, inputs, **infer_kwargs):
5664
pass
5765

58-
def postprocess(self, output: torch.Tensor):
59-
return output
66+
def postprocess(self, outputs: torch.Tensor):
67+
"""
68+
The outputs should be 3D tensor: [batch_size, target_count, predict_length].
69+
"""
70+
if len(outputs.shape) != 3:
71+
raise InferenceModelInternalException(
72+
f"[Inference] Output must be: [batch_size, target_count, predict_length], but receives {outputs.shape}"
73+
)
74+
return outputs
6075

6176

6277
class ClassificationPipeline(BasicPipeline):
@@ -70,8 +85,8 @@ def preprocess(self, inputs):
7085
def classify(self, inputs, **kwargs):
7186
pass
7287

73-
def postprocess(self, output: torch.Tensor):
74-
return output
88+
def postprocess(self, outputs: torch.Tensor):
89+
return outputs
7590

7691

7792
class ChatPipeline(BasicPipeline):
@@ -85,5 +100,5 @@ def preprocess(self, inputs):
85100
def chat(self, inputs, **kwargs):
86101
pass
87102

88-
def postprocess(self, output: torch.Tensor):
89-
return output
103+
def postprocess(self, outputs: torch.Tensor):
104+
return outputs

iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@
4646
from iotdb.ainode.core.manager.model_manager import ModelManager
4747
from iotdb.ainode.core.rpc.status import get_status
4848
from iotdb.ainode.core.util.gpu_mapping import get_available_devices
49-
from iotdb.ainode.core.util.serde import convert_to_binary
49+
from iotdb.ainode.core.util.serde import (
50+
convert_tensor_to_tsblock,
51+
convert_tsblock_to_tensor,
52+
)
5053
from iotdb.thrift.ainode.ttypes import (
5154
TForecastReq,
5255
TForecastResp,
@@ -58,7 +61,6 @@
5861
TUnloadModelReq,
5962
)
6063
from iotdb.thrift.common.ttypes import TSStatus
61-
from iotdb.tsfile.utils.tsblock_serde import deserialize
6264

6365
logger = Logger()
6466

@@ -170,23 +172,14 @@ def _run(
170172
self,
171173
req,
172174
data_getter,
173-
deserializer,
174175
extract_attrs,
175176
resp_cls,
176-
single_output: bool,
177+
single_batch: bool,
177178
):
178179
model_id = req.modelId
179180
try:
180181
raw = data_getter(req)
181-
# full data deserialized from iotdb is composed of [timestampList, valueList, None, length], we only get valueList currently.
182-
full_data = deserializer(raw)
183-
# TODO: TSBlock -> Tensor codes should be unified
184-
data = full_data[1][0] # get valueList in ndarray
185-
if data.dtype.byteorder not in ("=", "|"):
186-
np_data = data.byteswap()
187-
data = np_data.view(np_data.dtype.newbyteorder())
188-
# the inputs should be on CPU before passing to the inference request
189-
inputs = torch.tensor(data).unsqueeze(0).float().to("cpu")
182+
inputs = convert_tsblock_to_tensor(raw)
190183

191184
inference_attrs = extract_attrs(req)
192185
output_length = int(inference_attrs.pop("output_length", 96))
@@ -211,7 +204,6 @@ def _run(
211204
output_length=output_length,
212205
)
213206
outputs = self._process_request(infer_req)
214-
outputs = convert_to_binary(pd.DataFrame(outputs[0]))
215207
else:
216208
model_info = self._model_manager.get_model_info(model_id)
217209
inference_pipeline = load_pipeline(model_info, device="cpu")
@@ -228,45 +220,46 @@ def _run(
228220
outputs = None
229221
logger.error("[Inference] Unsupported pipeline type.")
230222
outputs = inference_pipeline.postprocess(outputs)
231-
outputs = convert_to_binary(pd.DataFrame(outputs[0]))
232223

233-
# construct response
234-
status = get_status(TSStatusCode.SUCCESS_STATUS)
224+
# convert tensor into tsblock for the output in each batch
225+
output_list = []
226+
for batch_idx in range(outputs.size(0)):
227+
output = convert_tensor_to_tsblock(outputs[batch_idx])
228+
output_list.append(output)
235229

236-
if isinstance(outputs, list):
237-
return resp_cls(status, outputs[0] if single_output else outputs)
238-
return resp_cls(status, outputs if single_output else [outputs])
230+
return resp_cls(
231+
get_status(TSStatusCode.SUCCESS_STATUS),
232+
output_list[0] if single_batch else output_list,
233+
)
239234

240235
except Exception as e:
241236
logger.error(e)
242237
status = get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
243-
empty = b"" if single_output else []
238+
empty = b"" if single_batch else []
244239
return resp_cls(status, empty)
245240

246241
def forecast(self, req: TForecastReq):
247242
return self._run(
248243
req,
249244
data_getter=lambda r: r.inputData,
250-
deserializer=deserialize,
251245
extract_attrs=lambda r: {
252246
"output_length": r.outputLength,
253247
**(r.options or {}),
254248
},
255249
resp_cls=TForecastResp,
256-
single_output=True,
250+
single_batch=True,
257251
)
258252

259253
def inference(self, req: TInferenceReq):
260254
return self._run(
261255
req,
262256
data_getter=lambda r: r.dataset,
263-
deserializer=deserialize,
264257
extract_attrs=lambda r: {
265258
"output_length": int(r.inferenceAttributes.pop("outputLength", 96)),
266259
**(r.inferenceAttributes or {}),
267260
},
268261
resp_cls=TInferenceResp,
269-
single_output=False,
262+
single_batch=False,
270263
)
271264

272265
def stop(self):

iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pandas as pd
2121
import torch
2222

23+
from iotdb.ainode.core.exception import InferenceModelInternalException
2324
from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline
2425

2526

@@ -29,6 +30,12 @@ def __init__(self, model_info, **model_kwargs):
2930
super().__init__(model_info, model_kwargs=model_kwargs)
3031

3132
def preprocess(self, inputs):
33+
inputs = super().preprocess(inputs)
34+
if inputs.shape[1] != 1:
35+
raise InferenceModelInternalException(
36+
f"[Inference] Sundial model only supports single target, but receives {inputs.shape[1]} series."
37+
)
38+
inputs = inputs.squeeze(1)
3239
return inputs
3340

3441
def forecast(self, inputs, **infer_kwargs):
@@ -47,21 +54,22 @@ def forecast(self, inputs, **infer_kwargs):
4754
)
4855
output = self.model.generate(series, predict_length=predict_length)
4956
outputs.append(output)
50-
output = np.array(outputs)
57+
outputs = np.array(outputs)
5158
else:
5259
# Single sample: convert to Series
5360
if isinstance(inputs, torch.Tensor):
5461
series = pd.Series(inputs.squeeze().cpu().numpy())
5562
else:
5663
series = pd.Series(inputs.squeeze())
57-
output = self.model.generate(series, predict_length=predict_length)
64+
outputs = self.model.generate(series, predict_length=predict_length)
5865
# Add batch dimension if needed
59-
if len(output.shape) == 1:
60-
output = output[np.newaxis, :]
66+
if len(outputs.shape) == 1:
67+
outputs = outputs[np.newaxis, :]
6168

62-
return output
69+
return outputs
6370

64-
def postprocess(self, output):
65-
if isinstance(output, np.ndarray):
66-
return torch.from_numpy(output).float()
67-
return output
71+
def postprocess(self, outputs):
72+
if isinstance(outputs, np.ndarray):
73+
outputs = torch.from_numpy(outputs).float()
74+
outputs = super().postprocess(outputs.unsqueeze(1))
75+
return outputs

iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,35 @@ def __init__(self, model_info, **model_kwargs):
2727
super().__init__(model_info, model_kwargs=model_kwargs)
2828

2929
def preprocess(self, inputs):
30-
if len(inputs.shape) != 2:
30+
"""
31+
The inputs shape should be 3D, but Sundial only supports 2D tensor: [batch_size, sequence_length],
32+
we need to squeeze the target_count dimension.
33+
"""
34+
inputs = super().preprocess(inputs)
35+
if inputs.shape[1] != 1:
3136
raise InferenceModelInternalException(
32-
f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}"
37+
f"[Inference] Sundial model only supports single target, but receives {inputs.shape[1]} series."
3338
)
39+
inputs = inputs.squeeze(1)
3440
return inputs
3541

3642
def forecast(self, inputs, **infer_kwargs):
3743
predict_length = infer_kwargs.get("predict_length", 96)
3844
num_samples = infer_kwargs.get("num_samples", 10)
3945
revin = infer_kwargs.get("revin", True)
4046

41-
output = self.model.generate(
47+
outputs = self.model.generate(
4248
inputs,
4349
max_new_tokens=predict_length,
4450
num_samples=num_samples,
4551
revin=revin,
4652
)
47-
return output
48-
49-
def postprocess(self, output: torch.Tensor):
50-
return output.mean(dim=1)
53+
return outputs
54+
55+
def postprocess(self, outputs: torch.Tensor):
56+
"""
57+
The outputs shape should be 3D, we need to take the mean value across num_samples dimension and expand dims.
58+
"""
59+
outputs = outputs.mean(dim=1).unsqueeze(1)
60+
outputs = super().postprocess(outputs)
61+
return outputs

iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,30 @@ def __init__(self, model_info, **model_kwargs):
2727
super().__init__(model_info, model_kwargs=model_kwargs)
2828

2929
def preprocess(self, inputs):
30-
if len(inputs.shape) != 2:
30+
"""
31+
The inputs shape should be 3D, but Timer-XL only supports 2D tensor: [batch_size, sequence_length],
32+
we need to squeeze the target_count dimension.
33+
"""
34+
inputs = super().preprocess(inputs)
35+
if inputs.shape[1] != 1:
3136
raise InferenceModelInternalException(
32-
f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}"
37+
f"[Inference] Timer-XL model only supports single target, but receives {inputs.shape[1]} series."
3338
)
39+
inputs = inputs.squeeze(1)
3440
return inputs
3541

3642
def forecast(self, inputs, **infer_kwargs):
3743
predict_length = infer_kwargs.get("predict_length", 96)
3844
revin = infer_kwargs.get("revin", True)
3945

40-
output = self.model.generate(inputs, max_new_tokens=predict_length, revin=revin)
41-
return output
42-
43-
def postprocess(self, output: torch.Tensor):
44-
return output
46+
outputs = self.model.generate(
47+
inputs, max_new_tokens=predict_length, revin=revin
48+
)
49+
return outputs
50+
51+
def postprocess(self, outputs: torch.Tensor):
52+
"""
53+
The outputs shape should be 3D, so we need to expand dims.
54+
"""
55+
outputs = super().postprocess(outputs.unsqueeze(1))
56+
return outputs

0 commit comments

Comments
 (0)