Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion iotdb-core/ainode/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
/iotdb/thrift/
/iotdb/tsfile/
/iotdb/utils/
/iotdb/__init__.py
/iotdb/Session.py
/iotdb/SessionPool.py
/iotdb/table_session.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ public static class ForecastTableFunctionHandle implements TableFunctionHandle {
long outputInterval;
boolean keepInput;
Map<String, String> options;
List<Type> types;
List<Type> inputColumnTypes;
List<Type> predicatedColumnTypes;

public ForecastTableFunctionHandle() {}

Expand All @@ -97,7 +98,8 @@ public ForecastTableFunctionHandle(
long outputStartTime,
long outputInterval,
TEndPoint targetAINode,
List<Type> types) {
List<Type> inputColumnTypes,
List<Type> predicatedColumnTypes) {
this.keepInput = keepInput;
this.maxInputLength = maxInputLength;
this.modelId = modelId;
Expand All @@ -106,7 +108,8 @@ public ForecastTableFunctionHandle(
this.outputStartTime = outputStartTime;
this.outputInterval = outputInterval;
this.targetAINode = targetAINode;
this.types = types;
this.inputColumnTypes = inputColumnTypes;
this.predicatedColumnTypes = predicatedColumnTypes;
}

@Override
Expand All @@ -122,8 +125,12 @@ public byte[] serialize() {
ReadWriteIOUtils.write(outputInterval, outputStream);
ReadWriteIOUtils.write(keepInput, outputStream);
ReadWriteIOUtils.write(options, outputStream);
ReadWriteIOUtils.write(types.size(), outputStream);
for (Type type : types) {
ReadWriteIOUtils.write(inputColumnTypes.size(), outputStream);
for (Type type : inputColumnTypes) {
ReadWriteIOUtils.write(type.getType(), outputStream);
}
ReadWriteIOUtils.write(predicatedColumnTypes.size(), outputStream);
for (Type type : predicatedColumnTypes) {
ReadWriteIOUtils.write(type.getType(), outputStream);
}
outputStream.flush();
Expand All @@ -149,9 +156,14 @@ public void deserialize(byte[] bytes) {
this.keepInput = ReadWriteIOUtils.readBoolean(buffer);
this.options = ReadWriteIOUtils.readMap(buffer);
int size = ReadWriteIOUtils.readInt(buffer);
this.types = new ArrayList<>(size);
this.inputColumnTypes = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
inputColumnTypes.add(Type.valueOf(ReadWriteIOUtils.readString(buffer)));
}
size = ReadWriteIOUtils.readInt(buffer);
this.predicatedColumnTypes = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
types.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer)));
predicatedColumnTypes.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer)));
}
}

Expand All @@ -172,7 +184,8 @@ public boolean equals(Object o) {
&& Objects.equals(targetAINode, that.targetAINode)
&& Objects.equals(modelId, that.modelId)
&& Objects.equals(options, that.options)
&& Objects.equals(types, that.types);
&& Objects.equals(inputColumnTypes, that.inputColumnTypes)
&& Objects.equals(predicatedColumnTypes, that.predicatedColumnTypes);
}

@Override
Expand All @@ -186,7 +199,8 @@ public int hashCode() {
outputInterval,
keepInput,
options,
types);
inputColumnTypes,
predicatedColumnTypes);
}
}

Expand Down Expand Up @@ -319,20 +333,29 @@ public TableFunctionAnalysis analyze(Map<String, Argument> arguments) {
DescribedSchema.Builder properColumnSchemaBuilder =
new DescribedSchema.Builder().addField(timeColumn, Type.TIMESTAMP);

List<Type> inputColumnTypes = new ArrayList<>();
List<Type> predicatedColumnTypes = new ArrayList<>();
List<Optional<String>> allInputColumnsName = input.getFieldNames();
List<Type> allInputColumnsType = input.getFieldTypes();
for (int i = 0, size = allInputColumnsName.size(); i < size; i++) {
Optional<String> fieldName = allInputColumnsName.get(i);
// All input value columns are required for model forecasting
if (!fieldName.isPresent()
|| !excludedColumns.contains(fieldName.get().toLowerCase(Locale.ENGLISH))) {
inputColumnTypes.add(allInputColumnsType.get(i));
requiredIndexList.add(i);
}
}
if (predicatedColumns.isEmpty()) {
// predicated columns by default include all columns from input table except for timecol and
// partition by columns
// predicated columns by default include all columns from input table except for
// timecol and partition by columns
for (int i = 0, size = allInputColumnsName.size(); i < size; i++) {
Optional<String> fieldName = allInputColumnsName.get(i);
if (!fieldName.isPresent()
|| !excludedColumns.contains(fieldName.get().toLowerCase(Locale.ENGLISH))) {
Type columnType = allInputColumnsType.get(i);
predicatedColumnTypes.add(columnType);
checkType(columnType, fieldName.orElse(""));
requiredIndexList.add(i);
properColumnSchemaBuilder.addField(fieldName, columnType);
}
}
Expand All @@ -347,7 +370,7 @@ public TableFunctionAnalysis analyze(Map<String, Argument> arguments) {
inputColumnIndexMap.put(fieldName.get().toLowerCase(Locale.ENGLISH), i);
}

Set<Integer> requiredIndexSet = new HashSet<>(predictedColumnsArray.length);
Set<Integer> predicatedIndexSet = new HashSet<>(predictedColumnsArray.length);
// columns need to be predicated
for (String outputColumn : predictedColumnsArray) {
String lowerCaseOutputColumn = outputColumn.toLowerCase(Locale.ENGLISH);
Expand All @@ -360,14 +383,13 @@ public TableFunctionAnalysis analyze(Map<String, Argument> arguments) {
throw new SemanticException(
String.format("Column %s don't exist in input", outputColumn));
}
if (!requiredIndexSet.add(inputColumnIndex)) {
if (!predicatedIndexSet.add(inputColumnIndex)) {
throw new SemanticException(String.format("Duplicate column %s", outputColumn));
}

Type columnType = allInputColumnsType.get(inputColumnIndex);
predicatedColumnTypes.add(columnType);
checkType(columnType, outputColumn);
requiredIndexList.add(inputColumnIndex);
properColumnSchemaBuilder.addField(outputColumn, columnType);
}
}
Expand All @@ -392,6 +414,7 @@ public TableFunctionAnalysis analyze(Map<String, Argument> arguments) {
outputStartTime,
outputInterval,
targetAINode,
inputColumnTypes,
predicatedColumnTypes);

// outputColumnSchema
Expand Down Expand Up @@ -469,8 +492,9 @@ private static class ForecastDataProcessor implements TableFunctionDataProcessor
private final boolean keepInput;
private final Map<String, String> options;
private final LinkedList<Record> inputRecords;
private final List<ResultColumnAppender> resultColumnAppenderList;
private final TsBlockBuilder inputTsBlockBuilder;
private final List<ResultColumnAppender> inputColumnAppenderList;
private final List<ResultColumnAppender> resultColumnAppenderList;

public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) {
this.targetAINode = functionHandle.targetAINode;
Expand All @@ -482,14 +506,22 @@ public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) {
this.keepInput = functionHandle.keepInput;
this.options = functionHandle.options;
this.inputRecords = new LinkedList<>();
this.resultColumnAppenderList = new ArrayList<>(functionHandle.types.size());
List<TSDataType> tsDataTypeList = new ArrayList<>(functionHandle.types.size());
for (Type type : functionHandle.types) {
List<TSDataType> inputTsDataTypeList =
new ArrayList<>(functionHandle.inputColumnTypes.size());
for (Type type : functionHandle.inputColumnTypes) {
// AINode currently only accept double input
inputTsDataTypeList.add(TSDataType.DOUBLE);
}
this.inputTsBlockBuilder = new TsBlockBuilder(inputTsDataTypeList);
this.inputColumnAppenderList = new ArrayList<>(functionHandle.inputColumnTypes.size());
for (Type type : functionHandle.inputColumnTypes) {
Comment on lines +511 to +517
Copy link

Copilot AI Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable 'Type type' is never read.

Suggested change
for (Type type : functionHandle.inputColumnTypes) {
// AINode currently only accept double input
inputTsDataTypeList.add(TSDataType.DOUBLE);
}
this.inputTsBlockBuilder = new TsBlockBuilder(inputTsDataTypeList);
this.inputColumnAppenderList = new ArrayList<>(functionHandle.inputColumnTypes.size());
for (Type type : functionHandle.inputColumnTypes) {
for (int i = 0; i < functionHandle.inputColumnTypes.size(); i++) {
// AINode currently only accept double input
inputTsDataTypeList.add(TSDataType.DOUBLE);
}
this.inputTsBlockBuilder = new TsBlockBuilder(inputTsDataTypeList);
this.inputColumnAppenderList = new ArrayList<>(functionHandle.inputColumnTypes.size());
for (int i = 0; i < functionHandle.inputColumnTypes.size(); i++) {

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable 'Type type' is never read.

Suggested change
for (Type type : functionHandle.inputColumnTypes) {
for (int i = 0; i < functionHandle.inputColumnTypes.size(); i++) {

Copilot uses AI. Check for mistakes.
// AINode currently only accept double input
inputColumnAppenderList.add(createResultColumnAppender(Type.DOUBLE));
}
this.resultColumnAppenderList = new ArrayList<>(functionHandle.predicatedColumnTypes.size());
for (Type type : functionHandle.predicatedColumnTypes) {
resultColumnAppenderList.add(createResultColumnAppender(type));
// ainode currently only accept double input
tsDataTypeList.add(TSDataType.DOUBLE);
}
this.inputTsBlockBuilder = new TsBlockBuilder(tsDataTypeList);
}

private static ResultColumnAppender createResultColumnAppender(Type type) {
Expand Down Expand Up @@ -613,7 +645,7 @@ private TsBlock forecast() {
// need to transform other types to DOUBLE
inputTsBlockBuilder
.getColumnBuilder(i - 1)
.writeDouble(resultColumnAppenderList.get(i - 1).getDouble(row, i));
.writeDouble(inputColumnAppenderList.get(i - 1).getDouble(row, i));
}
}
inputTsBlockBuilder.declarePosition();
Expand All @@ -634,15 +666,7 @@ private TsBlock forecast() {
throw new IoTDBRuntimeException(message, resp.getStatus().getCode());
}

TsBlock res = SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
if (res.getValueColumnCount() != inputData.getValueColumnCount()) {
throw new IoTDBRuntimeException(
String.format(
"Model %s output %s columns, doesn't equal to specified %s",
modelId, res.getValueColumnCount(), inputData.getValueColumnCount()),
TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
}
return res;
return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,8 @@ public void beforeStart(UDFParameters parameters, UDTFConfigurations configurati
Arrays.stream(
parameters.getStringOrDefault(OPTIONS_PARAMETER_NAME, DEFAULT_OPTIONS).split(","))
.map(s -> s.split("="))
.filter(arr -> arr.length == 2 && !arr[0].isEmpty()) // 防御性检查
.collect(
Collectors.toMap(
arr -> arr[0].trim(), arr -> arr[1].trim(), (v1, v2) -> v2 // 如果 key 重复,保留后一个
));
.filter(arr -> arr.length == 2 && !arr[0].isEmpty())
.collect(Collectors.toMap(arr -> arr[0].trim(), arr -> arr[1].trim(), (v1, v2) -> v2));
this.inputRows = new LinkedList<>();
List<TSDataType> tsDataTypeList = new ArrayList<>(this.types.size() - 1);
for (int i = 0; i < this.types.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ public void testForecastFunction() {
DEFAULT_OUTPUT_START_TIME,
DEFAULT_OUTPUT_INTERVAL,
new TEndPoint("127.0.0.1", 10810),
Collections.singletonList(DOUBLE),
Collections.singletonList(DOUBLE)));
// Verify full LogicalPlan
// Output - TableFunctionProcessor - TableScan
Expand Down Expand Up @@ -440,6 +441,7 @@ public void testForecastFunctionWithNoLowerCase() {
DEFAULT_OUTPUT_START_TIME,
DEFAULT_OUTPUT_INTERVAL,
new TEndPoint("127.0.0.1", 10810),
Collections.singletonList(DOUBLE),
Collections.singletonList(DOUBLE)));
// Verify full LogicalPlan
// Output - TableFunctionProcessor - TableScan
Expand Down
Loading