Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -858,12 +858,21 @@ private List<SqlNode> generateGroupList(Builder builder,
+ aggregate.getGroupSet() + ", just possibly a different order";

final List<SqlNode> groupKeys = new ArrayList<>();
final Join aggregateJoinInput =
aggregate.getInput() instanceof Join ? (Join) aggregate.getInput() : null;
final SqlJoin fromJoin =
builder.select.getFrom() instanceof SqlJoin ? (SqlJoin) builder.select.getFrom() : null;
final int leftFieldCount = aggregateJoinInput == null
? -1
: aggregateJoinInput.getLeft().getRowType().getFieldCount();
for (int key : groupList) {
final SqlNode field = builder.context.field(key);
SqlNode field = builder.context.field(key);
field = maybeQualifyJoinKey(field, key, fromJoin, leftFieldCount);
groupKeys.add(field);
}
for (int key : sortedGroupList) {
final SqlNode field = builder.context.field(key);
SqlNode field =
maybeQualifyJoinKey(builder.context.field(key), key, fromJoin, leftFieldCount);
addSelect(selectList, field, aggregate.getRowType());
}
switch (aggregate.getGroupType()) {
Expand All @@ -880,7 +889,8 @@ private List<SqlNode> generateGroupList(Builder builder,
final List<Integer> rollupBits = Aggregate.Group.getRollup(aggregate.groupSets);
final List<SqlNode> rollupKeys = rollupBits
.stream()
.map(bit -> builder.context.field(bit))
.map(bit ->
maybeQualifyJoinKey(builder.context.field(bit), bit, fromJoin, leftFieldCount))
.collect(Collectors.toList());
return ImmutableList.of(
SqlStdOperatorTable.ROLLUP.createCall(SqlParserPos.ZERO, rollupKeys));
Expand All @@ -905,6 +915,63 @@ private List<SqlNode> generateGroupList(Builder builder,
}
}

private SqlNode maybeQualifyJoinKey(SqlNode field, int key,
@Nullable SqlJoin fromJoin, int leftFieldCount) {
if (!isSimpleIdentifier(field) || fromJoin == null) {
return field;
}

final String fieldName = ((SqlIdentifier) field).getSimple();
if (leftFieldCount >= 0) {
final SqlNode side = key < leftFieldCount ? fromJoin.getLeft() : fromJoin.getRight();
return qualifyJoinField(SqlValidatorUtil.alias(side), fieldName, field);
}
return maybeQualifyJoinKeyWithoutInputJoin(field, key, fromJoin, fieldName);
}

private static boolean isSimpleIdentifier(SqlNode node) {
return node instanceof SqlIdentifier
&& ((SqlIdentifier) node).names.size() == 1;
}

private SqlNode maybeQualifyJoinKeyWithoutInputJoin(SqlNode field, int key,
SqlJoin fromJoin, String fieldName) {
if (key != 0) {
return field;
}
final String leftAlias = SqlValidatorUtil.alias(fromJoin.getLeft());
final String rightAlias = SqlValidatorUtil.alias(fromJoin.getRight());
switch (fromJoin.getJoinType()) {
case RIGHT:
return qualifyJoinField(rightAlias, fieldName, field);
case FULL:
if (leftAlias != null && rightAlias != null) {
return SqlStdOperatorTable.COALESCE.createCall(POS,
new SqlIdentifier(ImmutableList.of(leftAlias, fieldName), POS),
new SqlIdentifier(ImmutableList.of(rightAlias, fieldName), POS));
}
return qualifyJoinField(leftAlias != null ? leftAlias : rightAlias, fieldName, field);
case LEFT:
case LEFT_SEMI_JOIN:
case LEFT_ANTI_JOIN:
case INNER:
case CROSS:
case COMMA:
case ASOF:
case LEFT_ASOF:
default:
return qualifyJoinField(leftAlias, fieldName, field);
}
}

private static SqlNode qualifyJoinField(@Nullable String alias,
String fieldName, SqlNode fallback) {
if (alias == null) {
return fallback;
}
return new SqlIdentifier(ImmutableList.of(alias, fieldName), POS);
}

private static SqlNode groupItem(List<SqlNode> groupKeys,
ImmutableBitSet groupSet, ImmutableBitSet wholeGroupSet) {
final List<SqlNode> nodes = groupSet.asList().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,12 @@
import static org.apache.calcite.test.Matchers.isLinux;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasToString;
import static org.hamcrest.Matchers.matchesPattern;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

Expand Down Expand Up @@ -11887,6 +11890,139 @@ public Sql schema(CalciteAssert.SchemaSpec schemaSpec) {
sql(sql).schema(CalciteAssert.SchemaSpec.JDBC_SCOTT).ok(expected);
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-7439">[CALCITE-7439]
* RelToSqlConverter emits ambiguous GROUP BY after LEFT JOIN USING with
* semi-join rewrite.</a>. */
@Test void testPostgresqlRoundTripDistinctLeftJoinInSubqueryWithSemiJoinRules() {
final String query = "WITH product_keys AS (\n"
+ " SELECT p.\"product_id\",\n"
+ " (SELECT MAX(p3.\"product_id\")\n"
+ " FROM \"foodmart\".\"product\" p3\n"
+ " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n"
+ " FROM \"foodmart\".\"product\" p\n"
+ ")\n"
+ "SELECT DISTINCT pk.\"product_id\"\n"
+ "FROM product_keys pk\n"
+ "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
+ "WHERE pk.\"product_id\" IN (\n"
+ " SELECT p4.\"product_id\"\n"
+ " FROM \"foodmart\".\"product\" p4\n"
+ ")";

final RuleSet rules =
RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE,
CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE,
CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE,
CoreRules.PROJECT_TO_SEMI_JOIN);

final String generated = sql(query).withPostgresql().optimize(rules, null).exec();
assertThat(generated, containsString("GROUP BY \"t2\".\"product_id\""));
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-7439">[CALCITE-7439]
* RelToSqlConverter should not emit ambiguous GROUP BY after RIGHT JOIN USING
* with semi-join rewrite.</a>. */
@Test void testPostgresqlRoundTripDistinctRightJoinInSubqueryWithSemiJoinRules() {
final String query = "WITH product_keys AS (\n"
+ " SELECT p.\"product_id\",\n"
+ " (SELECT MAX(p3.\"product_id\")\n"
+ " FROM \"foodmart\".\"product\" p3\n"
+ " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n"
+ " FROM \"foodmart\".\"product\" p\n"
+ ")\n"
+ "SELECT DISTINCT pk.\"product_id\"\n"
+ "FROM product_keys pk\n"
+ "RIGHT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
+ "WHERE pk.\"product_id\" IN (\n"
+ " SELECT p4.\"product_id\"\n"
+ " FROM \"foodmart\".\"product\" p4\n"
+ ")";

final RuleSet rules =
RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE,
CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE,
CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE,
CoreRules.PROJECT_TO_SEMI_JOIN);

final String generated = sql(query).withPostgresql().optimize(rules, null).exec();
assertThat(generated, containsString("GROUP BY "));
assertThat(generated, containsString(".\"product_id\""));
assertThat(generated,
matchesPattern("(?s).*GROUP BY\\s+\"[^\"]+\"\\.\"product_id\".*"));
assertThat(generated, not(containsString("GROUP BY \"product_id\"")));
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-7439">[CALCITE-7439]
* RelToSqlConverter should not emit ambiguous GROUP BY after FULL JOIN USING
* with semi-join rewrite.</a>. */
@Test void testPostgresqlRoundTripDistinctFullJoinInSubqueryWithSemiJoinRules() {
final String query = "WITH product_keys AS (\n"
+ " SELECT p.\"product_id\",\n"
+ " (SELECT MAX(p3.\"product_id\")\n"
+ " FROM \"foodmart\".\"product\" p3\n"
+ " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n"
+ " FROM \"foodmart\".\"product\" p\n"
+ ")\n"
+ "SELECT DISTINCT pk.\"product_id\"\n"
+ "FROM product_keys pk\n"
+ "FULL JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
+ "WHERE pk.\"product_id\" IN (\n"
+ " SELECT p4.\"product_id\"\n"
+ " FROM \"foodmart\".\"product\" p4\n"
+ ")";

final RuleSet rules =
RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE,
CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE,
CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE,
CoreRules.PROJECT_TO_SEMI_JOIN);

final String generated = sql(query).withPostgresql().optimize(rules, null).exec();
assertThat(generated, containsString("GROUP BY "));
assertThat(generated, containsString("GROUP BY COALESCE("));
assertThat(generated,
matchesPattern("(?s).*GROUP BY\\s+COALESCE\\(\"[^\"]+\"\\.\"product_id\",\\s*"
+ "\"[^\"]+\"\\.\"product_id\"\\).*"));
assertThat(generated, not(containsString("GROUP BY \"product_id\"")));
}

@Test void testPostgresqlRoundTripRollupJoinUsingQualifiesGroupKey() {
final String query = "SELECT \"product_id\", COUNT(*)\n"
+ "FROM \"foodmart\".\"product\" p1\n"
+ "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
+ "GROUP BY ROLLUP(\"product_id\")";

final String generated = sql(query).withPostgresql().exec();
assertThat(generated,
matchesPattern("(?s).*GROUP BY\\s+ROLLUP\\(\"[^\"]+\"\\.\"product_id\"\\).*"));
assertThat(generated, not(containsString("GROUP BY ROLLUP(\"product_id\")")));
}

@Test void testPostgresqlRoundTripSingletonCubeJoinUsingQualifiesGroupKey() {
final String query = "SELECT \"product_id\", COUNT(*)\n"
+ "FROM \"foodmart\".\"product\" p1\n"
+ "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
+ "GROUP BY CUBE(\"product_id\")";

final String generated = sql(query).withPostgresql().exec();
assertThat(generated,
matchesPattern("(?s).*GROUP BY\\s+(?:CUBE|ROLLUP)\\(\"[^\"]+\"\\.\"product_id\"\\).*"));
assertThat(generated, not(containsString("GROUP BY CUBE(\"product_id\")")));
assertThat(generated, not(containsString("GROUP BY ROLLUP(\"product_id\")")));
}

@Test void testNotBetween() {
Sql f = fixture().withConvertletTable(new SqlRexConvertletTable() {
@Override public @Nullable SqlRexConvertlet get(SqlCall call) {
Expand Down
Loading