Skip to content

Commit db973ea

Browse files
authored
[AINode] Fix call inference bug (#17011)
1 parent 92308f2 commit db973ea

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ public class AINodeCallInferenceIT {
4848

4949
private static final String CALL_INFERENCE_SQL_TEMPLATE =
5050
"CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)";
51+
private static final String CALL_INFERENCE_BY_DEFAULT_SQL_TEMPLATE =
52+
"CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT 256\")";
5153
private static final int DEFAULT_INPUT_LENGTH = 256;
5254
private static final int DEFAULT_OUTPUT_LENGTH = 48;
5355

@@ -69,6 +71,7 @@ public void callInferenceTest() throws SQLException {
6971
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
7072
Statement statement = connection.createStatement()) {
7173
callInferenceTest(statement, modelInfo);
74+
callInferenceByDefaultTest(statement, modelInfo);
7275
}
7376
}
7477
}
@@ -96,4 +99,23 @@ public static void callInferenceTest(Statement statement, AINodeTestUtils.FakeMo
9699
}
97100
}
98101
}
102+
103+
public static void callInferenceByDefaultTest(
104+
Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException {
105+
// Invoke call inference for specified models, there should exist result.
106+
for (int i = 0; i < 4; i++) {
107+
String callInferenceSQL =
108+
String.format(CALL_INFERENCE_BY_DEFAULT_SQL_TEMPLATE, modelInfo.getModelId(), i);
109+
try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) {
110+
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
111+
checkHeader(resultSetMetaData, "output");
112+
int count = 0;
113+
while (resultSet.next()) {
114+
count++;
115+
}
116+
// Ensure the call inference return results
117+
Assert.assertTrue(count > 0);
118+
}
119+
}
120+
}
99121
}

iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,11 @@ def inference(self, req: TInferenceReq):
265265
req,
266266
data_getter=lambda r: r.dataset,
267267
extract_attrs=lambda r: {
268-
"output_length": int(r.inferenceAttributes.pop("outputLength", 96)),
268+
"output_length": (
269+
96
270+
if r.inferenceAttributes is None
271+
else int(r.inferenceAttributes.pop("outputLength", 96))
272+
),
269273
**(r.inferenceAttributes or {}),
270274
},
271275
resp_cls=TInferenceResp,

0 commit comments

Comments
 (0)