diff --git a/iotdb-core/ainode/.gitignore b/iotdb-core/ainode/.gitignore index 8cc2098c3fd8..80221b44dc5a 100644 --- a/iotdb-core/ainode/.gitignore +++ b/iotdb-core/ainode/.gitignore @@ -5,7 +5,6 @@ /iotdb/thrift/ /iotdb/tsfile/ /iotdb/utils/ -/iotdb/__init__.py /iotdb/Session.py /iotdb/SessionPool.py /iotdb/table_session.py diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java index 5521062a24e1..442359a8c94f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java @@ -84,7 +84,8 @@ public static class ForecastTableFunctionHandle implements TableFunctionHandle { long outputInterval; boolean keepInput; Map options; - List types; + List inputColumnTypes; + List predicatedColumnTypes; public ForecastTableFunctionHandle() {} @@ -97,7 +98,8 @@ public ForecastTableFunctionHandle( long outputStartTime, long outputInterval, TEndPoint targetAINode, - List types) { + List inputColumnTypes, + List predicatedColumnTypes) { this.keepInput = keepInput; this.maxInputLength = maxInputLength; this.modelId = modelId; @@ -106,7 +108,8 @@ public ForecastTableFunctionHandle( this.outputStartTime = outputStartTime; this.outputInterval = outputInterval; this.targetAINode = targetAINode; - this.types = types; + this.inputColumnTypes = inputColumnTypes; + this.predicatedColumnTypes = predicatedColumnTypes; } @Override @@ -122,8 +125,12 @@ public byte[] serialize() { ReadWriteIOUtils.write(outputInterval, outputStream); ReadWriteIOUtils.write(keepInput, outputStream); ReadWriteIOUtils.write(options, outputStream); - ReadWriteIOUtils.write(types.size(), outputStream); - for (Type type : types) { + ReadWriteIOUtils.write(inputColumnTypes.size(), outputStream); + for (Type type : inputColumnTypes) { + ReadWriteIOUtils.write(type.getType(), outputStream); + } + ReadWriteIOUtils.write(predicatedColumnTypes.size(), outputStream); + for (Type type : predicatedColumnTypes) { ReadWriteIOUtils.write(type.getType(), outputStream); } outputStream.flush(); @@ -149,9 +156,14 @@ public void deserialize(byte[] bytes) { this.keepInput = ReadWriteIOUtils.readBoolean(buffer); this.options = ReadWriteIOUtils.readMap(buffer); int size = ReadWriteIOUtils.readInt(buffer); - this.types = new ArrayList<>(size); + this.inputColumnTypes = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + inputColumnTypes.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer))); + } + size = ReadWriteIOUtils.readInt(buffer); + this.predicatedColumnTypes = new ArrayList<>(size); for (int i = 0; i < size; i++) { - types.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer))); + predicatedColumnTypes.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer))); } } @@ -172,7 +184,8 @@ public boolean equals(Object o) { && Objects.equals(targetAINode, that.targetAINode) && Objects.equals(modelId, that.modelId) && Objects.equals(options, that.options) - && Objects.equals(types, that.types); + && Objects.equals(inputColumnTypes, that.inputColumnTypes) + && Objects.equals(predicatedColumnTypes, that.predicatedColumnTypes); } @Override @@ -186,7 +199,8 @@ public int hashCode() { outputInterval, keepInput, options, - types); + inputColumnTypes, + predicatedColumnTypes); } } @@ -319,12 +333,22 @@ public TableFunctionAnalysis analyze(Map arguments) { DescribedSchema.Builder properColumnSchemaBuilder = new DescribedSchema.Builder().addField(timeColumn, Type.TIMESTAMP); + List inputColumnTypes = new ArrayList<>(); List predicatedColumnTypes = new ArrayList<>(); List> allInputColumnsName = input.getFieldNames(); List allInputColumnsType = input.getFieldTypes(); + for (int i = 0, size = allInputColumnsName.size(); i < size; i++) { + Optional fieldName = allInputColumnsName.get(i); + // All input value columns are required for model forecasting + if (!fieldName.isPresent() + || !excludedColumns.contains(fieldName.get().toLowerCase(Locale.ENGLISH))) { + inputColumnTypes.add(allInputColumnsType.get(i)); + requiredIndexList.add(i); + } + } if (predicatedColumns.isEmpty()) { - // predicated columns by default include all columns from input table except for timecol and - // partition by columns + // predicated columns by default include all columns from input table except for + // timecol and partition by columns for (int i = 0, size = allInputColumnsName.size(); i < size; i++) { Optional fieldName = allInputColumnsName.get(i); if (!fieldName.isPresent() @@ -332,7 +356,6 @@ public TableFunctionAnalysis analyze(Map arguments) { Type columnType = allInputColumnsType.get(i); predicatedColumnTypes.add(columnType); checkType(columnType, fieldName.orElse("")); - requiredIndexList.add(i); properColumnSchemaBuilder.addField(fieldName, columnType); } } @@ -347,7 +370,7 @@ public TableFunctionAnalysis analyze(Map arguments) { inputColumnIndexMap.put(fieldName.get().toLowerCase(Locale.ENGLISH), i); } - Set requiredIndexSet = new HashSet<>(predictedColumnsArray.length); + Set predicatedIndexSet = new HashSet<>(predictedColumnsArray.length); // columns need to be predicated for (String outputColumn : predictedColumnsArray) { String lowerCaseOutputColumn = outputColumn.toLowerCase(Locale.ENGLISH); @@ -360,14 +383,13 @@ public TableFunctionAnalysis analyze(Map arguments) { throw new SemanticException( String.format("Column %s don't exist in input", outputColumn)); } - if (!requiredIndexSet.add(inputColumnIndex)) { + if (!predicatedIndexSet.add(inputColumnIndex)) { throw new SemanticException(String.format("Duplicate column %s", outputColumn)); } Type columnType = allInputColumnsType.get(inputColumnIndex); predicatedColumnTypes.add(columnType); checkType(columnType, outputColumn); - requiredIndexList.add(inputColumnIndex); properColumnSchemaBuilder.addField(outputColumn, columnType); } } @@ -392,6 +414,7 @@ public TableFunctionAnalysis analyze(Map arguments) { outputStartTime, outputInterval, targetAINode, + inputColumnTypes, predicatedColumnTypes); // outputColumnSchema @@ -469,8 +492,9 @@ private static class ForecastDataProcessor implements TableFunctionDataProcessor private final boolean keepInput; private final Map options; private final LinkedList inputRecords; - private final List resultColumnAppenderList; private final TsBlockBuilder inputTsBlockBuilder; + private final List inputColumnAppenderList; + private final List resultColumnAppenderList; public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) { this.targetAINode = functionHandle.targetAINode; @@ -482,14 +506,22 @@ public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) { this.keepInput = functionHandle.keepInput; this.options = functionHandle.options; this.inputRecords = new LinkedList<>(); - this.resultColumnAppenderList = new ArrayList<>(functionHandle.types.size()); - List tsDataTypeList = new ArrayList<>(functionHandle.types.size()); - for (Type type : functionHandle.types) { + List inputTsDataTypeList = + new ArrayList<>(functionHandle.inputColumnTypes.size()); + for (Type type : functionHandle.inputColumnTypes) { + // AINode currently only accept double input + inputTsDataTypeList.add(TSDataType.DOUBLE); + } + this.inputTsBlockBuilder = new TsBlockBuilder(inputTsDataTypeList); + this.inputColumnAppenderList = new ArrayList<>(functionHandle.inputColumnTypes.size()); + for (Type type : functionHandle.inputColumnTypes) { + // AINode currently only accept double input + inputColumnAppenderList.add(createResultColumnAppender(Type.DOUBLE)); + } + this.resultColumnAppenderList = new ArrayList<>(functionHandle.predicatedColumnTypes.size()); + for (Type type : functionHandle.predicatedColumnTypes) { resultColumnAppenderList.add(createResultColumnAppender(type)); - // ainode currently only accept double input - tsDataTypeList.add(TSDataType.DOUBLE); } - this.inputTsBlockBuilder = new TsBlockBuilder(tsDataTypeList); } private static ResultColumnAppender createResultColumnAppender(Type type) { @@ -613,7 +645,7 @@ private TsBlock forecast() { // need to transform other types to DOUBLE inputTsBlockBuilder .getColumnBuilder(i - 1) - .writeDouble(resultColumnAppenderList.get(i - 1).getDouble(row, i)); + .writeDouble(inputColumnAppenderList.get(i - 1).getDouble(row, i)); } } inputTsBlockBuilder.declarePosition(); @@ -634,15 +666,7 @@ private TsBlock forecast() { throw new IoTDBRuntimeException(message, resp.getStatus().getCode()); } - TsBlock res = SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult())); - if (res.getValueColumnCount() != inputData.getValueColumnCount()) { - throw new IoTDBRuntimeException( - String.format( - "Model %s output %s columns, doesn't equal to specified %s", - modelId, res.getValueColumnCount(), inputData.getValueColumnCount()), - TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode()); - } - return res; + return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult())); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java index 22c2bce7b5ee..f1e387c0ff66 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java @@ -128,11 +128,8 @@ public void beforeStart(UDFParameters parameters, UDTFConfigurations configurati Arrays.stream( parameters.getStringOrDefault(OPTIONS_PARAMETER_NAME, DEFAULT_OPTIONS).split(",")) .map(s -> s.split("=")) - .filter(arr -> arr.length == 2 && !arr[0].isEmpty()) // 防御性检查 - .collect( - Collectors.toMap( - arr -> arr[0].trim(), arr -> arr[1].trim(), (v1, v2) -> v2 // 如果 key 重复,保留后一个 - )); + .filter(arr -> arr.length == 2 && !arr[0].isEmpty()) + .collect(Collectors.toMap(arr -> arr[0].trim(), arr -> arr[1].trim(), (v1, v2) -> v2)); this.inputRows = new LinkedList<>(); List tsDataTypeList = new ArrayList<>(this.types.size() - 1); for (int i = 0; i < this.types.size(); i++) { diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java index e56b48936b96..0e8ffb484f11 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java @@ -379,6 +379,7 @@ public void testForecastFunction() { DEFAULT_OUTPUT_START_TIME, DEFAULT_OUTPUT_INTERVAL, new TEndPoint("127.0.0.1", 10810), + Collections.singletonList(DOUBLE), Collections.singletonList(DOUBLE))); // Verify full LogicalPlan // Output - TableFunctionProcessor - TableScan @@ -440,6 +441,7 @@ public void testForecastFunctionWithNoLowerCase() { DEFAULT_OUTPUT_START_TIME, DEFAULT_OUTPUT_INTERVAL, new TEndPoint("127.0.0.1", 10810), + Collections.singletonList(DOUBLE), Collections.singletonList(DOUBLE))); // Verify full LogicalPlan // Output - TableFunctionProcessor - TableScan