Skip to content

Commit 1c4493a

Browse files
committed
[CALCITE-5740] Support for AggToSemiJoinRule
1 parent d0c72d1 commit 1c4493a

File tree

4 files changed

+110
-8
lines changed

4 files changed

+110
-8
lines changed

core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,12 @@ private CoreRules() {}
162162
public static final AggregateJoinTransposeRule AGGREGATE_JOIN_TRANSPOSE_EXTENDED =
163163
AggregateJoinTransposeRule.Config.EXTENDED.toRule();
164164

165+
/** Rule that creates a {@link Join#isSemiJoin semi-join} from a
166+
* {@link Aggregate} on top of a {@link Join} with an {@link Aggregate} as its
167+
* right input. */
168+
public static final SemiJoinRule.AggregateToSemiJoinRule AGGREGATE_TO_SEMI_JOIN =
169+
SemiJoinRule.AggregateToSemiJoinRule.AggregateToSemiJoinRuleConfig.DEFAULT.toRule();
170+
165171
/** Rule that pushes an {@link Aggregate}
166172
* past a non-distinct {@link Union}. */
167173
public static final AggregateUnionTransposeRule AGGREGATE_UNION_TRANSPOSE =

core/src/main/java/org/apache/calcite/rel/rules/SemiJoinRule.java

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,15 @@ protected SemiJoinRule(Config config) {
6767
super(config);
6868
}
6969

70-
protected void perform(RelOptRuleCall call, @Nullable Project project,
70+
protected void perform(RelOptRuleCall call, @Nullable RelNode topRel,
7171
Join join, RelNode left, Aggregate aggregate) {
7272
final RelOptCluster cluster = join.getCluster();
7373
final RexBuilder rexBuilder = cluster.getRexBuilder();
74-
if (project != null) {
75-
final ImmutableBitSet bits =
76-
RelOptUtil.InputFinder.bits(project.getProjects(), null);
74+
if (topRel != null) {
75+
final ImmutableBitSet bits = getUsedFields(topRel);
76+
if (bits.isEmpty()) {
77+
return;
78+
}
7779
final ImmutableBitSet rightBits =
7880
ImmutableBitSet.range(left.getRowType().getFieldCount(),
7981
join.getRowType().getFieldCount());
@@ -123,13 +125,72 @@ protected void perform(RelOptRuleCall call, @Nullable Project project,
123125
default:
124126
throw new AssertionError(join.getJoinType());
125127
}
126-
if (project != null) {
127-
relBuilder.project(project.getProjects(), project.getRowType().getFieldNames());
128+
if (topRel != null) {
129+
if (topRel instanceof Project) {
130+
Project topProject = (Project) topRel;
131+
relBuilder.project(topProject.getProjects(), topProject.getRowType().getFieldNames());
132+
} else if (topRel instanceof Aggregate) {
133+
Aggregate topAgg = (Aggregate) topRel;
134+
relBuilder.aggregate(
135+
relBuilder.groupKey(topAgg.getGroupSet(), topAgg.getGroupSets()),
136+
topAgg.getAggCallList());
137+
}
128138
}
129139
final RelNode relNode = relBuilder.build();
130140
call.transformTo(relNode);
131141
}
132142

143+
/** Returns a bit set of the input fields used by a relational expression. */
144+
private static ImmutableBitSet getUsedFields(RelNode rel) {
145+
final RelMetadataQuery mq = rel.getCluster().getMetadataQuery();
146+
return ImmutableBitSet.union(mq.getInputFieldsUsed(rel));
147+
}
148+
149+
/** SemiJoinRule that matches a Aggregate on top of a Join with an Aggregate
150+
* as its right child.
151+
*
152+
* @see CoreRules#AGGREGATE_TO_SEMI_JOIN */
153+
public static class AggregateToSemiJoinRule extends SemiJoinRule {
154+
/** Creates a AggregateToSemiJoinRule. */
155+
protected AggregateToSemiJoinRule(AggregateToSemiJoinRuleConfig config) {
156+
super(config);
157+
}
158+
159+
@Override public void onMatch(RelOptRuleCall call) {
160+
final Aggregate topAgg = call.rel(0);
161+
final Join join = call.rel(1);
162+
final RelNode left = call.rel(2);
163+
final Aggregate rightAgg = call.rel(3);
164+
perform(call, topAgg, join, left, rightAgg);
165+
}
166+
167+
/** Rule configuration. */
168+
@Value.Immutable
169+
public interface AggregateToSemiJoinRuleConfig extends SemiJoinRule.Config {
170+
AggregateToSemiJoinRuleConfig DEFAULT = ImmutableAggregateToSemiJoinRuleConfig.of()
171+
.withDescription("SemiJoinRule:aggregate")
172+
.withOperandFor(Aggregate.class, Join.class, Aggregate.class);
173+
174+
@Override default AggregateToSemiJoinRule toRule() {
175+
return new AggregateToSemiJoinRule(this);
176+
}
177+
178+
/** Defines an operand tree for the given classes. */
179+
default AggregateToSemiJoinRuleConfig withOperandFor(
180+
Class<? extends Aggregate> topAggClass,
181+
Class<? extends Join> joinClass,
182+
Class<? extends Aggregate> rightAggClass) {
183+
return withOperandSupplier(b ->
184+
b.operand(topAggClass).oneInput(b2 ->
185+
b2.operand(joinClass)
186+
.predicate(SemiJoinRule::isJoinTypeSupported).inputs(
187+
b3 -> b3.operand(RelNode.class).anyInputs(),
188+
b4 -> b4.operand(rightAggClass).anyInputs())))
189+
.as(AggregateToSemiJoinRuleConfig.class);
190+
}
191+
}
192+
}
193+
133194
/** SemiJoinRule that matches a Project on top of a Join with an Aggregate
134195
* as its right child.
135196
*
@@ -251,8 +312,7 @@ protected JoinOnUniqueToSemiJoinRule(JoinOnUniqueToSemiJoinRuleConfig config) {
251312
final Join join = call.rel(1);
252313
final RelNode left = call.rel(2);
253314

254-
final ImmutableBitSet bits =
255-
RelOptUtil.InputFinder.bits(project.getProjects(), null);
315+
final ImmutableBitSet bits = getUsedFields(project);
256316
final ImmutableBitSet rightBits =
257317
ImmutableBitSet.range(left.getRowType().getFieldCount(),
258318
join.getRowType().getFieldCount());

core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2098,6 +2098,19 @@ private void checkSemiOrAntiJoinProjectTranspose(JoinRelType type) {
20982098
.check();
20992099
}
21002100

2101+
/** Test case for
2102+
* <a href="https://issues.apache.org/jira/browse/CALCITE-5740">[CALCITE-5740]
2103+
* Support for AggToSemiJoinRule </a>. */
2104+
@Test void testAggregateToSemiJoinRule() {
2105+
final String sql = "select distinct emp.deptno from emp\n"
2106+
+ "join (select distinct mgr from emp) d on emp.deptno = d.mgr";
2107+
sql(sql)
2108+
.withDecorrelate(true)
2109+
.withPreRule(CoreRules.AGGREGATE_PROJECT_MERGE)
2110+
.withRule(CoreRules.AGGREGATE_TO_SEMI_JOIN)
2111+
.check();
2112+
}
2113+
21012114
/** Test case for
21022115
* <a href="https://issues.apache.org/jira/browse/CALCITE-1495">[CALCITE-1495]
21032116
* SemiJoinRule should not apply to RIGHT and FULL JOIN</a>. */

core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,29 @@ LogicalProject(MGR=[$0], SUM_SAL=[$2])
11301130
LogicalAggregate(group=[{0, 1}], SUM_SAL=[SUM($2)])
11311131
LogicalProject(MGR=[$3], DEPTNO=[$7], SAL=[$5])
11321132
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
1133+
]]>
1134+
</Resource>
1135+
</TestCase>
1136+
<TestCase name="testAggregateToSemiJoinRule">
1137+
<Resource name="sql">
1138+
<![CDATA[select distinct emp.deptno from emp
1139+
join (select distinct mgr from emp) d on emp.deptno = d.mgr]]>
1140+
</Resource>
1141+
<Resource name="planBefore">
1142+
<![CDATA[
1143+
LogicalAggregate(group=[{7}])
1144+
LogicalJoin(condition=[=($7, $9)], joinType=[inner])
1145+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
1146+
LogicalAggregate(group=[{3}])
1147+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
1148+
]]>
1149+
</Resource>
1150+
<Resource name="planAfter">
1151+
<![CDATA[
1152+
LogicalAggregate(group=[{7}])
1153+
LogicalJoin(condition=[=($7, $12)], joinType=[semi])
1154+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
1155+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
11331156
]]>
11341157
</Resource>
11351158
</TestCase>

0 commit comments

Comments
 (0)