Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion iotdb-core/ainode/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
/iotdb/thrift/
/iotdb/tsfile/
/iotdb/utils/
/iotdb/__init__.py
/iotdb/Session.py
/iotdb/SessionPool.py
/iotdb/table_session.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ public static class ForecastTableFunctionHandle implements TableFunctionHandle {
long outputInterval;
boolean keepInput;
Map<String, String> options;
List<Type> types;
List<Type> inputColumnTypes;
List<Type> predicatedColumnTypes;

public ForecastTableFunctionHandle() {}

Expand All @@ -97,7 +98,8 @@ public ForecastTableFunctionHandle(
long outputStartTime,
long outputInterval,
TEndPoint targetAINode,
List<Type> types) {
List<Type> inputColumnTypes,
List<Type> predicatedColumnTypes) {
this.keepInput = keepInput;
this.maxInputLength = maxInputLength;
this.modelId = modelId;
Expand All @@ -106,7 +108,8 @@ public ForecastTableFunctionHandle(
this.outputStartTime = outputStartTime;
this.outputInterval = outputInterval;
this.targetAINode = targetAINode;
this.types = types;
this.inputColumnTypes = inputColumnTypes;
this.predicatedColumnTypes = predicatedColumnTypes;
}

@Override
Expand All @@ -122,8 +125,12 @@ public byte[] serialize() {
ReadWriteIOUtils.write(outputInterval, outputStream);
ReadWriteIOUtils.write(keepInput, outputStream);
ReadWriteIOUtils.write(options, outputStream);
ReadWriteIOUtils.write(types.size(), outputStream);
for (Type type : types) {
ReadWriteIOUtils.write(inputColumnTypes.size(), outputStream);
for (Type type : inputColumnTypes) {
ReadWriteIOUtils.write(type.getType(), outputStream);
}
ReadWriteIOUtils.write(predicatedColumnTypes.size(), outputStream);
for (Type type : predicatedColumnTypes) {
ReadWriteIOUtils.write(type.getType(), outputStream);
}
outputStream.flush();
Expand All @@ -149,9 +156,14 @@ public void deserialize(byte[] bytes) {
this.keepInput = ReadWriteIOUtils.readBoolean(buffer);
this.options = ReadWriteIOUtils.readMap(buffer);
int size = ReadWriteIOUtils.readInt(buffer);
this.types = new ArrayList<>(size);
this.inputColumnTypes = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
inputColumnTypes.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer)));
}
size = ReadWriteIOUtils.readInt(buffer);
this.predicatedColumnTypes = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
types.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer)));
predicatedColumnTypes.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer)));
}
}

Expand All @@ -172,7 +184,8 @@ public boolean equals(Object o) {
&& Objects.equals(targetAINode, that.targetAINode)
&& Objects.equals(modelId, that.modelId)
&& Objects.equals(options, that.options)
&& Objects.equals(types, that.types);
&& Objects.equals(inputColumnTypes, that.inputColumnTypes)
&& Objects.equals(predicatedColumnTypes, that.predicatedColumnTypes);
}

@Override
Expand All @@ -186,7 +199,8 @@ public int hashCode() {
outputInterval,
keepInput,
options,
types);
inputColumnTypes,
predicatedColumnTypes);
}
}

Expand Down Expand Up @@ -319,20 +333,29 @@ public TableFunctionAnalysis analyze(Map<String, Argument> arguments) {
DescribedSchema.Builder properColumnSchemaBuilder =
new DescribedSchema.Builder().addField(timeColumn, Type.TIMESTAMP);

List<Type> inputColumnTypes = new ArrayList<>();
List<Type> predicatedColumnTypes = new ArrayList<>();
List<Optional<String>> allInputColumnsName = input.getFieldNames();
List<Type> allInputColumnsType = input.getFieldTypes();
for (int i = 0, size = allInputColumnsName.size(); i < size; i++) {
Optional<String> fieldName = allInputColumnsName.get(i);
// All input value columns are required for model forecasting
if (!fieldName.isPresent()
|| !excludedColumns.contains(fieldName.get().toLowerCase(Locale.ENGLISH))) {
inputColumnTypes.add(allInputColumnsType.get(i));
requiredIndexList.add(i);
}
}
if (predicatedColumns.isEmpty()) {
// predicated columns by default include all columns from input table except for timecol and
// partition by columns
// predicated columns by default include all columns from input table except for
// timecol and partition by columns
for (int i = 0, size = allInputColumnsName.size(); i < size; i++) {
Optional<String> fieldName = allInputColumnsName.get(i);
if (!fieldName.isPresent()
|| !excludedColumns.contains(fieldName.get().toLowerCase(Locale.ENGLISH))) {
Type columnType = allInputColumnsType.get(i);
predicatedColumnTypes.add(columnType);
checkType(columnType, fieldName.orElse(""));
requiredIndexList.add(i);
properColumnSchemaBuilder.addField(fieldName, columnType);
}
}
Expand All @@ -347,7 +370,7 @@ public TableFunctionAnalysis analyze(Map<String, Argument> arguments) {
inputColumnIndexMap.put(fieldName.get().toLowerCase(Locale.ENGLISH), i);
}

Set<Integer> requiredIndexSet = new HashSet<>(predictedColumnsArray.length);
Set<Integer> predicatedIndexSet = new HashSet<>(predictedColumnsArray.length);
// columns need to be predicated
for (String outputColumn : predictedColumnsArray) {
String lowerCaseOutputColumn = outputColumn.toLowerCase(Locale.ENGLISH);
Expand All @@ -360,14 +383,13 @@ public TableFunctionAnalysis analyze(Map<String, Argument> arguments) {
throw new SemanticException(
String.format("Column %s don't exist in input", outputColumn));
}
if (!requiredIndexSet.add(inputColumnIndex)) {
if (!predicatedIndexSet.add(inputColumnIndex)) {
throw new SemanticException(String.format("Duplicate column %s", outputColumn));
}

Type columnType = allInputColumnsType.get(inputColumnIndex);
predicatedColumnTypes.add(columnType);
checkType(columnType, outputColumn);
requiredIndexList.add(inputColumnIndex);
properColumnSchemaBuilder.addField(outputColumn, columnType);
}
}
Expand All @@ -392,6 +414,7 @@ public TableFunctionAnalysis analyze(Map<String, Argument> arguments) {
outputStartTime,
outputInterval,
targetAINode,
inputColumnTypes,
predicatedColumnTypes);

// outputColumnSchema
Expand Down Expand Up @@ -469,8 +492,9 @@ private static class ForecastDataProcessor implements TableFunctionDataProcessor
private final boolean keepInput;
private final Map<String, String> options;
private final LinkedList<Record> inputRecords;
private final List<ResultColumnAppender> resultColumnAppenderList;
private final TsBlockBuilder inputTsBlockBuilder;
private final List<ResultColumnAppender> inputColumnAppenderList;
private final List<ResultColumnAppender> resultColumnAppenderList;

public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) {
this.targetAINode = functionHandle.targetAINode;
Expand All @@ -482,14 +506,22 @@ public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) {
this.keepInput = functionHandle.keepInput;
this.options = functionHandle.options;
this.inputRecords = new LinkedList<>();
this.resultColumnAppenderList = new ArrayList<>(functionHandle.types.size());
List<TSDataType> tsDataTypeList = new ArrayList<>(functionHandle.types.size());
for (Type type : functionHandle.types) {
List<TSDataType> inputTsDataTypeList =
new ArrayList<>(functionHandle.inputColumnTypes.size());
for (Type type : functionHandle.inputColumnTypes) {
// AINode currently only accept double input
inputTsDataTypeList.add(TSDataType.DOUBLE);
}
this.inputTsBlockBuilder = new TsBlockBuilder(inputTsDataTypeList);
this.inputColumnAppenderList = new ArrayList<>(functionHandle.inputColumnTypes.size());
for (Type type : functionHandle.inputColumnTypes) {
Comment on lines +511 to +517
Copy link

Copilot AI Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable 'Type type' is never read.

Suggested change
for (Type type : functionHandle.inputColumnTypes) {
// AINode currently only accept double input
inputTsDataTypeList.add(TSDataType.DOUBLE);
}
this.inputTsBlockBuilder = new TsBlockBuilder(inputTsDataTypeList);
this.inputColumnAppenderList = new ArrayList<>(functionHandle.inputColumnTypes.size());
for (Type type : functionHandle.inputColumnTypes) {
for (int i = 0; i < functionHandle.inputColumnTypes.size(); i++) {
// AINode currently only accept double input
inputTsDataTypeList.add(TSDataType.DOUBLE);
}
this.inputTsBlockBuilder = new TsBlockBuilder(inputTsDataTypeList);
this.inputColumnAppenderList = new ArrayList<>(functionHandle.inputColumnTypes.size());
for (int i = 0; i < functionHandle.inputColumnTypes.size(); i++) {

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable 'Type type' is never read.

Suggested change
for (Type type : functionHandle.inputColumnTypes) {
for (int i = 0; i < functionHandle.inputColumnTypes.size(); i++) {

Copilot uses AI. Check for mistakes.
// AINode currently only accept double input
inputColumnAppenderList.add(createResultColumnAppender(Type.DOUBLE));
}
this.resultColumnAppenderList = new ArrayList<>(functionHandle.predicatedColumnTypes.size());
for (Type type : functionHandle.predicatedColumnTypes) {
resultColumnAppenderList.add(createResultColumnAppender(type));
// ainode currently only accept double input
tsDataTypeList.add(TSDataType.DOUBLE);
}
this.inputTsBlockBuilder = new TsBlockBuilder(tsDataTypeList);
}

private static ResultColumnAppender createResultColumnAppender(Type type) {
Expand Down Expand Up @@ -613,7 +645,7 @@ private TsBlock forecast() {
// need to transform other types to DOUBLE
inputTsBlockBuilder
.getColumnBuilder(i - 1)
.writeDouble(resultColumnAppenderList.get(i - 1).getDouble(row, i));
.writeDouble(inputColumnAppenderList.get(i - 1).getDouble(row, i));
}
}
inputTsBlockBuilder.declarePosition();
Expand All @@ -634,15 +666,7 @@ private TsBlock forecast() {
throw new IoTDBRuntimeException(message, resp.getStatus().getCode());
}

TsBlock res = SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
if (res.getValueColumnCount() != inputData.getValueColumnCount()) {
throw new IoTDBRuntimeException(
String.format(
"Model %s output %s columns, doesn't equal to specified %s",
modelId, res.getValueColumnCount(), inputData.getValueColumnCount()),
TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
}
return res;
return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,8 @@ public void beforeStart(UDFParameters parameters, UDTFConfigurations configurati
Arrays.stream(
parameters.getStringOrDefault(OPTIONS_PARAMETER_NAME, DEFAULT_OPTIONS).split(","))
.map(s -> s.split("="))
.filter(arr -> arr.length == 2 && !arr[0].isEmpty()) // 防御性检查
.collect(
Collectors.toMap(
arr -> arr[0].trim(), arr -> arr[1].trim(), (v1, v2) -> v2 // 如果 key 重复,保留后一个
));
.filter(arr -> arr.length == 2 && !arr[0].isEmpty())
.collect(Collectors.toMap(arr -> arr[0].trim(), arr -> arr[1].trim(), (v1, v2) -> v2));
this.inputRows = new LinkedList<>();
List<TSDataType> tsDataTypeList = new ArrayList<>(this.types.size() - 1);
for (int i = 0; i < this.types.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ public void testForecastFunction() {
DEFAULT_OUTPUT_START_TIME,
DEFAULT_OUTPUT_INTERVAL,
new TEndPoint("127.0.0.1", 10810),
Collections.singletonList(DOUBLE),
Collections.singletonList(DOUBLE)));
// Verify full LogicalPlan
// Output - TableFunctionProcessor - TableScan
Expand Down Expand Up @@ -440,6 +441,7 @@ public void testForecastFunctionWithNoLowerCase() {
DEFAULT_OUTPUT_START_TIME,
DEFAULT_OUTPUT_INTERVAL,
new TEndPoint("127.0.0.1", 10810),
Collections.singletonList(DOUBLE),
Collections.singletonList(DOUBLE)));
// Verify full LogicalPlan
// Output - TableFunctionProcessor - TableScan
Expand Down
Loading