Skip to content

Commit 0a2166f

Browse files
authored
UDFT pattern_match: A similarity matching algorithm based on sketch and example
1 parent e09bdbe commit 0a2166f

File tree

12 files changed

+2535
-5
lines changed

12 files changed

+2535
-5
lines changed

integration-test/src/test/java/org/apache/iotdb/db/it/utils/TestUtils.java

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import java.sql.ResultSetMetaData;
4141
import java.sql.SQLException;
4242
import java.sql.Statement;
43+
import java.sql.Types;
4344
import java.text.DateFormat;
4445
import java.time.ZoneId;
4546
import java.time.ZoneOffset;
@@ -51,6 +52,7 @@
5152
import java.util.List;
5253
import java.util.Map;
5354
import java.util.Objects;
55+
import java.util.Optional;
5456
import java.util.Set;
5557
import java.util.TreeMap;
5658
import java.util.concurrent.TimeUnit;
@@ -233,6 +235,18 @@ public static void tableResultSetEqualTest(
233235
"+00:00");
234236
}
235237

238+
public static void tableResultSetEqualByDataTypeTest(
239+
String sql, String[] expectedHeader, String[] expectedRetArray, String database) {
240+
tableResultSetEqualByDataTypeTest(
241+
sql,
242+
expectedHeader,
243+
expectedRetArray,
244+
SessionConfig.DEFAULT_USER,
245+
SessionConfig.DEFAULT_PASSWORD,
246+
database,
247+
"+00:00");
248+
}
249+
236250
public static void tableResultSetEqualTest(
237251
String sql,
238252
String timeZone,
@@ -298,6 +312,57 @@ public static void tableResultSetEqualTest(
298312
}
299313
}
300314

315+
public static void tableResultSetEqualByDataTypeTest(
316+
String sql,
317+
String[] expectedHeader,
318+
String[] expectedRetArray,
319+
String userName,
320+
String password,
321+
String database,
322+
String timeZone) {
323+
try (Connection connection =
324+
EnvFactory.getEnv().getConnection(userName, password, BaseEnv.TABLE_SQL_DIALECT)) {
325+
connection.setClientInfo("time_zone", timeZone);
326+
try (Statement statement = connection.createStatement()) {
327+
statement.execute("use " + database);
328+
try (ResultSet resultSet = statement.executeQuery(sql)) {
329+
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
330+
for (int i = 1; i <= resultSetMetaData.getColumnCount(); i++) {
331+
assertEquals(expectedHeader[i - 1], resultSetMetaData.getColumnName(i));
332+
}
333+
assertEquals(expectedHeader.length, resultSetMetaData.getColumnCount());
334+
335+
int cnt = 0;
336+
while (resultSet.next()) {
337+
for (int i = 1; i <= expectedHeader.length; i++) {
338+
if (resultSetMetaData.getColumnType(i) == Types.BOOLEAN) {
339+
assertEquals(
340+
Boolean.valueOf(expectedRetArray[cnt].split(",")[i - 1]),
341+
resultSet.getBoolean(i));
342+
} else if (resultSetMetaData.getColumnType(i) == Types.INTEGER) {
343+
assertEquals(
344+
Optional.of(Integer.valueOf(expectedRetArray[cnt].split(",")[i - 1])),
345+
Optional.of(resultSet.getInt(i)));
346+
} else if (resultSetMetaData.getColumnType(i) == Types.DOUBLE) {
347+
assertEquals(
348+
Double.valueOf(expectedRetArray[cnt].split(",")[i - 1]),
349+
resultSet.getDouble(i),
350+
DELTA);
351+
} else if (resultSetMetaData.getColumnType(i) == Types.VARCHAR) {
352+
assertEquals(expectedRetArray[cnt].split(",")[i - 1], resultSet.getString(i));
353+
}
354+
}
355+
cnt++;
356+
}
357+
assertEquals(expectedRetArray.length, cnt);
358+
}
359+
}
360+
} catch (SQLException e) {
361+
e.printStackTrace();
362+
fail(e.getMessage());
363+
}
364+
}
365+
301366
public static void tableExecuteTest(String sql, String userName, String password) {
302367
try (Connection connection =
303368
EnvFactory.getEnv().getConnection(userName, password, BaseEnv.TABLE_SQL_DIALECT)) {

integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBWindowTVFIT.java

Lines changed: 281 additions & 2 deletions
Large diffs are not rendered by default.

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.apache.iotdb.commons.udf.builtin.relational.tvf.TumbleTableFunction;
2727
import org.apache.iotdb.commons.udf.builtin.relational.tvf.VariationTableFunction;
2828
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction;
29+
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.PatternMatchTableFunction;
2930
import org.apache.iotdb.udf.api.relational.TableFunction;
3031

3132
import java.util.Arrays;
@@ -40,7 +41,8 @@ public enum TableBuiltinTableFunction {
4041
SESSION("session"),
4142
VARIATION("variation"),
4243
CAPACITY("capacity"),
43-
FORECAST("forecast");
44+
FORECAST("forecast"),
45+
PATTERN_MATCH("pattern_match");
4446

4547
private final String functionName;
4648

@@ -78,6 +80,8 @@ public static TableFunction getBuiltinTableFunction(String functionName) {
7880
return new SessionTableFunction();
7981
case "variation":
8082
return new VariationTableFunction();
83+
case "pattern_match":
84+
return new PatternMatchTableFunction();
8185
case "capacity":
8286
return new CapacityTableFunction();
8387
case "forecast":
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.iotdb.db.queryengine.plan.relational.function.tvf;
21+
22+
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.match.QetchAlgorithm;
23+
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.match.model.MatchState;
24+
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.match.model.Point;
25+
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.match.model.RegexMatchState;
26+
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.match.model.Section;
27+
import org.apache.iotdb.udf.api.exception.UDFException;
28+
import org.apache.iotdb.udf.api.relational.TableFunction;
29+
import org.apache.iotdb.udf.api.relational.access.Record;
30+
import org.apache.iotdb.udf.api.relational.table.MapTableFunctionHandle;
31+
import org.apache.iotdb.udf.api.relational.table.TableFunctionAnalysis;
32+
import org.apache.iotdb.udf.api.relational.table.TableFunctionHandle;
33+
import org.apache.iotdb.udf.api.relational.table.TableFunctionProcessorProvider;
34+
import org.apache.iotdb.udf.api.relational.table.argument.Argument;
35+
import org.apache.iotdb.udf.api.relational.table.argument.DescribedSchema;
36+
import org.apache.iotdb.udf.api.relational.table.argument.ScalarArgument;
37+
import org.apache.iotdb.udf.api.relational.table.argument.TableArgument;
38+
import org.apache.iotdb.udf.api.relational.table.processor.TableFunctionDataProcessor;
39+
import org.apache.iotdb.udf.api.relational.table.specification.ParameterSpecification;
40+
import org.apache.iotdb.udf.api.relational.table.specification.ScalarParameterSpecification;
41+
import org.apache.iotdb.udf.api.relational.table.specification.TableParameterSpecification;
42+
import org.apache.iotdb.udf.api.type.Type;
43+
44+
import com.google.common.collect.ImmutableSet;
45+
import org.apache.tsfile.block.column.ColumnBuilder;
46+
47+
import java.util.Arrays;
48+
import java.util.List;
49+
import java.util.Map;
50+
51+
import static org.apache.iotdb.commons.udf.builtin.relational.tvf.WindowTVFUtils.findColumnIndex;
52+
53+
public class PatternMatchTableFunction implements TableFunction {
54+
private static final String TBL_PARAM = "DATA";
55+
private static final String TIME_COLUMN = "TIME_COL";
56+
private static final String DATA_COLUMN = "DATA_COL";
57+
private static final String PATTERN_PARAM = "PATTERN";
58+
private static final String SMOOTH_PARAM = "SMOOTH";
59+
private static final String THRESHOLD_PARAM = "THRESHOLD";
60+
private static final String WIDTH_PARAM = "WIDTH";
61+
private static final String HEIGHT_PARAM = "HEIGHT";
62+
private static final String SMOOTH_ON_PATTERN = "SMOOTH_ON_PATTERN";
63+
64+
@Override
65+
public List<ParameterSpecification> getArgumentsSpecifications() {
66+
return Arrays.asList(
67+
TableParameterSpecification.builder().name(TBL_PARAM).passThroughColumns().build(),
68+
ScalarParameterSpecification.builder()
69+
.name(TIME_COLUMN)
70+
.type(Type.STRING)
71+
.defaultValue("time")
72+
.build(),
73+
ScalarParameterSpecification.builder().name(DATA_COLUMN).type(Type.STRING).build(),
74+
ScalarParameterSpecification.builder().name(PATTERN_PARAM).type(Type.STRING).build(),
75+
ScalarParameterSpecification.builder().name(SMOOTH_PARAM).type(Type.DOUBLE).build(),
76+
ScalarParameterSpecification.builder().name(THRESHOLD_PARAM).type(Type.DOUBLE).build(),
77+
ScalarParameterSpecification.builder()
78+
.name(WIDTH_PARAM)
79+
.type(Type.DOUBLE)
80+
.defaultValue(Double.MAX_VALUE)
81+
.build(),
82+
ScalarParameterSpecification.builder()
83+
.name(HEIGHT_PARAM)
84+
.type(Type.DOUBLE)
85+
.defaultValue(Double.MAX_VALUE)
86+
.build(),
87+
ScalarParameterSpecification.builder()
88+
.name(SMOOTH_ON_PATTERN)
89+
.type(Type.BOOLEAN)
90+
.defaultValue(false)
91+
.build());
92+
}
93+
94+
@Override
95+
public TableFunctionAnalysis analyze(Map<String, Argument> arguments) throws UDFException {
96+
// calc the index of the column
97+
TableArgument tableArgument = (TableArgument) arguments.get(TBL_PARAM);
98+
String expectedTimeName = (String) ((ScalarArgument) arguments.get(TIME_COLUMN)).getValue();
99+
String expectedDataName = (String) ((ScalarArgument) arguments.get(DATA_COLUMN)).getValue();
100+
int requiredTimeIndex =
101+
findColumnIndex(tableArgument, expectedTimeName, ImmutableSet.of(Type.TIMESTAMP));
102+
int requiredDataIndex =
103+
findColumnIndex(
104+
tableArgument,
105+
expectedDataName,
106+
ImmutableSet.of(Type.INT32, Type.INT64, Type.FLOAT, Type.DOUBLE));
107+
108+
// outputColumnSchema description
109+
DescribedSchema properColumnSchema =
110+
new DescribedSchema.Builder()
111+
.addField("match_index", Type.INT32)
112+
.addField("similarity", Type.DOUBLE)
113+
.build();
114+
115+
// this is for transferring the parameters to the processor
116+
MapTableFunctionHandle handle =
117+
new MapTableFunctionHandle.Builder()
118+
.addProperty(PATTERN_PARAM, ((ScalarArgument) arguments.get(PATTERN_PARAM)).getValue())
119+
.addProperty(SMOOTH_PARAM, ((ScalarArgument) arguments.get(SMOOTH_PARAM)).getValue())
120+
.addProperty(
121+
THRESHOLD_PARAM, ((ScalarArgument) arguments.get(THRESHOLD_PARAM)).getValue())
122+
.addProperty(WIDTH_PARAM, ((ScalarArgument) arguments.get(WIDTH_PARAM)).getValue())
123+
.addProperty(HEIGHT_PARAM, ((ScalarArgument) arguments.get(HEIGHT_PARAM)).getValue())
124+
.addProperty(
125+
SMOOTH_ON_PATTERN, ((ScalarArgument) arguments.get(SMOOTH_ON_PATTERN)).getValue())
126+
.build();
127+
128+
return TableFunctionAnalysis.builder()
129+
.properColumnSchema(properColumnSchema)
130+
.requireRecordSnapshot(false)
131+
.requiredColumns(
132+
TBL_PARAM,
133+
Arrays.asList(requiredTimeIndex, requiredDataIndex)) // the 0th column is time
134+
.handle(handle)
135+
.build();
136+
}
137+
138+
@Override
139+
public TableFunctionHandle createTableFunctionHandle() {
140+
return new MapTableFunctionHandle();
141+
}
142+
143+
@Override
144+
public TableFunctionProcessorProvider getProcessorProvider(
145+
TableFunctionHandle tableFunctionHandle) {
146+
String pattern =
147+
(String) ((MapTableFunctionHandle) tableFunctionHandle).getProperty(PATTERN_PARAM);
148+
Double smoothValue =
149+
(Double) ((MapTableFunctionHandle) tableFunctionHandle).getProperty(SMOOTH_PARAM);
150+
Double threshold =
151+
(Double) ((MapTableFunctionHandle) tableFunctionHandle).getProperty(THRESHOLD_PARAM);
152+
Double widthLimit =
153+
(Double) ((MapTableFunctionHandle) tableFunctionHandle).getProperty(WIDTH_PARAM);
154+
Double heightLimit =
155+
(Double) ((MapTableFunctionHandle) tableFunctionHandle).getProperty(HEIGHT_PARAM);
156+
boolean isPatternFromOrigin =
157+
(Boolean) ((MapTableFunctionHandle) tableFunctionHandle).getProperty(SMOOTH_ON_PATTERN);
158+
159+
QetchAlgorithm qetchAlgorithm = new QetchAlgorithm();
160+
qetchAlgorithm.setThreshold(threshold);
161+
qetchAlgorithm.setSmoothValue(smoothValue);
162+
qetchAlgorithm.setHeightLimit(heightLimit);
163+
qetchAlgorithm.setWidthLimit(widthLimit);
164+
qetchAlgorithm.setIsPatternFromOrigin(isPatternFromOrigin);
165+
qetchAlgorithm.parsePattern2Automaton(pattern);
166+
167+
return new TableFunctionProcessorProvider() {
168+
@Override
169+
public TableFunctionDataProcessor getDataProcessor() {
170+
return new ShapeMatchDataProcessor(qetchAlgorithm);
171+
}
172+
};
173+
}
174+
175+
private static class ShapeMatchDataProcessor implements TableFunctionDataProcessor {
176+
177+
private final QetchAlgorithm qetchAlgorithm;
178+
179+
public ShapeMatchDataProcessor(QetchAlgorithm qetchAlgorithm) {
180+
this.qetchAlgorithm = qetchAlgorithm;
181+
}
182+
183+
@Override
184+
public void process(
185+
Record input,
186+
List<ColumnBuilder> properColumnBuilders,
187+
ColumnBuilder passThroughIndexBuilder) {
188+
189+
double time = input.getLong(0);
190+
double value = input.getDouble(1);
191+
192+
qetchAlgorithm.addPoint(new Point(time, value, qetchAlgorithm.getPointNum()));
193+
if (qetchAlgorithm.hasMatchResult()) {
194+
outputWindow(
195+
properColumnBuilders, passThroughIndexBuilder, qetchAlgorithm.getMatchResult());
196+
}
197+
if (qetchAlgorithm.hasRegexMatchResult()) {
198+
outputWindowRegex(
199+
properColumnBuilders, passThroughIndexBuilder, qetchAlgorithm.getRegexMatchResult());
200+
}
201+
}
202+
203+
@Override
204+
public void finish(
205+
List<ColumnBuilder> properColumnBuilders, ColumnBuilder passThroughIndexBuilder) {
206+
qetchAlgorithm.closeNowDataSegment();
207+
if (qetchAlgorithm.hasMatchResult()) {
208+
outputWindow(
209+
properColumnBuilders, passThroughIndexBuilder, qetchAlgorithm.getMatchResult());
210+
}
211+
if (qetchAlgorithm.hasRegexMatchResult()) {
212+
outputWindowRegex(
213+
properColumnBuilders, passThroughIndexBuilder, qetchAlgorithm.getRegexMatchResult());
214+
}
215+
}
216+
217+
private void outputWindow(
218+
List<ColumnBuilder> properColumnBuilders,
219+
ColumnBuilder passThroughIndexBuilder,
220+
MatchState matchResult) {
221+
int matchResultID = qetchAlgorithm.getMatchResultID();
222+
for (int i = 0; i < matchResult.getDataSectionList().size(); i++) {
223+
for (int j = i == 0 ? 0 : 1;
224+
j < matchResult.getDataSectionList().get(i).getPoints().size();
225+
j++) {
226+
passThroughIndexBuilder.writeLong(
227+
matchResult.getDataSectionList().get(i).getPoints().get(j).index);
228+
properColumnBuilders.get(0).writeInt(matchResultID);
229+
properColumnBuilders.get(1).writeDouble(matchResult.getMatchValue());
230+
}
231+
}
232+
233+
// after the process, the result of qetchAlgorthm will be empty
234+
qetchAlgorithm.matchResultClear();
235+
if (qetchAlgorithm.checkNextMatchResult()) {
236+
outputWindow(
237+
properColumnBuilders, passThroughIndexBuilder, qetchAlgorithm.getMatchResult());
238+
}
239+
}
240+
241+
private void outputWindowRegex(
242+
List<ColumnBuilder> properColumnBuilders,
243+
ColumnBuilder passThroughIndexBuilder,
244+
RegexMatchState matchResult) {
245+
for (RegexMatchState.PathState pathState : matchResult.getMatchResult()) {
246+
int matchResultID = qetchAlgorithm.getMatchResultID();
247+
int dataSectionIndex = pathState.getDataSectionIndex();
248+
List<Section> dataSectionList = matchResult.getDataSectionList();
249+
for (int i = 0; i <= dataSectionIndex; i++) {
250+
for (int j = i == 0 ? 0 : 1; j < dataSectionList.get(i).getPoints().size(); j++) {
251+
passThroughIndexBuilder.writeLong(dataSectionList.get(i).getPoints().get(j).index);
252+
properColumnBuilders.get(0).writeInt(matchResultID);
253+
properColumnBuilders.get(1).writeDouble(pathState.getMatchValue());
254+
}
255+
}
256+
}
257+
258+
// after the process, the result of qetchAlgorthm will be empty
259+
qetchAlgorithm.matchResultClear();
260+
if (qetchAlgorithm.checkNextRegexMatchResult()) {
261+
outputWindowRegex(
262+
properColumnBuilders, passThroughIndexBuilder, qetchAlgorithm.getRegexMatchResult());
263+
}
264+
}
265+
}
266+
}

0 commit comments

Comments
 (0)