Skip to content

Commit b8e4db4

Browse files
committed
convert field name
1 parent eacbf12 commit b8e4db4

File tree

2 files changed

+118
-42
lines changed

2 files changed

+118
-42
lines changed

core/src/main/java/com/dtstack/flink/sql/exec/FlinkSQLExec.java

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import org.apache.flink.table.api.Table;
77
import org.apache.flink.table.api.TableEnvironment;
88
import org.apache.flink.table.api.TableException;
9+
import org.apache.flink.table.api.ValidationException;
910
import org.apache.flink.table.api.java.StreamTableEnvironment;
1011
import org.apache.flink.table.calcite.FlinkPlannerImpl;
1112
import org.apache.flink.table.plan.logical.LogicalRelNode;
@@ -36,17 +37,23 @@ public static void sqlUpdate(StreamTableEnvironment tableEnv, String stmt) throw
3637
Table queryResult = new Table(tableEnv, new LogicalRelNode(planner.rel(validatedQuery).rel));
3738
String targetTableName = ((SqlIdentifier) ((SqlInsert) insert).getTargetTable()).names.get(0);
3839

40+
Method method = TableEnvironment.class.getDeclaredMethod("getTable", String.class);
41+
method.setAccessible(true);
42+
43+
TableSinkTable targetTable = (TableSinkTable) method.invoke(tableEnv, targetTableName);
44+
String[] fieldNames = targetTable.tableSink().getFieldNames();
45+
46+
Table newTable = null;
47+
3948
try {
40-
Method method = TableEnvironment.class.getDeclaredMethod("getTable", String.class);
41-
method.setAccessible(true);
42-
43-
TableSinkTable targetTable = (TableSinkTable) method.invoke(tableEnv, targetTableName);
44-
String[] fieldNames = targetTable.tableSink().getFieldNames();
45-
Table newTable = queryResult.select(String.join(",", fieldNames));
46-
// insert query result into sink table
47-
tableEnv.insertInto(newTable, targetTableName, tableEnv.queryConfig());
49+
newTable = queryResult.select(String.join(",", fieldNames));
4850
} catch (Exception e) {
49-
throw e;
51+
throw new ValidationException(
52+
"Field name of query result and registered TableSink "+targetTableName +" do not match.\n" +
53+
"Query result schema: " + String.join(",", queryResult.getSchema().getColumnNames()) + "\n" +
54+
"TableSink schema: " + String.join(",", fieldNames));
5055
}
56+
57+
tableEnv.insertInto(newTable, targetTableName, tableEnv.queryConfig());
5158
}
5259
}

core/src/main/java/com/dtstack/flink/sql/side/SideSqlExec.java

Lines changed: 102 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,16 @@
1717
*/
1818

1919

20-
2120
package com.dtstack.flink.sql.side;
2221

23-
import com.dtstack.flink.sql.Main;
2422
import com.dtstack.flink.sql.enums.ECacheType;
2523
import com.dtstack.flink.sql.exec.FlinkSQLExec;
2624
import com.dtstack.flink.sql.parser.CreateTmpTableParser;
2725
import com.dtstack.flink.sql.side.operator.SideAsyncOperator;
2826
import com.dtstack.flink.sql.side.operator.SideWithAllCacheOperator;
2927
import com.dtstack.flink.sql.util.ClassUtil;
3028
import com.dtstack.flink.sql.util.ParseUtils;
29+
import org.apache.calcite.sql.SqlAsOperator;
3130
import org.apache.calcite.sql.SqlBasicCall;
3231
import org.apache.calcite.sql.SqlDataTypeSpec;
3332
import org.apache.calcite.sql.SqlIdentifier;
@@ -37,6 +36,7 @@
3736
import org.apache.calcite.sql.SqlLiteral;
3837
import org.apache.calcite.sql.SqlNode;
3938
import org.apache.calcite.sql.SqlNodeList;
39+
import org.apache.calcite.sql.SqlOperator;
4040
import org.apache.calcite.sql.SqlSelect;
4141
import org.apache.calcite.sql.fun.SqlCase;
4242
import org.apache.calcite.sql.parser.SqlParseException;
@@ -101,8 +101,11 @@ public void exec(String sql, Map<String, SideTableInfo> sideTableMap, StreamTabl
101101

102102
if(preIsSideJoin){
103103
preIsSideJoin = false;
104+
List<String> fieldNames = null;
104105
for(FieldReplaceInfo replaceInfo : replaceInfoList){
105-
replaceFieldName(pollSqlNode, replaceInfo.getMappingTable(), replaceInfo.getTargetTableName(), replaceInfo.getTargetTableAlias());
106+
fieldNames = Lists.newArrayList();
107+
replaceFieldName(pollSqlNode, replaceInfo.getMappingTable(), replaceInfo.getTargetTableName(), replaceInfo.getTargetTableAlias(), fieldNames);
108+
dealMidConvertField(pollSqlNode, fieldNames);
106109
}
107110
}
108111

@@ -126,6 +129,66 @@ public void exec(String sql, Map<String, SideTableInfo> sideTableMap, StreamTabl
126129

127130
}
128131

132+
133+
private void dealMidConvertField(SqlNode pollSqlNode, List<String> field) {
134+
SqlKind sqlKind = pollSqlNode.getKind();
135+
switch (sqlKind) {
136+
case INSERT:
137+
SqlNode source = ((SqlInsert) pollSqlNode).getSource();
138+
dealMidConvertField(source, field);
139+
break;
140+
141+
case AS:
142+
dealMidConvertField(((SqlBasicCall) pollSqlNode).getOperands()[0], field);
143+
break;
144+
145+
case SELECT:
146+
147+
SqlNodeList selectList = ((SqlSelect) pollSqlNode).getSelectList();
148+
149+
selectList.getList().forEach(node -> {
150+
if (node.getKind() == IDENTIFIER) {
151+
SqlIdentifier sqlIdentifier = (SqlIdentifier) node;
152+
if (sqlIdentifier.names.size() == 1) {
153+
return;
154+
}
155+
String name = sqlIdentifier.names.get(1);
156+
if (!name.endsWith("0")) {
157+
field.add(name);
158+
}
159+
160+
}
161+
});
162+
// convert
163+
for (int i = 0; i < selectList.getList().size(); i++) {
164+
SqlNode node = selectList.get(i);
165+
if (node.getKind() == IDENTIFIER) {
166+
SqlIdentifier sqlIdentifier = (SqlIdentifier) node;
167+
if (sqlIdentifier.names.size() == 1) {
168+
return;
169+
}
170+
171+
String name = sqlIdentifier.names.get(1);
172+
if (name.endsWith("0") && !field.contains(name)) {
173+
SqlOperator operator = new SqlAsOperator();
174+
SqlParserPos sqlParserPos = new SqlParserPos(0, 0);
175+
176+
SqlIdentifier sqlIdentifierAlias = new SqlIdentifier(name.substring(0, name.length() - 1), null, sqlParserPos);
177+
SqlNode[] sqlNodes = new SqlNode[2];
178+
sqlNodes[0] = sqlIdentifier;
179+
sqlNodes[1] = sqlIdentifierAlias;
180+
SqlBasicCall sqlBasicCall = new SqlBasicCall(operator, sqlNodes, sqlParserPos);
181+
182+
selectList.set(i, sqlBasicCall);
183+
}
184+
185+
}
186+
}
187+
break;
188+
}
189+
}
190+
191+
129192
public AliasInfo parseASNode(SqlNode sqlNode) throws SqlParseException {
130193
SqlKind sqlKind = sqlNode.getKind();
131194
if(sqlKind != AS){
@@ -164,16 +227,16 @@ public RowTypeInfo buildOutRowTypeInfo(List<FieldInfo> sideJoinFieldInfo, HashBa
164227
}
165228

166229
//需要考虑更多的情况
167-
private void replaceFieldName(SqlNode sqlNode, HashBasedTable<String, String, String> mappingTable, String targetTableName, String tableAlias) {
230+
private void replaceFieldName(SqlNode sqlNode, HashBasedTable<String, String, String> mappingTable, String targetTableName, String tableAlias, List<String> fieldNames) {
168231
SqlKind sqlKind = sqlNode.getKind();
169232
switch (sqlKind) {
170233
case INSERT:
171234
SqlNode sqlSource = ((SqlInsert) sqlNode).getSource();
172-
replaceFieldName(sqlSource, mappingTable, targetTableName, tableAlias);
235+
replaceFieldName(sqlSource, mappingTable, targetTableName, tableAlias, fieldNames);
173236
break;
174237
case AS:
175-
SqlNode asNode = ((SqlBasicCall)sqlNode).getOperands()[0];
176-
replaceFieldName(asNode, mappingTable, targetTableName, tableAlias);
238+
SqlNode asNode = ((SqlBasicCall) sqlNode).getOperands()[0];
239+
replaceFieldName(asNode, mappingTable, targetTableName, tableAlias, fieldNames);
177240
break;
178241
case SELECT:
179242
SqlSelect sqlSelect = (SqlSelect) filterNodeWithTargetName(sqlNode, targetTableName);
@@ -202,7 +265,7 @@ private void replaceFieldName(SqlNode sqlNode, HashBasedTable<String, String, St
202265
continue;
203266
}
204267

205-
SqlNode replaceNode = replaceSelectFieldName(selectNode, mappingTable, tableAlias);
268+
SqlNode replaceNode = replaceSelectFieldName(selectNode, mappingTable, tableAlias, fieldNames);
206269
if(replaceNode == null){
207270
continue;
208271
}
@@ -219,15 +282,15 @@ private void replaceFieldName(SqlNode sqlNode, HashBasedTable<String, String, St
219282
SqlNode[] sqlNodeList = ((SqlBasicCall)whereNode).getOperands();
220283
for(int i =0; i<sqlNodeList.length; i++) {
221284
SqlNode whereSqlNode = sqlNodeList[i];
222-
SqlNode replaceNode = replaceNodeInfo(whereSqlNode, mappingTable, tableAlias);
285+
SqlNode replaceNode = replaceNodeInfo(whereSqlNode, mappingTable, tableAlias, fieldNames);
223286
sqlNodeList[i] = replaceNode;
224287
}
225288
}
226289

227290
if(sqlGroup != null && CollectionUtils.isNotEmpty(sqlGroup.getList())){
228291
for( int i=0; i<sqlGroup.getList().size(); i++){
229292
SqlNode selectNode = sqlGroup.getList().get(i);
230-
SqlNode replaceNode = replaceNodeInfo(selectNode, mappingTable, tableAlias);
293+
SqlNode replaceNode = replaceNodeInfo(selectNode, mappingTable, tableAlias, fieldNames);
231294
sqlGroup.set(i, replaceNode);
232295
}
233296
}
@@ -247,7 +310,7 @@ private void replaceFieldName(SqlNode sqlNode, HashBasedTable<String, String, St
247310
}
248311
}
249312

250-
private SqlNode replaceNodeInfo(SqlNode groupNode, HashBasedTable<String, String, String> mappingTable, String tableAlias){
313+
private SqlNode replaceNodeInfo(SqlNode groupNode, HashBasedTable<String, String, String> mappingTable, String tableAlias, List<String> fieldNames){
251314
if(groupNode.getKind() == IDENTIFIER){
252315
SqlIdentifier sqlIdentifier = (SqlIdentifier) groupNode;
253316
String mappingFieldName = mappingTable.get(sqlIdentifier.getComponent(0).getSimple(), sqlIdentifier.getComponent(1).getSimple());
@@ -257,7 +320,7 @@ private SqlNode replaceNodeInfo(SqlNode groupNode, HashBasedTable<String, String
257320
SqlBasicCall sqlBasicCall = (SqlBasicCall) groupNode;
258321
for(int i=0; i<sqlBasicCall.getOperandList().size(); i++){
259322
SqlNode sqlNode = sqlBasicCall.getOperandList().get(i);
260-
SqlNode replaceNode = replaceSelectFieldName(sqlNode, mappingTable, tableAlias);
323+
SqlNode replaceNode = replaceSelectFieldName(sqlNode, mappingTable, tableAlias, fieldNames);
261324
sqlBasicCall.getOperands()[i] = replaceNode;
262325
}
263326

@@ -267,7 +330,7 @@ private SqlNode replaceNodeInfo(SqlNode groupNode, HashBasedTable<String, String
267330
}
268331
}
269332

270-
public SqlNode filterNodeWithTargetName(SqlNode sqlNode, String targetTableName){
333+
public SqlNode filterNodeWithTargetName(SqlNode sqlNode, String targetTableName) {
271334

272335
SqlKind sqlKind = sqlNode.getKind();
273336
switch (sqlKind){
@@ -304,7 +367,7 @@ public SqlNode filterNodeWithTargetName(SqlNode sqlNode, String targetTableName)
304367
}
305368

306369

307-
public void setLocalSqlPluginPath(String localSqlPluginPath){
370+
public void setLocalSqlPluginPath(String localSqlPluginPath) {
308371
this.localSqlPluginPath = localSqlPluginPath;
309372
}
310373

@@ -348,12 +411,14 @@ private List<SqlNode> replaceSelectStarFieldName(SqlNode selectNode, HashBasedTa
348411
}
349412
}
350413

351-
private SqlNode replaceSelectFieldName(SqlNode selectNode, HashBasedTable<String, String, String> mappingTable, String tableAlias){
352-
if(selectNode.getKind() == AS){
353-
SqlNode leftNode = ((SqlBasicCall)selectNode).getOperands()[0];
354-
SqlNode replaceNode = replaceSelectFieldName(leftNode, mappingTable, tableAlias);
355-
if(replaceNode != null){
356-
((SqlBasicCall)selectNode).getOperands()[0] = replaceNode;
414+
private SqlNode replaceSelectFieldName(SqlNode selectNode, HashBasedTable<String, String, String> mappingTable, String tableAlias, List<String> fieldNames) {
415+
if (selectNode.getKind() == AS) {
416+
SqlNode leftNode = ((SqlBasicCall) selectNode).getOperands()[0];
417+
SqlNode rightNode = ((SqlBasicCall) selectNode).getOperands()[1];
418+
fieldNames.add(rightNode.toString());
419+
SqlNode replaceNode = replaceSelectFieldName(leftNode, mappingTable, tableAlias, fieldNames);
420+
if (replaceNode != null) {
421+
((SqlBasicCall) selectNode).getOperands()[0] = replaceNode;
357422
}
358423

359424
return selectNode;
@@ -419,7 +484,7 @@ private SqlNode replaceSelectFieldName(SqlNode selectNode, HashBasedTable<String
419484
continue;
420485
}
421486

422-
SqlNode replaceNode = replaceSelectFieldName(sqlNode, mappingTable, tableAlias);
487+
SqlNode replaceNode = replaceSelectFieldName(sqlNode, mappingTable, tableAlias, fieldNames);
423488
if(replaceNode == null){
424489
continue;
425490
}
@@ -437,21 +502,21 @@ private SqlNode replaceSelectFieldName(SqlNode selectNode, HashBasedTable<String
437502

438503
for(int i=0; i<whenOperands.size(); i++){
439504
SqlNode oneOperand = whenOperands.get(i);
440-
SqlNode replaceNode = replaceSelectFieldName(oneOperand, mappingTable, tableAlias);
441-
if(replaceNode != null){
505+
SqlNode replaceNode = replaceSelectFieldName(oneOperand, mappingTable, tableAlias, fieldNames);
506+
if (replaceNode != null) {
442507
whenOperands.set(i, replaceNode);
443508
}
444509
}
445510

446511
for(int i=0; i<thenOperands.size(); i++){
447512
SqlNode oneOperand = thenOperands.get(i);
448-
SqlNode replaceNode = replaceSelectFieldName(oneOperand, mappingTable, tableAlias);
449-
if(replaceNode != null){
513+
SqlNode replaceNode = replaceSelectFieldName(oneOperand, mappingTable, tableAlias, fieldNames);
514+
if (replaceNode != null) {
450515
thenOperands.set(i, replaceNode);
451516
}
452517
}
453518

454-
((SqlCase) selectNode).setOperand(3, replaceSelectFieldName(elseNode, mappingTable, tableAlias));
519+
((SqlCase) selectNode).setOperand(3, replaceSelectFieldName(elseNode, mappingTable, tableAlias, fieldNames));
455520
return selectNode;
456521
}else if(selectNode.getKind() == OTHER){
457522
//不处理
@@ -463,17 +528,18 @@ private SqlNode replaceSelectFieldName(SqlNode selectNode, HashBasedTable<String
463528

464529
/**
465530
* Analyzing conditions are very join the dimension tables include all equivalent conditions (i.e., dimension table is the primary key definition
531+
*
466532
* @return
467533
*/
468-
private boolean checkJoinCondition(SqlNode conditionNode, String sideTableAlias, SideTableInfo sideTableInfo){
534+
private boolean checkJoinCondition(SqlNode conditionNode, String sideTableAlias, SideTableInfo sideTableInfo) {
469535
List<String> conditionFields = getConditionFields(conditionNode, sideTableAlias, sideTableInfo);
470536
if(CollectionUtils.isEqualCollection(conditionFields, convertPrimaryAlias(sideTableInfo))){
471537
return true;
472538
}
473539
return false;
474540
}
475541

476-
private List<String> convertPrimaryAlias(SideTableInfo sideTableInfo){
542+
private List<String> convertPrimaryAlias(SideTableInfo sideTableInfo) {
477543
List<String> res = Lists.newArrayList();
478544
sideTableInfo.getPrimaryKeys().forEach(field -> {
479545
res.add(sideTableInfo.getPhysicalFields().getOrDefault(field, field));
@@ -535,8 +601,11 @@ public void registerTmpTable(CreateTmpTableParser.SqlParserResult result,
535601

536602
if(preIsSideJoin){
537603
preIsSideJoin = false;
538-
for(FieldReplaceInfo replaceInfo : replaceInfoList){
539-
replaceFieldName(pollSqlNode, replaceInfo.getMappingTable(), replaceInfo.getTargetTableName(), replaceInfo.getTargetTableAlias());
604+
List<String> fieldNames = null;
605+
for (FieldReplaceInfo replaceInfo : replaceInfoList) {
606+
fieldNames = Lists.newArrayList();
607+
replaceFieldName(pollSqlNode, replaceInfo.getMappingTable(), replaceInfo.getTargetTableName(), replaceInfo.getTargetTableAlias(), fieldNames);
608+
dealMidConvertField(pollSqlNode, fieldNames);
540609
}
541610
}
542611

@@ -572,6 +641,7 @@ public void registerTmpTable(CreateTmpTableParser.SqlParserResult result,
572641
}
573642
}
574643
}
644+
575645
private void joinFun(Object pollObj, Map<String, Table> localTableCache,
576646
Map<String, SideTableInfo> sideTableMap, StreamTableEnvironment tableEnv,
577647
List<FieldReplaceInfo> replaceInfoList) throws Exception{
@@ -655,12 +725,11 @@ private void joinFun(Object pollObj, Map<String, Table> localTableCache,
655725
}
656726
}
657727

658-
private boolean checkFieldsInfo(CreateTmpTableParser.SqlParserResult result, Table table){
728+
private boolean checkFieldsInfo(CreateTmpTableParser.SqlParserResult result, Table table) {
659729
List<String> fieldNames = new LinkedList<>();
660730
String fieldsInfo = result.getFieldsInfoStr();
661731
String[] fields = fieldsInfo.split(",");
662-
for (int i=0; i < fields.length; i++)
663-
{
732+
for (int i = 0; i < fields.length; i++) {
664733
String[] filed = fields[i].split("\\s");
665734
if (filed.length < 2 || fields.length != table.getSchema().getColumnNames().length){
666735
return false;

0 commit comments

Comments
 (0)