Skip to content

Commit 2d64111

Browse files
authored
Add support for join filter pullups (#6272)
This commit adds functionality to the optimizer to pull up simple join predicates. Currently this will only pullup simple predicate expressions where keys are compared against constant values.
1 parent 3b76fa6 commit 2d64111

File tree

2 files changed

+150
-0
lines changed

2 files changed

+150
-0
lines changed

compiler/optimizer/optimizer.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ func (o *Optimizer) Optimize(main *dag.Main) error {
144144
seq = mergeFilters(seq)
145145
seq = mergeValuesOps(seq)
146146
inlineRecordExprSpreads(seq)
147+
seq = joinFilterPullup(seq)
147148
seq = removePassOps(seq)
148149
seq = replaceSortAndHeadOrTailWithTop(seq)
149150
o.optimizeParallels(seq)
@@ -506,6 +507,103 @@ func inlineRecordExprSpreads(v any) {
506507
})
507508
}
508509

510+
func joinFilterPullup(seq dag.Seq) dag.Seq {
511+
seq = mergeFilters(seq)
512+
for i := 0; i <= len(seq)-3; i++ {
513+
fork, isfork := seq[i].(*dag.Fork)
514+
leftAlias, rightAlias, isjoin := isJoin(seq[i+1])
515+
filter, isfilter := seq[i+2].(*dag.Filter)
516+
if !isfork || !isjoin || !isfilter {
517+
continue
518+
}
519+
if len(fork.Paths) != 2 {
520+
panic(seq[i])
521+
}
522+
var remaining []dag.Expr
523+
for _, e := range splitPredicate(filter.Expr) {
524+
if pullup, ok := pullupExpr(leftAlias, e); ok {
525+
fork.Paths[0] = append(fork.Paths[0], dag.NewFilter(pullup))
526+
continue
527+
}
528+
if pullup, ok := pullupExpr(rightAlias, e); ok {
529+
fork.Paths[1] = append(fork.Paths[1], dag.NewFilter(pullup))
530+
continue
531+
}
532+
remaining = append(remaining, e)
533+
}
534+
if len(remaining) == 0 {
535+
// Filter has been fully pulled up and can be removed.
536+
seq.Delete(i+2, i+3)
537+
} else {
538+
out := remaining[0]
539+
for _, e := range remaining[1:] {
540+
out = dag.NewBinaryExpr("and", e, out)
541+
}
542+
seq[i+2] = dag.NewFilter(out)
543+
}
544+
fork.Paths[0] = joinFilterPullup(fork.Paths[0])
545+
fork.Paths[1] = joinFilterPullup(fork.Paths[1])
546+
}
547+
return seq
548+
}
549+
550+
func isJoin(op dag.Op) (string, string, bool) {
551+
switch op := op.(type) {
552+
case *dag.HashJoin:
553+
return op.LeftAlias, op.RightAlias, true
554+
case *dag.Join:
555+
return op.LeftAlias, op.RightAlias, true
556+
default:
557+
return "", "", false
558+
}
559+
}
560+
561+
func splitPredicate(e dag.Expr) []dag.Expr {
562+
if b, ok := e.(*dag.BinaryExpr); ok && b.Op == "and" {
563+
return append(splitPredicate(b.LHS), splitPredicate(b.RHS)...)
564+
}
565+
return []dag.Expr{e}
566+
}
567+
568+
func pullupExpr(alias string, expr dag.Expr) (dag.Expr, bool) {
569+
e, ok := expr.(*dag.BinaryExpr)
570+
if !ok {
571+
return nil, false
572+
}
573+
if e.Op == "and" {
574+
lhs, lok := pullupExpr(alias, e.LHS)
575+
rhs, rok := pullupExpr(alias, e.RHS)
576+
if !lok || !rok {
577+
return nil, false
578+
}
579+
return dag.NewBinaryExpr("and", lhs, rhs), true
580+
}
581+
if e.Op == "or" {
582+
lhs, lok := pullupExpr(alias, e.LHS)
583+
rhs, rok := pullupExpr(alias, e.RHS)
584+
if !lok || !rok {
585+
return nil, false
586+
}
587+
return dag.NewBinaryExpr("or", lhs, rhs), true
588+
589+
}
590+
var literal *dag.Literal
591+
var this *dag.This
592+
for _, e := range []dag.Expr{e.RHS, e.LHS} {
593+
if l, ok := e.(*dag.Literal); ok && literal == nil {
594+
literal = l
595+
continue
596+
}
597+
if t, ok := e.(*dag.This); ok && this == nil && len(t.Path) > 1 && t.Path[0] == alias {
598+
this = t
599+
continue
600+
}
601+
return nil, false
602+
}
603+
path := slices.Clone(this.Path[1:])
604+
return dag.NewBinaryExpr(e.Op, dag.NewThis(path), literal), true
605+
}
606+
509607
func liftFilterOps(seq dag.Seq) dag.Seq {
510608
walkT(reflect.ValueOf(&seq), func(seq dag.Seq) dag.Seq {
511609
for i := len(seq) - 2; i >= 0; i-- {
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
script: |
2+
super compile -C -O "
3+
select *
4+
from (values (1,'badgers')) as teams(id,team)
5+
inner join (values (1,1,'blake','P')) as players(id,team_id,player,position) on teams.id = players.team_id
6+
where players.position = 'P' and 'badgers' = team
7+
"
8+
echo // ===
9+
super compile -C -O "
10+
select *
11+
from (values (1)) as t1(a1), (values (1)) as t2(a2), (values (1)) as t3(a3)
12+
where a1 == 1 and a3 == 1 and (1 == a2 or a2 == 2)
13+
"
14+
15+
outputs:
16+
- name: stdout
17+
data: |
18+
null
19+
| fork
20+
(
21+
values {id:1,team:"badgers"}
22+
| where team=="badgers"
23+
)
24+
(
25+
values {id:1,team_id:1,player:"blake",position:"P"}
26+
| where position=="P"
27+
)
28+
| inner hashjoin as {left,right} on id==team_id
29+
| values {team:left.team,id:right.id,team_id:right.team_id,player:right.player,position:right.position}
30+
| output main
31+
// ===
32+
null
33+
| fork
34+
(
35+
values {a1:1}
36+
| where a1==1
37+
)
38+
(
39+
fork
40+
(
41+
values {a2:1}
42+
| where a2==1 or a2==2
43+
)
44+
(
45+
values {a3:1}
46+
| where a3==1
47+
)
48+
| cross join as {left,right}
49+
)
50+
| cross join as {left,right}
51+
| values {a1:left.a1,a2:right.left.a2,a3:right.right.a3}
52+
| output main

0 commit comments

Comments
 (0)