Skip to content

Commit ea6a2d0

Browse files
authored
Support built-in forecast table function for table model
1 parent 8439e54 commit ea6a2d0

File tree

3 files changed

+47
-9
lines changed

3 files changed

+47
-9
lines changed

iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/table/specification/ScalarParameterSpecification.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public static final class Builder {
6363
private Type type;
6464
private boolean required = true;
6565
private Object defaultValue;
66-
private List<Function<Object, String>> checkers = new ArrayList<>();
66+
private final List<Function<Object, String>> checkers = new ArrayList<>();
6767

6868
private Builder() {}
6969

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.iotdb.commons.udf.builtin.relational.tvf.SessionTableFunction;
2626
import org.apache.iotdb.commons.udf.builtin.relational.tvf.TumbleTableFunction;
2727
import org.apache.iotdb.commons.udf.builtin.relational.tvf.VariationTableFunction;
28+
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction;
2829
import org.apache.iotdb.udf.api.relational.TableFunction;
2930

3031
import java.util.Arrays;
@@ -38,8 +39,8 @@ public enum TableBuiltinTableFunction {
3839
CUMULATE("cumulate"),
3940
SESSION("session"),
4041
VARIATION("variation"),
41-
CAPACITY("capacity");
42-
// FORECAST("forecast");
42+
CAPACITY("capacity"),
43+
FORECAST("forecast");
4344

4445
private final String functionName;
4546

@@ -79,8 +80,8 @@ public static TableFunction getBuiltinTableFunction(String functionName) {
7980
return new VariationTableFunction();
8081
case "capacity":
8182
return new CapacityTableFunction();
82-
// case "forecast":
83-
// return new ForecastTableFunction();
83+
case "forecast":
84+
return new ForecastTableFunction();
8485
default:
8586
throw new UnsupportedOperationException("Unsupported table function: " + functionName);
8687
}

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ private static class ForecastTableFunctionHandle implements TableFunctionHandle
7878
String modelId;
7979
int maxInputLength;
8080
int outputLength;
81+
long outputStartTime;
82+
long outputInterval;
8183
boolean keepInput;
8284
Map<String, String> options;
8385
List<Type> types;
@@ -90,13 +92,17 @@ public ForecastTableFunctionHandle(
9092
String modelId,
9193
Map<String, String> options,
9294
int outputLength,
95+
long outputStartTime,
96+
long outputInterval,
9397
TEndPoint targetAINode,
9498
List<Type> types) {
9599
this.keepInput = keepInput;
96100
this.maxInputLength = maxInputLength;
97101
this.modelId = modelId;
98102
this.options = options;
99103
this.outputLength = outputLength;
104+
this.outputStartTime = outputStartTime;
105+
this.outputInterval = outputInterval;
100106
this.targetAINode = targetAINode;
101107
this.types = types;
102108
}
@@ -110,6 +116,8 @@ public byte[] serialize() {
110116
ReadWriteIOUtils.write(modelId, outputStream);
111117
ReadWriteIOUtils.write(maxInputLength, outputStream);
112118
ReadWriteIOUtils.write(outputLength, outputStream);
119+
ReadWriteIOUtils.write(outputStartTime, outputStream);
120+
ReadWriteIOUtils.write(outputInterval, outputStream);
113121
ReadWriteIOUtils.write(keepInput, outputStream);
114122
ReadWriteIOUtils.write(options, outputStream);
115123
ReadWriteIOUtils.write(types.size(), outputStream);
@@ -134,6 +142,8 @@ public void deserialize(byte[] bytes) {
134142
this.modelId = ReadWriteIOUtils.readString(buffer);
135143
this.maxInputLength = ReadWriteIOUtils.readInt(buffer);
136144
this.outputLength = ReadWriteIOUtils.readInt(buffer);
145+
this.outputStartTime = ReadWriteIOUtils.readLong(buffer);
146+
this.outputInterval = ReadWriteIOUtils.readLong(buffer);
137147
this.keepInput = ReadWriteIOUtils.readBoolean(buffer);
138148
this.options = ReadWriteIOUtils.readMap(buffer);
139149
int size = ReadWriteIOUtils.readInt(buffer);
@@ -152,6 +162,10 @@ public void deserialize(byte[] bytes) {
152162
private static final int DEFAULT_OUTPUT_LENGTH = 96;
153163
private static final String PREDICATED_COLUMNS_PARAMETER_NAME = "PREDICATED_COLUMNS";
154164
private static final String DEFAULT_PREDICATED_COLUMNS = "";
165+
private static final String OUTPUT_START_TIME = "OUTPUT_START_TIME";
166+
private static final long DEFAULT_OUTPUT_START_TIME = Long.MIN_VALUE;
167+
private static final String OUTPUT_INTERVAL = "OUTPUT_INTERVAL";
168+
private static final long DEFAULT_OUTPUT_INTERVAL = 0L;
155169
private static final String TIMECOL_PARAMETER_NAME = "TIMECOL";
156170
private static final String DEFAULT_TIME_COL = "time";
157171
private static final String KEEP_INPUT_PARAMETER_NAME = "KEEP_INPUT";
@@ -184,6 +198,16 @@ public List<ParameterSpecification> getArgumentsSpecifications() {
184198
.type(Type.INT32)
185199
.defaultValue(DEFAULT_OUTPUT_LENGTH)
186200
.build(),
201+
ScalarParameterSpecification.builder()
202+
.name(OUTPUT_START_TIME)
203+
.type(Type.TIMESTAMP)
204+
.defaultValue(DEFAULT_OUTPUT_START_TIME)
205+
.build(),
206+
ScalarParameterSpecification.builder()
207+
.name(OUTPUT_INTERVAL)
208+
.type(Type.INT64)
209+
.defaultValue(DEFAULT_OUTPUT_INTERVAL)
210+
.build(),
187211
ScalarParameterSpecification.builder()
188212
.name(PREDICATED_COLUMNS_PARAMETER_NAME)
189213
.type(Type.STRING)
@@ -307,6 +331,8 @@ public TableFunctionAnalysis analyze(Map<String, Argument> arguments) {
307331
properColumnSchemaBuilder.addField(IS_INPUT_COLUMN_NAME, Type.BOOLEAN);
308332
}
309333

334+
long outputStartTime = (long) ((ScalarArgument) arguments.get(OUTPUT_START_TIME)).getValue();
335+
long outputInterval = (long) ((ScalarArgument) arguments.get(OUTPUT_INTERVAL)).getValue();
310336
String options = (String) ((ScalarArgument) arguments.get(OPTIONS_PARAMETER_NAME)).getValue();
311337

312338
ForecastTableFunctionHandle functionHandle =
@@ -316,6 +342,8 @@ public TableFunctionAnalysis analyze(Map<String, Argument> arguments) {
316342
modelId,
317343
parseOptions(options),
318344
outputLength,
345+
outputStartTime,
346+
outputInterval,
319347
targetAINode,
320348
predicatedColumnTypes);
321349

@@ -389,6 +417,8 @@ private static class ForecastDataProcessor implements TableFunctionDataProcessor
389417
private final String modelId;
390418
private final int maxInputLength;
391419
private final int outputLength;
420+
private final long outputStartTime;
421+
private final long outputInterval;
392422
private final boolean keepInput;
393423
private final Map<String, String> options;
394424
private final LinkedList<Record> inputRecords;
@@ -400,6 +430,8 @@ public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) {
400430
this.modelId = functionHandle.modelId;
401431
this.maxInputLength = functionHandle.maxInputLength;
402432
this.outputLength = functionHandle.outputLength;
433+
this.outputStartTime = functionHandle.outputStartTime;
434+
this.outputInterval = functionHandle.outputInterval;
403435
this.keepInput = functionHandle.keepInput;
404436
this.options = functionHandle.options;
405437
this.inputRecords = new LinkedList<>();
@@ -467,11 +499,16 @@ public void finish(
467499
int columnSize = properColumnBuilders.size();
468500

469501
// time column
470-
long startTime = inputRecords.getFirst().getLong(0);
471-
long endTime = inputRecords.getLast().getLong(0);
472-
long interval = (endTime - startTime) / inputRecords.size();
502+
long inputStartTime = inputRecords.getFirst().getLong(0);
503+
long inputEndTime = inputRecords.getLast().getLong(0);
504+
long interval =
505+
outputInterval <= 0
506+
? (inputEndTime - inputStartTime) / inputRecords.size()
507+
: outputInterval;
508+
long outputTime =
509+
(outputStartTime == Long.MIN_VALUE) ? (inputEndTime + interval) : outputStartTime;
473510
for (int i = 0; i < outputLength; i++) {
474-
properColumnBuilders.get(0).writeLong(endTime + interval * (i + 1));
511+
properColumnBuilders.get(0).writeLong(outputTime + interval * i);
475512
}
476513

477514
// predicated columns

0 commit comments

Comments
 (0)