Skip to content

Commit 381aea8

Browse files
committed
Fix call inference cannot specify outputLength
1 parent 71daf3b commit 381aea8

File tree

9 files changed

+44
-42
lines changed

9 files changed

+44
-42
lines changed

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@
3434

3535
import java.sql.Connection;
3636
import java.sql.ResultSet;
37+
import java.sql.ResultSetMetaData;
3738
import java.sql.SQLException;
3839
import java.sql.Statement;
3940

4041
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
42+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
4143
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
4244
import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
4345

@@ -55,7 +57,8 @@ public class AINodeCallInferenceIT {
5557
};
5658

5759
private static final String CALL_INFERENCE_SQL_TEMPLATE =
58-
"CALL INFERENCE(%s, \"select s%d from root.AI\")";
60+
"CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)";
61+
private static final int DEFAULT_OUTPUT_LENGTH = 48;
5962

6063
@BeforeClass
6164
public static void setUp() throws Exception {
@@ -93,14 +96,21 @@ public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo
9396
// Invoke call inference for specified models, there should exist result.
9497
for (int i = 0; i < 4; i++) {
9598
String callInferenceSQL =
96-
String.format(CALL_INFERENCE_SQL_TEMPLATE, modelInfo.getModelId(), i);
99+
String.format(
100+
CALL_INFERENCE_SQL_TEMPLATE,
101+
modelInfo.getModelId(),
102+
i,
103+
DEFAULT_OUTPUT_LENGTH,
104+
DEFAULT_OUTPUT_LENGTH);
97105
try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) {
106+
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
107+
checkHeader(resultSetMetaData, "Time,output");
98108
int count = 0;
99109
while (resultSet.next()) {
100110
count++;
101111
}
102112
// Ensure the call inference return results
103-
Assert.assertTrue(count > 0);
113+
Assert.assertEquals(DEFAULT_OUTPUT_LENGTH, count);
104114
}
105115
}
106116
}

iotdb-core/ainode/iotdb/ainode/core/config.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
AINODE_CONF_POM_FILE_NAME,
3333
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS,
3434
AINODE_INFERENCE_EXTRA_MEMORY_RATIO,
35-
AINODE_INFERENCE_MAX_PREDICT_LENGTH,
35+
AINODE_INFERENCE_MAX_OUTPUT_LENGTH,
3636
AINODE_INFERENCE_MEMORY_USAGE_RATIO,
3737
AINODE_INFERENCE_MODEL_MEM_USAGE_MAP,
3838
AINODE_LOG_DIR,
@@ -75,9 +75,7 @@ def __init__(self):
7575
self._ain_inference_batch_interval_in_ms: int = (
7676
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS
7777
)
78-
self._ain_inference_max_predict_length: int = (
79-
AINODE_INFERENCE_MAX_PREDICT_LENGTH
80-
)
78+
self._ain_inference_max_output_length: int = AINODE_INFERENCE_MAX_OUTPUT_LENGTH
8179
self._ain_inference_model_mem_usage_map: dict[str, int] = (
8280
AINODE_INFERENCE_MODEL_MEM_USAGE_MAP
8381
)
@@ -160,13 +158,13 @@ def set_ain_inference_batch_interval_in_ms(
160158
) -> None:
161159
self._ain_inference_batch_interval_in_ms = ain_inference_batch_interval_in_ms
162160

163-
def get_ain_inference_max_predict_length(self) -> int:
164-
return self._ain_inference_max_predict_length
161+
def get_ain_inference_max_output_length(self) -> int:
162+
return self._ain_inference_max_output_length
165163

166-
def set_ain_inference_max_predict_length(
167-
self, ain_inference_max_predict_length: int
164+
def set_ain_inference_max_output_length(
165+
self, ain_inference_max_output_length: int
168166
) -> None:
169-
self._ain_inference_max_predict_length = ain_inference_max_predict_length
167+
self._ain_inference_max_output_length = ain_inference_max_output_length
170168

171169
def get_ain_inference_model_mem_usage_map(self) -> dict[str, int]:
172170
return self._ain_inference_model_mem_usage_map

iotdb-core/ainode/iotdb/ainode/core/constant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949
# AINode inference configuration
5050
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15
51-
AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880
51+
AINODE_INFERENCE_MAX_OUTPUT_LENGTH = 2880
5252

5353
# TODO: Should be optimized
5454
AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = {

iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(
3939
req_id: str,
4040
model_id: str,
4141
inputs: torch.Tensor,
42-
max_new_tokens: int = 96,
42+
output_length: int = 96,
4343
**infer_kwargs,
4444
):
4545
if inputs.ndim == 1:
@@ -49,8 +49,8 @@ def __init__(
4949
self.model_id = model_id
5050
self.inputs = inputs
5151
self.infer_kwargs = infer_kwargs
52-
self.max_new_tokens = (
53-
max_new_tokens # Number of time series data points to generate
52+
self.output_length = (
53+
output_length # Number of time series data points to generate
5454
)
5555

5656
self.batch_size = inputs.size(0)
@@ -61,7 +61,7 @@ def __init__(
6161

6262
# Preallocate output buffer [batch_size, max_new_tokens]
6363
self.output_tensor = torch.zeros(
64-
self.batch_size, max_new_tokens, device="cpu"
64+
self.batch_size, output_length, device="cpu"
6565
) # shape: [self.batch_size, max_new_steps]
6666

6767
def mark_running(self):
@@ -73,7 +73,7 @@ def mark_finished(self):
7373
def is_finished(self) -> bool:
7474
return (
7575
self.state == InferenceRequestState.FINISHED
76-
or self.cur_step_idx >= self.max_new_tokens
76+
or self.cur_step_idx >= self.output_length
7777
)
7878

7979
def write_step_output(self, step_output: torch.Tensor):
@@ -83,11 +83,11 @@ def write_step_output(self, step_output: torch.Tensor):
8383
batch_size, step_size = step_output.shape
8484
end_idx = self.cur_step_idx + step_size
8585

86-
if end_idx > self.max_new_tokens:
86+
if end_idx > self.output_length:
8787
self.output_tensor[:, self.cur_step_idx :] = step_output[
88-
:, : self.max_new_tokens - self.cur_step_idx
88+
:, : self.output_length - self.cur_step_idx
8989
]
90-
self.cur_step_idx = self.max_new_tokens
90+
self.cur_step_idx = self.output_length
9191
else:
9292
self.output_tensor[:, self.cur_step_idx : end_idx] = step_output
9393
self.cur_step_idx = end_idx

iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _step(self):
115115

116116
grouped_requests = defaultdict(list)
117117
for req in all_requests:
118-
key = (req.inputs.shape[1], req.max_new_tokens)
118+
key = (req.inputs.shape[1], req.output_length)
119119
grouped_requests[key].append(req)
120120
grouped_requests = list(grouped_requests.values())
121121

@@ -124,7 +124,7 @@ def _step(self):
124124
if isinstance(self._inference_pipeline, ForecastPipeline):
125125
batch_output = self._inference_pipeline.forecast(
126126
batch_inputs,
127-
predict_length=requests[0].max_new_tokens,
127+
predict_length=requests[0].output_length,
128128
revin=True,
129129
)
130130
elif isinstance(self._inference_pipeline, ClassificationPipeline):

iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -189,26 +189,26 @@ def _run(
189189
inputs = torch.tensor(data).unsqueeze(0).float().to("cpu")
190190

191191
inference_attrs = extract_attrs(req)
192-
predict_length = int(inference_attrs.pop("predict_length", 96))
192+
output_length = int(inference_attrs.pop("output_length", 96))
193193
if (
194-
predict_length
195-
> AINodeDescriptor().get_config().get_ain_inference_max_predict_length()
194+
output_length
195+
> AINodeDescriptor().get_config().get_ain_inference_max_output_length()
196196
):
197197
raise NumericalRangeException(
198198
"output_length",
199199
1,
200200
AINodeDescriptor()
201201
.get_config()
202-
.get_ain_inference_max_predict_length(),
203-
predict_length,
202+
.get_ain_inference_max_output_length(),
203+
output_length,
204204
)
205205

206206
if self._pool_controller.has_request_pools(model_id):
207207
infer_req = InferenceRequest(
208208
req_id=generate_req_id(),
209209
model_id=model_id,
210210
inputs=inputs,
211-
max_new_tokens=predict_length,
211+
output_length=output_length,
212212
)
213213
outputs = self._process_request(infer_req)
214214
outputs = convert_to_binary(pd.DataFrame(outputs[0]))
@@ -217,7 +217,7 @@ def _run(
217217
inference_pipeline = load_pipeline(model_info, device="cpu")
218218
if isinstance(inference_pipeline, ForecastPipeline):
219219
outputs = inference_pipeline.forecast(
220-
inputs, predict_length=predict_length, **inference_attrs
220+
inputs, predict_length=output_length, **inference_attrs
221221
)
222222
elif isinstance(inference_pipeline, ClassificationPipeline):
223223
outputs = inference_pipeline.classify(inputs)
@@ -246,7 +246,7 @@ def forecast(self, req: TForecastReq):
246246
data_getter=lambda r: r.inputData,
247247
deserializer=deserialize,
248248
extract_attrs=lambda r: {
249-
"predict_length": r.outputLength,
249+
"output_length": r.outputLength,
250250
**(r.options or {}),
251251
},
252252
resp_cls=TForecastResp,
@@ -259,8 +259,7 @@ def inference(self, req: TInferenceReq):
259259
data_getter=lambda r: r.dataset,
260260
deserializer=deserialize,
261261
extract_attrs=lambda r: {
262-
"window_interval": getattr(r.windowParams, "windowInterval", None),
263-
"window_step": getattr(r.windowParams, "windowStep", None),
262+
"output_length": int(r.inferenceAttributes.pop("outputLength", 96)),
264263
**(r.inferenceAttributes or {}),
265264
},
266265
resp_cls=TInferenceResp,

iotdb-core/ainode/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ exclude = [
7979
python = ">=3.11.0,<3.14.0"
8080

8181
# ---- DL / HF stack ----
82-
torch = ">=2.7.0"
82+
torch = "^2.7.1"
8383
torchmetrics = "^1.8.0"
8484
transformers = "==4.56.2"
8585
tokenizers = ">=0.22.0,<=0.23.0"

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ private void submitInferenceTask() {
245245
.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) {
246246
return client.inference(
247247
new TInferenceReq(
248-
modelInferenceDescriptor.getModelId(), serde.serialize(inputTsBlock)));
248+
modelInferenceDescriptor.getModelId(), serde.serialize(inputTsBlock))
249+
.setInferenceAttributes(modelInferenceDescriptor.getInferenceAttributes()));
249250
} catch (Exception e) {
250251
throw new ModelInferenceProcessException(e.getMessage());
251252
}

iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,7 @@ struct TRegisterModelResp {
6060
struct TInferenceReq {
6161
1: required string modelId
6262
2: required binary dataset
63-
3: optional TWindowParams windowParams
64-
4: optional map<string, string> inferenceAttributes
65-
}
66-
67-
struct TWindowParams {
68-
1: required i32 windowInterval
69-
2: required i32 windowStep
63+
3: optional map<string, string> inferenceAttributes
7064
}
7165

7266
struct TInferenceResp {

0 commit comments

Comments
 (0)