@@ -67,13 +67,15 @@ protected SemiJoinRule(Config config) {
6767 super (config );
6868 }
6969
70- protected void perform (RelOptRuleCall call , @ Nullable Project project ,
70+ protected void perform (RelOptRuleCall call , @ Nullable RelNode topRel ,
7171 Join join , RelNode left , Aggregate aggregate ) {
7272 final RelOptCluster cluster = join .getCluster ();
7373 final RexBuilder rexBuilder = cluster .getRexBuilder ();
74- if (project != null ) {
75- final ImmutableBitSet bits =
76- RelOptUtil .InputFinder .bits (project .getProjects (), null );
74+ if (topRel != null ) {
75+ final ImmutableBitSet bits = getUsedFields (topRel );
76+ if (bits .isEmpty ()) {
77+ return ;
78+ }
7779 final ImmutableBitSet rightBits =
7880 ImmutableBitSet .range (left .getRowType ().getFieldCount (),
7981 join .getRowType ().getFieldCount ());
@@ -123,13 +125,72 @@ protected void perform(RelOptRuleCall call, @Nullable Project project,
123125 default :
124126 throw new AssertionError (join .getJoinType ());
125127 }
126- if (project != null ) {
127- relBuilder .project (project .getProjects (), project .getRowType ().getFieldNames ());
128+ if (topRel != null ) {
129+ if (topRel instanceof Project ) {
130+ Project topProject = (Project ) topRel ;
131+ relBuilder .project (topProject .getProjects (), topProject .getRowType ().getFieldNames ());
132+ } else if (topRel instanceof Aggregate ) {
133+ Aggregate topAgg = (Aggregate ) topRel ;
134+ relBuilder .aggregate (
135+ relBuilder .groupKey (topAgg .getGroupSet (), topAgg .getGroupSets ()),
136+ topAgg .getAggCallList ());
137+ }
128138 }
129139 final RelNode relNode = relBuilder .build ();
130140 call .transformTo (relNode );
131141 }
132142
143+ /** Returns a bit set of the input fields used by a relational expression. */
144+ private static ImmutableBitSet getUsedFields (RelNode rel ) {
145+ final RelMetadataQuery mq = rel .getCluster ().getMetadataQuery ();
146+ return ImmutableBitSet .union (mq .getInputFieldsUsed (rel ));
147+ }
148+
149+ /** SemiJoinRule that matches a Aggregate on top of a Join with an Aggregate
150+ * as its right child.
151+ *
152+ * @see CoreRules#AGGREGATE_TO_SEMI_JOIN */
153+ public static class AggregateToSemiJoinRule extends SemiJoinRule {
154+ /** Creates a AggregateToSemiJoinRule. */
155+ protected AggregateToSemiJoinRule (AggregateToSemiJoinRuleConfig config ) {
156+ super (config );
157+ }
158+
159+ @ Override public void onMatch (RelOptRuleCall call ) {
160+ final Aggregate topAgg = call .rel (0 );
161+ final Join join = call .rel (1 );
162+ final RelNode left = call .rel (2 );
163+ final Aggregate rightAgg = call .rel (3 );
164+ perform (call , topAgg , join , left , rightAgg );
165+ }
166+
167+ /** Rule configuration. */
168+ @ Value .Immutable
169+ public interface AggregateToSemiJoinRuleConfig extends SemiJoinRule .Config {
170+ AggregateToSemiJoinRuleConfig DEFAULT = ImmutableAggregateToSemiJoinRuleConfig .of ()
171+ .withDescription ("SemiJoinRule:aggregate" )
172+ .withOperandFor (Aggregate .class , Join .class , Aggregate .class );
173+
174+ @ Override default AggregateToSemiJoinRule toRule () {
175+ return new AggregateToSemiJoinRule (this );
176+ }
177+
178+ /** Defines an operand tree for the given classes. */
179+ default AggregateToSemiJoinRuleConfig withOperandFor (
180+ Class <? extends Aggregate > topAggClass ,
181+ Class <? extends Join > joinClass ,
182+ Class <? extends Aggregate > rightAggClass ) {
183+ return withOperandSupplier (b ->
184+ b .operand (topAggClass ).oneInput (b2 ->
185+ b2 .operand (joinClass )
186+ .predicate (SemiJoinRule ::isJoinTypeSupported ).inputs (
187+ b3 -> b3 .operand (RelNode .class ).anyInputs (),
188+ b4 -> b4 .operand (rightAggClass ).anyInputs ())))
189+ .as (AggregateToSemiJoinRuleConfig .class );
190+ }
191+ }
192+ }
193+
133194 /** SemiJoinRule that matches a Project on top of a Join with an Aggregate
134195 * as its right child.
135196 *
@@ -251,8 +312,7 @@ protected JoinOnUniqueToSemiJoinRule(JoinOnUniqueToSemiJoinRuleConfig config) {
251312 final Join join = call .rel (1 );
252313 final RelNode left = call .rel (2 );
253314
254- final ImmutableBitSet bits =
255- RelOptUtil .InputFinder .bits (project .getProjects (), null );
315+ final ImmutableBitSet bits = getUsedFields (project );
256316 final ImmutableBitSet rightBits =
257317 ImmutableBitSet .range (left .getRowType ().getFieldCount (),
258318 join .getRowType ().getFieldCount ());
0 commit comments