Skip to content

Commit 7487f4c

Browse files
ycycseCRZbulabula
authored andcommitted
Support built-in forecast function through UDTF for tree model (#15682)
(cherry picked from commit c158092)
1 parent 8f6d4b5 commit 7487f4c

File tree

16 files changed

+87
-33
lines changed

16 files changed

+87
-33
lines changed

integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,11 @@ public String getFunctionName() {
5656

5757
private static final Set<String> NATIVE_FUNCTION_NAMES =
5858
new HashSet<>(
59-
Arrays.stream(org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction.values())
60-
.map(org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction::getFunctionName)
59+
Arrays.stream(
60+
org.apache.iotdb.db.queryengine.plan.udf.BuiltinAggregationFunction.values())
61+
.map(
62+
org.apache.iotdb.db.queryengine.plan.udf.BuiltinAggregationFunction
63+
::getFunctionName)
6164
.collect(Collectors.toList()));
6265

6366
public static Set<String> getNativeFunctionNames() {

integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinTimeSeriesGeneratingFunctionEnum.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ public enum BuiltinTimeSeriesGeneratingFunctionEnum {
7474
EQUAL_SIZE_BUCKET_OUTLIER_SAMPLE("EQUAL_SIZE_BUCKET_OUTLIER_SAMPLE"),
7575
JEXL("JEXL"),
7676
MASTER_REPAIR("MASTER_REPAIR"),
77+
FORECAST("FORECAST"),
7778
M4("M4");
7879

7980
private final String functionName;

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public static void tearDown() throws Exception {
6565

6666
private static void prepareDataForTreeModel() throws SQLException {
6767
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
68-
Statement statement = connection.createStatement()) {
68+
Statement statement = connection.createStatement()) {
6969
statement.execute("CREATE DATABASE root.AI");
7070
statement.execute("CREATE TIMESERIES root.AI.s WITH DATATYPE=DOUBLE, ENCODING=RLE");
7171
for (int i = 0; i < 2880; i++) {
@@ -79,7 +79,7 @@ private static void prepareDataForTreeModel() throws SQLException {
7979

8080
private static void prepareDataForTableModel() throws SQLException {
8181
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
82-
Statement statement = connection.createStatement()) {
82+
Statement statement = connection.createStatement()) {
8383
statement.execute("CREATE DATABASE root");
8484
statement.execute("CREATE TABLE root.AI (s DOUBLE FIELD)");
8585
for (int i = 0; i < 2880; i++) {
@@ -99,7 +99,7 @@ public void concurrentGPUCallInferenceTest() throws SQLException, InterruptedExc
9999
private void concurrentGPUCallInferenceTest(String modelId)
100100
throws SQLException, InterruptedException {
101101
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
102-
Statement statement = connection.createStatement()) {
102+
Statement statement = connection.createStatement()) {
103103
final int threadCnt = 10;
104104
final int loop = 100;
105105
final int predictLength = 512;
@@ -118,15 +118,23 @@ private void concurrentGPUCallInferenceTest(String modelId)
118118
}
119119
}
120120

121+
String forecastTableFunctionSql =
122+
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d";
123+
String forecastUDTFSql =
124+
"SELECT forecast(s, 'MODEL_ID'='%s', 'PREDICT_LENGTH'='%d') FROM root.AI";
125+
121126
@Test
122127
public void concurrentGPUForecastTest() throws SQLException, InterruptedException {
123-
concurrentGPUForecastTest("timer_xl");
124-
concurrentGPUForecastTest("sundial");
128+
concurrentGPUForecastTest("timer_xl", forecastUDTFSql);
129+
concurrentGPUForecastTest("sundial", forecastUDTFSql);
130+
concurrentGPUForecastTest("timer_xl", forecastTableFunctionSql);
131+
concurrentGPUForecastTest("sundial", forecastTableFunctionSql);
125132
}
126133

127-
public void concurrentGPUForecastTest(String modelId) throws SQLException, InterruptedException {
134+
public void concurrentGPUForecastTest(String modelId, String selectSql)
135+
throws SQLException, InterruptedException {
128136
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
129-
Statement statement = connection.createStatement()) {
137+
Statement statement = connection.createStatement()) {
130138
final int threadCnt = 10;
131139
final int loop = 100;
132140
final int predictLength = 512;
@@ -136,9 +144,7 @@ public void concurrentGPUForecastTest(String modelId) throws SQLException, Inter
136144
long startTime = System.currentTimeMillis();
137145
concurrentInference(
138146
statement,
139-
String.format(
140-
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d",
141-
modelId, predictLength),
147+
String.format(selectSql, modelId, predictLength),
142148
threadCnt,
143149
loop,
144150
predictLength);
@@ -158,7 +164,7 @@ private void checkModelOnSpecifiedDevice(Statement statement, String modelId, St
158164
for (int retry = 0; retry < 200; retry++) {
159165
Set<String> foundDevices = new HashSet<>();
160166
try (final ResultSet resultSet =
161-
statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
167+
statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
162168
while (resultSet.next()) {
163169
String deviceId = resultSet.getString("DeviceId");
164170
String loadedModelId = resultSet.getString("ModelId");

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/AggregationUtil.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
package org.apache.iotdb.db.queryengine.execution.operator;
2121

22-
import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
2322
import org.apache.iotdb.db.queryengine.execution.aggregation.TreeAggregator;
2423
import org.apache.iotdb.db.queryengine.execution.aggregation.timerangeiterator.ITimeRangeIterator;
2524
import org.apache.iotdb.db.queryengine.execution.aggregation.timerangeiterator.SingleTimeWindowIterator;
@@ -29,6 +28,7 @@
2928
import org.apache.iotdb.db.queryengine.plan.analyze.TypeProvider;
3029
import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.AggregationDescriptor;
3130
import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.GroupByTimeParameter;
31+
import org.apache.iotdb.db.queryengine.plan.udf.BuiltinAggregationFunction;
3232
import org.apache.iotdb.db.queryengine.statistics.StatisticsManager;
3333

3434
import org.apache.tsfile.block.column.Column;

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionAnalyzer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
import org.apache.iotdb.commons.path.PartialPath;
2525
import org.apache.iotdb.commons.path.PathPatternTree;
2626
import org.apache.iotdb.commons.schema.column.ColumnHeader;
27-
import org.apache.iotdb.commons.udf.builtin.BuiltinScalarFunction;
28-
import org.apache.iotdb.commons.udf.builtin.BuiltinTimeSeriesGeneratingFunction;
2927
import org.apache.iotdb.db.exception.sql.SemanticException;
3028
import org.apache.iotdb.db.queryengine.common.MPPQueryContext;
3129
import org.apache.iotdb.db.queryengine.common.schematree.ISchemaTree;
@@ -56,6 +54,8 @@
5654
import org.apache.iotdb.db.queryengine.plan.expression.visitor.cartesian.ConcatDeviceAndBindSchemaForPredicateVisitor;
5755
import org.apache.iotdb.db.queryengine.plan.expression.visitor.cartesian.ConcatExpressionWithSuffixPathsVisitor;
5856
import org.apache.iotdb.db.queryengine.plan.statement.component.ResultColumn;
57+
import org.apache.iotdb.db.queryengine.plan.udf.BuiltinScalarFunction;
58+
import org.apache.iotdb.db.queryengine.plan.udf.BuiltinTimeSeriesGeneratingFunction;
5959
import org.apache.iotdb.db.utils.constant.SqlConstant;
6060

6161
import org.apache.tsfile.common.constant.TsFileConstant;

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowFunctionsTask.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
import org.apache.iotdb.commons.schema.column.ColumnHeaderConstant;
2525
import org.apache.iotdb.commons.udf.UDFInformation;
2626
import org.apache.iotdb.commons.udf.UDFType;
27-
import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
28-
import org.apache.iotdb.commons.udf.builtin.BuiltinScalarFunction;
29-
import org.apache.iotdb.commons.udf.builtin.BuiltinTimeSeriesGeneratingFunction;
3027
import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinAggregationFunction;
3128
import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinScalarFunction;
3229
import org.apache.iotdb.db.queryengine.common.header.DatasetHeader;
@@ -35,6 +32,9 @@
3532
import org.apache.iotdb.db.queryengine.plan.execution.config.IConfigTask;
3633
import org.apache.iotdb.db.queryengine.plan.execution.config.executor.IConfigTaskExecutor;
3734
import org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction;
35+
import org.apache.iotdb.db.queryengine.plan.udf.BuiltinAggregationFunction;
36+
import org.apache.iotdb.db.queryengine.plan.udf.BuiltinScalarFunction;
37+
import org.apache.iotdb.db.queryengine.plan.udf.BuiltinTimeSeriesGeneratingFunction;
3838
import org.apache.iotdb.db.queryengine.plan.udf.TableUDFUtils;
3939
import org.apache.iotdb.db.queryengine.plan.udf.TreeUDFUtils;
4040
import org.apache.iotdb.rpc.TSStatusCode;

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/ExpressionFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import org.apache.iotdb.commons.exception.IllegalPathException;
2424
import org.apache.iotdb.commons.path.MeasurementPath;
2525
import org.apache.iotdb.commons.path.PartialPath;
26-
import org.apache.iotdb.commons.udf.builtin.BuiltinTimeSeriesGeneratingFunction;
2726
import org.apache.iotdb.db.queryengine.plan.expression.binary.AdditionExpression;
2827
import org.apache.iotdb.db.queryengine.plan.expression.binary.EqualToExpression;
2928
import org.apache.iotdb.db.queryengine.plan.expression.binary.GreaterEqualExpression;
@@ -46,6 +45,7 @@
4645
import org.apache.iotdb.db.queryengine.plan.expression.unary.LogicNotExpression;
4746
import org.apache.iotdb.db.queryengine.plan.expression.unary.RegularExpression;
4847
import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.GroupByTimeParameter;
48+
import org.apache.iotdb.db.queryengine.plan.udf.BuiltinTimeSeriesGeneratingFunction;
4949

5050
import org.apache.tsfile.enums.TSDataType;
5151
import org.apache.tsfile.utils.TimeDuration;

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/multi/FunctionExpression.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121

2222
import org.apache.iotdb.commons.conf.IoTDBConstant;
2323
import org.apache.iotdb.commons.path.PartialPath;
24-
import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
25-
import org.apache.iotdb.commons.udf.builtin.BuiltinScalarFunction;
2624
import org.apache.iotdb.db.queryengine.common.NodeRef;
2725
import org.apache.iotdb.db.queryengine.execution.MemoryEstimationHelper;
2826
import org.apache.iotdb.db.queryengine.plan.expression.Expression;
@@ -31,6 +29,8 @@
3129
import org.apache.iotdb.db.queryengine.plan.expression.multi.builtin.BuiltInScalarFunctionHelperFactory;
3230
import org.apache.iotdb.db.queryengine.plan.expression.visitor.ExpressionVisitor;
3331
import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.InputLocation;
32+
import org.apache.iotdb.db.queryengine.plan.udf.BuiltinAggregationFunction;
33+
import org.apache.iotdb.db.queryengine.plan.udf.BuiltinScalarFunction;
3434
import org.apache.iotdb.db.queryengine.plan.udf.TreeUDFUtils;
3535
import org.apache.iotdb.db.queryengine.transformation.dag.memory.LayerMemoryAssigner;
3636
import org.apache.iotdb.db.queryengine.transformation.dag.udf.UDTFExecutor;

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/optimization/AggregationPushDown.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.apache.iotdb.commons.path.MeasurementPath;
2626
import org.apache.iotdb.commons.path.PartialPath;
2727
import org.apache.iotdb.commons.schema.column.ColumnHeaderConstant;
28-
import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
2928
import org.apache.iotdb.db.queryengine.common.MPPQueryContext;
3029
import org.apache.iotdb.db.queryengine.execution.MemoryEstimationHelper;
3130
import org.apache.iotdb.db.queryengine.plan.analyze.Analysis;
@@ -57,6 +56,7 @@
5756
import org.apache.iotdb.db.queryengine.plan.statement.StatementType;
5857
import org.apache.iotdb.db.queryengine.plan.statement.component.Ordering;
5958
import org.apache.iotdb.db.queryengine.plan.statement.crud.QueryStatement;
59+
import org.apache.iotdb.db.queryengine.plan.udf.BuiltinAggregationFunction;
6060
import org.apache.iotdb.db.schemaengine.schemaregion.utils.MetaUtils;
6161
import org.apache.iotdb.db.utils.SchemaUtils;
6262

iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java renamed to iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/BuiltinAggregationFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
* under the License.
1818
*/
1919

20-
package org.apache.iotdb.commons.udf.builtin;
20+
package org.apache.iotdb.db.queryengine.plan.udf;
2121

2222
import java.util.Arrays;
2323
import java.util.HashSet;

0 commit comments

Comments
 (0)