Skip to content

Commit 4da262e

Browse files
yunbow30944CRZbulabula
authored andcommitted
[AINode] Fix bug of sundial and forecast udf (#16768)
(cherry picked from commit 2b47be7)
1 parent 26f0628 commit 4da262e

File tree

3 files changed

+279
-4
lines changed

3 files changed

+279
-4
lines changed

iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -645,9 +645,10 @@ def prepare_inputs_for_generation(
645645
position_ids = attention_mask.long().cumsum(-1) - 1
646646
position_ids.masked_fill_(attention_mask == 0, 1)
647647
if past_key_values:
648-
position_ids = position_ids[
649-
:, -(input_ids.shape[1] // self.config.input_token_len) :
650-
]
648+
token_num = (
649+
input_ids.shape[1] + self.config.input_token_len - 1
650+
) // self.config.input_token_len
651+
position_ids = position_ids[:, -token_num:]
651652

652653
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
653654
if inputs_embeds is not None and past_key_values is None:

iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,11 @@ def prepare_inputs_for_generation(
613613
if attention_mask is not None and attention_mask.shape[1] > (
614614
input_ids.shape[1] // self.config.input_token_len
615615
):
616-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
616+
input_ids = input_ids[
617+
:,
618+
-(attention_mask.shape[1] - past_length)
619+
* self.config.input_token_len :,
620+
]
617621
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
618622
# input_ids based on the past_length.
619623
elif past_length < (input_ids.shape[1] // self.config.input_token_len):
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.iotdb.db.queryengine.plan.udf;
21+
22+
import org.apache.iotdb.ainode.rpc.thrift.TForecastResp;
23+
import org.apache.iotdb.common.rpc.thrift.TEndPoint;
24+
import org.apache.iotdb.commons.exception.IoTDBRuntimeException;
25+
import org.apache.iotdb.db.protocol.client.ainode.AINodeClient;
26+
import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager;
27+
import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher;
28+
import org.apache.iotdb.db.queryengine.plan.analyze.ModelFetcher;
29+
import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
30+
import org.apache.iotdb.rpc.TSStatusCode;
31+
import org.apache.iotdb.udf.api.UDTF;
32+
import org.apache.iotdb.udf.api.access.Row;
33+
import org.apache.iotdb.udf.api.collector.PointCollector;
34+
import org.apache.iotdb.udf.api.customizer.config.UDTFConfigurations;
35+
import org.apache.iotdb.udf.api.customizer.parameter.UDFParameters;
36+
import org.apache.iotdb.udf.api.customizer.strategy.RowByRowAccessStrategy;
37+
import org.apache.iotdb.udf.api.type.Type;
38+
39+
import org.apache.tsfile.enums.TSDataType;
40+
import org.apache.tsfile.read.common.block.TsBlock;
41+
import org.apache.tsfile.read.common.block.TsBlockBuilder;
42+
import org.apache.tsfile.read.common.block.column.TsBlockSerde;
43+
44+
import java.io.IOException;
45+
import java.nio.ByteBuffer;
46+
import java.util.ArrayList;
47+
import java.util.Arrays;
48+
import java.util.HashSet;
49+
import java.util.LinkedList;
50+
import java.util.List;
51+
import java.util.Map;
52+
import java.util.Set;
53+
import java.util.stream.Collectors;
54+
55+
public class UDTFForecast implements UDTF {
56+
private static final TsBlockSerde serde = new TsBlockSerde();
57+
private static final AINodeClientManager CLIENT_MANAGER = AINodeClientManager.getInstance();
58+
private TEndPoint targetAINode = new TEndPoint("127.0.0.1", 10810);
59+
private String model_id;
60+
private int maxInputLength;
61+
private int outputLength;
62+
private long outputStartTime;
63+
private long outputInterval;
64+
private boolean keepInput;
65+
Map<String, String> options;
66+
List<Type> types;
67+
private LinkedList<Row> inputRows;
68+
private TsBlockBuilder inputTsBlockBuilder;
69+
private final IModelFetcher modelFetcher = ModelFetcher.getInstance();
70+
71+
private static final Set<Type> ALLOWED_INPUT_TYPES = new HashSet<>();
72+
73+
static {
74+
ALLOWED_INPUT_TYPES.add(Type.INT32);
75+
ALLOWED_INPUT_TYPES.add(Type.INT64);
76+
ALLOWED_INPUT_TYPES.add(Type.FLOAT);
77+
ALLOWED_INPUT_TYPES.add(Type.DOUBLE);
78+
}
79+
80+
private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID";
81+
private static final String OUTPUT_LENGTH_PARAMETER_NAME = "OUTPUT_LENGTH";
82+
private static final int DEFAULT_OUTPUT_LENGTH = 96;
83+
private static final String OUTPUT_START_TIME = "OUTPUT_START_TIME";
84+
public static final long DEFAULT_OUTPUT_START_TIME = Long.MIN_VALUE;
85+
private static final String OUTPUT_INTERVAL = "OUTPUT_INTERVAL";
86+
public static final long DEFAULT_OUTPUT_INTERVAL = 0L;
87+
private static final String KEEP_INPUT_PARAMETER_NAME = "PRESERVE_INPUT";
88+
private static final Boolean DEFAULT_KEEP_INPUT = Boolean.FALSE;
89+
private static final String OPTIONS_PARAMETER_NAME = "MODEL_OPTIONS";
90+
private static final String DEFAULT_OPTIONS = "";
91+
92+
private void checkType() {
93+
for (Type type : this.types) {
94+
if (!ALLOWED_INPUT_TYPES.contains(type)) {
95+
throw new IllegalArgumentException(
96+
String.format(
97+
"Input data type %s is not supported, only %s are allowed.",
98+
type, ALLOWED_INPUT_TYPES));
99+
}
100+
}
101+
}
102+
103+
@Override
104+
public void beforeStart(UDFParameters parameters, UDTFConfigurations configurations)
105+
throws Exception {
106+
this.types = parameters.getDataTypes();
107+
checkType();
108+
configurations.setAccessStrategy(new RowByRowAccessStrategy()).setOutputDataType(Type.DOUBLE);
109+
110+
this.model_id = parameters.getString(MODEL_ID_PARAMETER_NAME);
111+
if (this.model_id == null || this.model_id.isEmpty()) {
112+
throw new IllegalArgumentException(
113+
"MODEL_ID parameter must be provided and cannot be empty.");
114+
}
115+
ModelInferenceDescriptor descriptor = modelFetcher.fetchModel(this.model_id);
116+
this.targetAINode = descriptor.getTargetAINode();
117+
118+
this.outputInterval = parameters.getLongOrDefault(OUTPUT_INTERVAL, DEFAULT_OUTPUT_INTERVAL);
119+
this.outputLength =
120+
parameters.getIntOrDefault(OUTPUT_LENGTH_PARAMETER_NAME, DEFAULT_OUTPUT_LENGTH);
121+
this.outputStartTime =
122+
parameters.getLongOrDefault(OUTPUT_START_TIME, DEFAULT_OUTPUT_START_TIME);
123+
this.keepInput = parameters.getBooleanOrDefault(KEEP_INPUT_PARAMETER_NAME, DEFAULT_KEEP_INPUT);
124+
this.options =
125+
Arrays.stream(
126+
parameters.getStringOrDefault(OPTIONS_PARAMETER_NAME, DEFAULT_OPTIONS).split(","))
127+
.map(s -> s.split("="))
128+
.filter(arr -> arr.length == 2 && !arr[0].isEmpty()) // 防御性检查
129+
.collect(
130+
Collectors.toMap(
131+
arr -> arr[0].trim(), arr -> arr[1].trim(), (v1, v2) -> v2 // 如果 key 重复,保留后一个
132+
));
133+
this.inputRows = new LinkedList<>();
134+
List<TSDataType> tsDataTypeList = new ArrayList<>(this.types.size() - 1);
135+
for (int i = 0; i < this.types.size(); i++) {
136+
tsDataTypeList.add(TSDataType.DOUBLE);
137+
}
138+
this.inputTsBlockBuilder = new TsBlockBuilder(tsDataTypeList);
139+
}
140+
141+
private void setByType(Row row, PointCollector collector) throws IOException {
142+
for (int i = 0; i < row.size(); i++) {
143+
switch (this.types.get(i)) {
144+
case INT32:
145+
collector.putInt(row.getTime(), row.getInt(i));
146+
break;
147+
case INT64:
148+
collector.putLong(row.getTime(), row.getLong(i));
149+
break;
150+
case FLOAT:
151+
collector.putFloat(row.getTime(), row.getFloat(i));
152+
break;
153+
case DOUBLE:
154+
collector.putDouble(row.getTime(), row.getDouble(i));
155+
break;
156+
default:
157+
throw new IllegalArgumentException(
158+
String.format("Unsupported data type %s", this.types.get(i + 1)));
159+
}
160+
}
161+
}
162+
163+
private void setByType(Row row, TsBlockBuilder tsBlockBuilder) throws IOException {
164+
for (int i = 0; i < row.size(); i++) {
165+
if (row.isNull(i)) {
166+
tsBlockBuilder.getColumnBuilder(i).appendNull();
167+
continue;
168+
}
169+
switch (this.types.get(i)) {
170+
case INT32:
171+
tsBlockBuilder.getColumnBuilder(i).writeInt(row.getInt(i));
172+
break;
173+
case INT64:
174+
tsBlockBuilder.getColumnBuilder(i).writeLong(row.getLong(i));
175+
break;
176+
case FLOAT:
177+
tsBlockBuilder.getColumnBuilder(i).writeFloat(row.getFloat(i));
178+
break;
179+
case DOUBLE:
180+
tsBlockBuilder.getColumnBuilder(i).writeDouble(row.getDouble(i));
181+
break;
182+
default:
183+
throw new IllegalArgumentException(
184+
String.format("Unsupported data type %s", this.types.get(i + 1)));
185+
}
186+
}
187+
}
188+
189+
@Override
190+
public void transform(Row row, PointCollector collector) throws Exception {
191+
if (this.keepInput) {
192+
setByType(row, collector);
193+
}
194+
195+
if (maxInputLength != 0 && inputRows.size() >= maxInputLength) {
196+
// If the input rows exceed the maximum length, remove the oldest row
197+
inputRows.removeFirst();
198+
}
199+
inputRows.add(row);
200+
}
201+
202+
private TsBlock forecast() throws Exception {
203+
// Build the input data which will be sent to AINode
204+
while (!inputRows.isEmpty()) {
205+
Row row = inputRows.removeFirst();
206+
inputTsBlockBuilder.getTimeColumnBuilder().writeLong(row.getTime());
207+
setByType(row, inputTsBlockBuilder);
208+
inputTsBlockBuilder.declarePosition();
209+
}
210+
211+
TsBlock inputData = inputTsBlockBuilder.build();
212+
213+
TForecastResp resp;
214+
try (AINodeClient client = CLIENT_MANAGER.borrowClient(targetAINode)) {
215+
resp = client.forecast(model_id, inputData, outputLength, options);
216+
} catch (Exception e) {
217+
throw new IoTDBRuntimeException(
218+
e.getMessage(), TSStatusCode.CAN_NOT_CONNECT_AINODE.getStatusCode());
219+
}
220+
221+
if (resp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
222+
throw new IoTDBRuntimeException(
223+
String.format(
224+
"Forecast failed due to %d %s",
225+
resp.getStatus().getCode(), resp.getStatus().getMessage()),
226+
resp.getStatus().getCode());
227+
}
228+
return serde.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
229+
}
230+
231+
@Override
232+
public void terminate(PointCollector collector) throws Exception {
233+
long inputStartTime = inputRows.get(0).getTime();
234+
long inputEndTime = inputRows.get(inputRows.size() - 1).getTime();
235+
if (inputStartTime > inputEndTime) {
236+
throw new IllegalArgumentException(
237+
String.format(
238+
"input end time should never less than start time, start time is %s, end time is %s",
239+
inputStartTime, inputEndTime));
240+
}
241+
long interval = this.outputInterval;
242+
if (outputInterval <= 0) {
243+
interval = (inputEndTime - inputStartTime) / (inputRows.size() - 1);
244+
}
245+
long outputTime =
246+
(this.outputStartTime == Long.MIN_VALUE) ? inputEndTime + interval : this.outputStartTime;
247+
long[] outputTimes = new long[this.outputLength];
248+
for (int i = 0; i < this.outputLength; i++) {
249+
outputTimes[i] = outputTime + interval * i;
250+
}
251+
252+
TsBlock forecastResult = forecast();
253+
if (forecastResult.getPositionCount() != this.outputLength) {
254+
throw new IllegalArgumentException(
255+
String.format(
256+
"The forecast result length %d does not match the expected output length %d",
257+
forecastResult.getPositionCount(), this.outputLength));
258+
}
259+
if (forecastResult.getValueColumnCount() != 1) {
260+
throw new IllegalArgumentException(
261+
String.format(
262+
"The forecast result should have only one value column, but got %d",
263+
forecastResult.getValueColumnCount()));
264+
}
265+
266+
for (int i = 0; i < forecastResult.getPositionCount(); i++) {
267+
collector.putDouble(outputTimes[i], forecastResult.getValueColumns()[0].getDouble(i));
268+
}
269+
}
270+
}

0 commit comments

Comments
 (0)