Skip to content

Commit 053c4e3

Browse files
feat: support custom order in insert stmt (#2075)
1 parent b2762f9 commit 053c4e3

File tree

16 files changed

+798
-558
lines changed

16 files changed

+798
-558
lines changed

.github/workflows/cicd.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ jobs:
8080
run: |
8181
make test
8282
83+
- name: run sql_router_test
84+
id: sql_router_test
85+
run: |
86+
bash steps/ut.sh sql_router_test 0
87+
8388
- name: run sql_sdk_test
8489
id: sql_sdk_test
8590
run: |
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Copyright 2021 4Paradigm
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com._4paradigm.openmldb.common;
18+
19+
import java.io.Serializable;
20+
21+
public class Pair<K, V> implements Serializable {
22+
23+
/**
24+
* Key of this <code>Pair</code>.
25+
*/
26+
private K key;
27+
28+
/**
29+
* Gets the key for this pair.
30+
*
31+
* @return key for this pair
32+
*/
33+
public K getKey() {
34+
return key;
35+
}
36+
37+
/**
38+
* Value of this this <code>Pair</code>.
39+
*/
40+
private V value;
41+
42+
/**
43+
* Gets the value for this pair.
44+
*
45+
* @return value for this pair
46+
*/
47+
public V getValue() {
48+
return value;
49+
}
50+
51+
/**
52+
* Creates a new pair
53+
*
54+
* @param key The key for this pair
55+
* @param value The value to use for this pair
56+
*/
57+
public Pair(K key, V value) {
58+
this.key = key;
59+
this.value = value;
60+
}
61+
62+
/**
63+
* <p><code>String</code> representation of this
64+
* <code>Pair</code>.</p>
65+
*
66+
* <p>The default name/value delimiter '=' is always used.</p>
67+
*
68+
* @return <code>String</code> representation of this <code>Pair</code>
69+
*/
70+
@Override
71+
public String toString() {
72+
return key + "=" + value;
73+
}
74+
75+
/**
76+
* <p>Generate a hash code for this <code>Pair</code>.</p>
77+
*
78+
* <p>The hash code is calculated using both the name and
79+
* the value of the <code>Pair</code>.</p>
80+
*
81+
* @return hash code for this <code>Pair</code>
82+
*/
83+
@Override
84+
public int hashCode() {
85+
// name's hashCode is multiplied by an arbitrary prime number (13)
86+
// in order to make sure there is a difference in the hashCode between
87+
// these two parameters:
88+
// name: a value: aa
89+
// name: aa value: a
90+
return key.hashCode() * 13 + (value == null ? 0 : value.hashCode());
91+
}
92+
93+
/**
94+
* <p>Test this <code>Pair</code> for equality with another
95+
* <code>Object</code>.</p>
96+
*
97+
* <p>If the <code>Object</code> to be tested is not a
98+
* <code>Pair</code> or is <code>null</code>, then this method
99+
* returns <code>false</code>.</p>
100+
*
101+
* <p>Two <code>Pair</code>s are considered equal if and only if
102+
* both the names and values are equal.</p>
103+
*
104+
* @param o the <code>Object</code> to test for
105+
* equality with this <code>Pair</code>
106+
* @return <code>true</code> if the given <code>Object</code> is
107+
* equal to this <code>Pair</code> else <code>false</code>
108+
*/
109+
@Override
110+
public boolean equals(Object o) {
111+
if (this == o) return true;
112+
if (o instanceof Pair) {
113+
Pair pair = (Pair) o;
114+
if (key != null ? !key.equals(pair.key) : pair.key != null) return false;
115+
if (value != null ? !value.equals(pair.value) : pair.value != null) return false;
116+
return true;
117+
}
118+
return false;
119+
}
120+
}

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/jdbc/SQLInsertMetaData.java

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616

1717
package com._4paradigm.openmldb.jdbc;
1818

19+
import static com._4paradigm.openmldb.sdk.impl.Util.sqlTypeToString;
20+
1921
import com._4paradigm.openmldb.DataType;
2022
import com._4paradigm.openmldb.Schema;
23+
import com._4paradigm.openmldb.common.Pair;
2124
import com._4paradigm.openmldb.sdk.Common;
2225

2326
import java.sql.ResultSetMetaData;
@@ -28,10 +31,11 @@ public class SQLInsertMetaData implements ResultSetMetaData {
2831

2932
private final List<DataType> schema;
3033
private final Schema realSchema;
31-
private final List<Integer> idx;
34+
private final List<Pair<Long, Integer>> idx;
35+
3236
public SQLInsertMetaData(List<DataType> schema,
3337
Schema realSchema,
34-
List<Integer> idx) {
38+
List<Pair<Long, Integer>> idx) {
3539
this.schema = schema;
3640
this.realSchema = realSchema;
3741
this.idx = idx;
@@ -90,7 +94,7 @@ public boolean isCurrency(int i) throws SQLException {
9094
@Override
9195
public int isNullable(int i) throws SQLException {
9296
check(i);
93-
int index = idx.get(i - 1);
97+
Long index = idx.get(i - 1).getKey();
9498
if (realSchema.IsColumnNotNull(index)) {
9599
return columnNoNulls;
96100
} else {
@@ -119,7 +123,7 @@ public String getColumnLabel(int i) throws SQLException {
119123
@Override
120124
public String getColumnName(int i) throws SQLException {
121125
check(i);
122-
int index = idx.get(i - 1);
126+
Long index = idx.get(i - 1).getKey();
123127
return realSchema.GetColumnName(index);
124128
}
125129

@@ -156,14 +160,13 @@ public String getCatalogName(int i) throws SQLException {
156160
@Override
157161
public int getColumnType(int i) throws SQLException {
158162
check(i);
159-
DataType dataType = schema.get(i - 1);
160-
return Common.type2SqlType(dataType);
163+
Long index = idx.get(i - 1).getKey();
164+
return Common.type2SqlType(realSchema.GetColumnType(index));
161165
}
162166

163167
@Override
164-
@Deprecated
165168
public String getColumnTypeName(int i) throws SQLException {
166-
throw new SQLException("current do not support this method");
169+
return sqlTypeToString(getColumnType(i));
167170
}
168171

169172
@Override

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/InsertPreparedStatementImpl.java

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import com._4paradigm.openmldb.*;
2020

21+
import com._4paradigm.openmldb.common.Pair;
2122
import com._4paradigm.openmldb.jdbc.SQLInsertMetaData;
2223
import org.slf4j.Logger;
2324
import org.slf4j.LoggerFactory;
@@ -32,6 +33,7 @@
3233
import java.sql.Date;
3334
import java.sql.ResultSet;
3435
import java.util.*;
36+
import java.util.stream.Collectors;
3537

3638
public class InsertPreparedStatementImpl implements PreparedStatement {
3739
public static final Charset CHARSET = StandardCharsets.UTF_8;
@@ -48,7 +50,10 @@ public class InsertPreparedStatementImpl implements PreparedStatement {
4850
private final List<Object> currentDatas;
4951
private final List<DataType> currentDatasType;
5052
private final List<Boolean> hasSet;
51-
private final List<Integer> scehmaIdxs;
53+
// stmt insert idx -> real table schema idx
54+
private final List<Pair<Long, Integer>> schemaIdxes;
55+
// used by building row
56+
private final List<Pair<Long, Integer>> sortedIdxes;
5257

5358
private boolean closed = false;
5459
private boolean closeOnComplete = false;
@@ -63,18 +68,27 @@ public InsertPreparedStatementImpl(String db, String sql, SQLRouter router) thro
6368
this.currentSchema = tempRow.GetSchema();
6469
VectorUint32 idxes = tempRow.GetHoleIdx();
6570

71+
// In stmt order, if no columns in stmt, in schema order
72+
// We'll sort it to schema order later, so needs the map <real_schema_idx, current_data_idx>
73+
schemaIdxes = new ArrayList<>(idxes.size());
74+
// CurrentData and Type order is consistent with insert stmt. We'll do appending in schema order when build
75+
// row.
6676
currentDatas = new ArrayList<>(idxes.size());
6777
currentDatasType = new ArrayList<>(idxes.size());
6878
hasSet = new ArrayList<>(idxes.size());
69-
scehmaIdxs = new ArrayList<>(idxes.size());
79+
7080
for (int i = 0; i < idxes.size(); i++) {
71-
long idx = idxes.get(i);
72-
DataType type = currentSchema.GetColumnType(idx);
81+
Long realIdx = idxes.get(i);
82+
schemaIdxes.add(new Pair<>(realIdx, i));
83+
DataType type = currentSchema.GetColumnType(realIdx);
7384
currentDatasType.add(type);
7485
currentDatas.add(null);
7586
hasSet.add(false);
76-
scehmaIdxs.add(i);
87+
logger.debug("add col {}, {}", currentSchema.GetColumnName(realIdx), type);
7788
}
89+
// SQLInsertRow::AppendXXX order is the schema order(skip the no-hole columns)
90+
sortedIdxes = schemaIdxes.stream().sorted(Comparator.comparing(Pair::getKey))
91+
.collect(Collectors.toList());
7892
}
7993

8094
private SQLInsertRow getSQLInsertRow() throws SQLException {
@@ -118,14 +132,14 @@ private void checkIdx(int i) throws SQLException {
118132
if (i <= 0) {
119133
throw new SQLException("error sqe number");
120134
}
121-
if (i > scehmaIdxs.size()) {
135+
if (i > schemaIdxes.size()) {
122136
throw new SQLException("out of data range");
123137
}
124138
}
125139

126140
private void checkType(int i, DataType type) throws SQLException {
127141
if (currentDatasType.get(i - 1) != type) {
128-
throw new SQLException("data type not match");
142+
throw new SQLException("data type not match, expect " + currentDatasType.get(i - 1) + ", actual " + type);
129143
}
130144
}
131145

@@ -206,7 +220,7 @@ public void setBigDecimal(int i, BigDecimal bigDecimal) throws SQLException {
206220
}
207221

208222
private boolean checkNotAllowNull(int i) {
209-
long idx = this.scehmaIdxs.get(i - 1);
223+
Long idx = this.schemaIdxes.get(i - 1).getKey();
210224
return this.currentSchema.IsColumnNotNull(idx);
211225
}
212226

@@ -300,22 +314,22 @@ public void setObject(int i, Object o, int i1) throws SQLException {
300314

301315
private void buildRow() throws SQLException {
302316
SQLInsertRow currentRow = getSQLInsertRow();
303-
304317
boolean ok = currentRow.Init(stringsLen);
305318
if (!ok) {
306319
throw new SQLException("init row failed");
307320
}
308321

309-
for (int i = 0; i < currentDatasType.size(); i++) {
310-
Object data = currentDatas.get(i);
322+
for (Pair<Long, Integer> sortedIdx : sortedIdxes) {
323+
Integer currentDataIdx = sortedIdx.getValue();
324+
Object data = currentDatas.get(currentDataIdx);
311325
if (data == null) {
312326
ok = currentRow.AppendNULL();
313327
} else {
314-
DataType curType = currentDatasType.get(i);
328+
DataType curType = currentDatasType.get(currentDataIdx);
315329
if (DataType.kTypeBool.equals(curType)) {
316330
ok = currentRow.AppendBool((boolean) data);
317331
} else if (DataType.kTypeDate.equals(curType)) {
318-
java.sql.Date date = (java.sql.Date) data;
332+
Date date = (Date) data;
319333
ok = currentRow.AppendDate(date.getYear() + 1900, date.getMonth() + 1, date.getDate());
320334
} else if (DataType.kTypeDouble.equals(curType)) {
321335
ok = currentRow.AppendDouble((double) data);
@@ -333,7 +347,7 @@ private void buildRow() throws SQLException {
333347
} else if (DataType.kTypeTimestamp.equals(curType)) {
334348
ok = currentRow.AppendTimestamp((long) data);
335349
} else {
336-
throw new SQLException("unkown data type");
350+
throw new SQLException("unknown data type");
337351
}
338352
}
339353
if (!ok) {
@@ -423,9 +437,8 @@ public void setArray(int i, Array array) throws SQLException {
423437
}
424438

425439
@Override
426-
@Deprecated
427440
public ResultSetMetaData getMetaData() throws SQLException {
428-
return new SQLInsertMetaData(this.currentDatasType, this.currentSchema, this.scehmaIdxs);
441+
return new SQLInsertMetaData(this.currentDatasType, this.currentSchema, this.schemaIdxes);
429442
}
430443

431444
@Override

java/openmldb-jdbc/src/test/java/com/_4paradigm/openmldb/jdbc/JDBCDriverTest.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.sql.ResultSet;
2424
import java.sql.SQLException;
2525
import java.sql.Statement;
26+
import java.sql.Types;
2627
import java.util.ArrayList;
2728
import java.util.List;
2829
import java.util.stream.IntStream;
@@ -183,7 +184,7 @@ public void testForKafkaConnector() throws SQLException {
183184
String tableName = "kafka_test";
184185
stmt = connection.createStatement();
185186
try {
186-
stmt.execute(String.format("create table if not exists %s(c1 int, c2 string)", tableName));
187+
stmt.execute(String.format("create table if not exists %s(c1 int, c2 string, c3 timestamp)", tableName));
187188
} catch (Exception e) {
188189
Assert.fail();
189190
}
@@ -198,6 +199,15 @@ public void testForKafkaConnector() throws SQLException {
198199
pstmt.setFetchSize(100);
199200

200201
pstmt.addBatch();
202+
insertSql = "INSERT INTO " +
203+
tableName +
204+
"(`c3`,`c2`) VALUES(?,?)";
205+
pstmt = connection.prepareStatement(insertSql);
206+
Assert.assertEquals(pstmt.getMetaData().getColumnCount(), 2);
207+
// index starts from 1
208+
Assert.assertEquals(pstmt.getMetaData().getColumnType(2), Types.VARCHAR);
209+
Assert.assertEquals(pstmt.getMetaData().getColumnName(2), "c2");
210+
201211

202212
try {
203213
stmt = connection.prepareStatement("DELETE FROM " + tableName + " WHERE c1=1");

0 commit comments

Comments
 (0)