Skip to content

Commit b6c13d7

Browse files
CRZbulabulaLiu Zhengyun
andauthored
[AINode] Forecast table function version2 (#16922)
--------- Co-authored-by: Liu Zhengyun <[email protected]>
1 parent 76d02a5 commit b6c13d7

File tree

11 files changed

+723
-226
lines changed

11 files changed

+723
-226
lines changed

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public class AINodeConcurrentForecastIT {
4949
private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentForecastIT.class);
5050

5151
private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
52-
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time, output_length=>%d)";
52+
"SELECT * FROM FORECAST(model_id=>'%s', targets=>(SELECT time,s FROM root.AI) ORDER BY time, output_length=>%d)";
5353

5454
@BeforeClass
5555
public static void setUp() throws Exception {

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,21 @@
3838
import java.sql.Statement;
3939

4040
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
41+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
4142

4243
@RunWith(IoTDBTestRunner.class)
4344
@Category({AIClusterIT.class})
4445
public class AINodeForecastIT {
4546

4647
private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
47-
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time, s%d FROM db.AI) ORDER BY time)";
48+
"SELECT * FROM FORECAST("
49+
+ "model_id=>'%s', "
50+
+ "targets=>(SELECT time, s%d FROM db.AI WHERE time<%d ORDER BY time DESC LIMIT %d) ORDER BY time, "
51+
+ "output_start_time=>%d, "
52+
+ "output_length=>%d, "
53+
+ "output_interval=>%d, "
54+
+ "timecol=>'%s'"
55+
+ ")";
4856

4957
@BeforeClass
5058
public static void setUp() throws Exception {
@@ -55,7 +63,7 @@ public static void setUp() throws Exception {
5563
statement.execute("CREATE DATABASE db");
5664
statement.execute(
5765
"CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)");
58-
for (int i = 0; i < 2880; i++) {
66+
for (int i = 0; i < 5760; i++) {
5967
statement.execute(
6068
String.format(
6169
"INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
@@ -81,18 +89,100 @@ public void forecastTableFunctionTest() throws SQLException {
8189

8290
public void forecastTableFunctionTest(
8391
Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException {
84-
// Invoke call inference for specified models, there should exist result.
92+
// Invoke forecast table function for specified models, there should exist result.
8593
for (int i = 0; i < 4; i++) {
8694
String forecastTableFunctionSQL =
87-
String.format(FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), i);
95+
String.format(
96+
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
97+
modelInfo.getModelId(),
98+
i,
99+
5760,
100+
2880,
101+
5760,
102+
96,
103+
1,
104+
"time");
88105
try (ResultSet resultSet = statement.executeQuery(forecastTableFunctionSQL)) {
89106
int count = 0;
90107
while (resultSet.next()) {
91108
count++;
92109
}
93-
// Ensure the call inference return results
110+
// Ensure the forecast sentence return results
94111
Assert.assertTrue(count > 0);
95112
}
96113
}
97114
}
115+
116+
@Test
117+
public void forecastTableFunctionErrorTest() throws SQLException {
118+
for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) {
119+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
120+
Statement statement = connection.createStatement()) {
121+
forecastTableFunctionErrorTest(statement, modelInfo);
122+
}
123+
}
124+
}
125+
126+
public void forecastTableFunctionErrorTest(
127+
Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException {
128+
// OUTPUT_START_TIME error
129+
String invalidOutputStartTimeSQL =
130+
String.format(
131+
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
132+
modelInfo.getModelId(),
133+
0,
134+
5760,
135+
2880,
136+
5759,
137+
96,
138+
1,
139+
"time");
140+
errorTest(
141+
statement,
142+
invalidOutputStartTimeSQL,
143+
"701: The OUTPUT_START_TIME should be greater than the maximum timestamp of target time series. Expected greater than [5759] but found [5759].");
144+
145+
// OUTPUT_LENGTH error
146+
String invalidOutputLengthSQL =
147+
String.format(
148+
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
149+
modelInfo.getModelId(),
150+
0,
151+
5760,
152+
2880,
153+
5760,
154+
0,
155+
1,
156+
"time");
157+
errorTest(statement, invalidOutputLengthSQL, "701: OUTPUT_LENGTH should be greater than 0");
158+
159+
// OUTPUT_INTERVAL error
160+
String invalidOutputIntervalSQL =
161+
String.format(
162+
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
163+
modelInfo.getModelId(),
164+
0,
165+
5760,
166+
2880,
167+
5760,
168+
96,
169+
-1,
170+
"time");
171+
errorTest(statement, invalidOutputIntervalSQL, "701: OUTPUT_INTERVAL should be greater than 0");
172+
173+
// TIMECOL error
174+
String invalidTimecolSQL2 =
175+
String.format(
176+
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
177+
modelInfo.getModelId(),
178+
0,
179+
5760,
180+
2880,
181+
5760,
182+
96,
183+
1,
184+
"s0");
185+
errorTest(
186+
statement, invalidTimecolSQL2, "701: The type of the column [s0] is not as expected.");
187+
}
98188
}

iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,11 @@ def prepare_inputs_for_generation(
610610
if attention_mask is not None and attention_mask.shape[1] > (
611611
input_ids.shape[1] // self.config.input_token_len
612612
):
613-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
613+
input_ids = input_ids[
614+
:,
615+
-(attention_mask.shape[1] - past_length)
616+
* self.config.input_token_len :,
617+
]
614618
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
615619
# input_ids based on the past_length.
616620
elif past_length < (input_ids.shape[1] // self.config.input_token_len):
@@ -623,9 +627,10 @@ def prepare_inputs_for_generation(
623627
position_ids = attention_mask.long().cumsum(-1) - 1
624628
position_ids.masked_fill_(attention_mask == 0, 1)
625629
if past_key_values:
626-
position_ids = position_ids[
627-
:, -(input_ids.shape[1] // self.config.input_token_len) :
628-
]
630+
token_num = (
631+
input_ids.shape[1] + self.config.input_token_len - 1
632+
) // self.config.input_token_len
633+
position_ids = position_ids[:, -token_num:]
629634

630635
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
631636
if inputs_embeds is not None and past_key_values is None:

iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,11 @@ def prepare_inputs_for_generation(
603603
if attention_mask is not None and attention_mask.shape[1] > (
604604
input_ids.shape[1] // self.config.input_token_len
605605
):
606-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
606+
input_ids = input_ids[
607+
:,
608+
-(attention_mask.shape[1] - past_length)
609+
* self.config.input_token_len :,
610+
]
607611
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
608612
# input_ids based on the past_length.
609613
elif past_length < (input_ids.shape[1] // self.config.input_token_len):

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2013,15 +2013,22 @@ public Operator visitStreamSort(StreamSortNode node, LocalExecutionPlanContext c
20132013

20142014
@Override
20152015
public Operator visitGroup(GroupNode node, LocalExecutionPlanContext context) {
2016-
StreamSortNode streamSortNode =
2017-
new StreamSortNode(
2018-
node.getPlanNodeId(),
2019-
node.getChild(),
2020-
node.getOrderingScheme(),
2021-
false,
2022-
false,
2023-
node.getPartitionKeyCount() - 1);
2024-
return visitStreamSort(streamSortNode, context);
2016+
if (node.getPartitionKeyCount() == 0) {
2017+
SortNode sortNode =
2018+
new SortNode(
2019+
node.getPlanNodeId(), node.getChild(), node.getOrderingScheme(), false, false);
2020+
return visitSort(sortNode, context);
2021+
} else {
2022+
StreamSortNode streamSortNode =
2023+
new StreamSortNode(
2024+
node.getPlanNodeId(),
2025+
node.getChild(),
2026+
node.getOrderingScheme(),
2027+
false,
2028+
false,
2029+
node.getPartitionKeyCount() - 1);
2030+
return visitStreamSort(streamSortNode, context);
2031+
}
20252032
}
20262033

20272034
@Override

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.iotdb.commons.udf.builtin.relational.tvf.SessionTableFunction;
2626
import org.apache.iotdb.commons.udf.builtin.relational.tvf.TumbleTableFunction;
2727
import org.apache.iotdb.commons.udf.builtin.relational.tvf.VariationTableFunction;
28+
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ClassifyTableFunction;
2829
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction;
2930
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.PatternMatchTableFunction;
3031
import org.apache.iotdb.udf.api.relational.TableFunction;
@@ -42,7 +43,8 @@ public enum TableBuiltinTableFunction {
4243
VARIATION("variation"),
4344
CAPACITY("capacity"),
4445
FORECAST("forecast"),
45-
PATTERN_MATCH("pattern_match");
46+
PATTERN_MATCH("pattern_match"),
47+
CLASSIFY("classify");
4648

4749
private final String functionName;
4850

@@ -86,6 +88,8 @@ public static TableFunction getBuiltinTableFunction(String functionName) {
8688
return new CapacityTableFunction();
8789
case "forecast":
8890
return new ForecastTableFunction();
91+
case "classify":
92+
return new ClassifyTableFunction();
8993
default:
9094
throw new UnsupportedOperationException("Unsupported table function: " + functionName);
9195
}

0 commit comments

Comments
 (0)