Skip to content

Commit d233128

Browse files
committed
Add support for join filter pullups
This commit adds functionality to the optimizer to pull up simple join predicates. Currently this will only pullup simple predicate expressions where keys are compared against constant values.
1 parent 3b76fa6 commit d233128

File tree

2 files changed

+150
-0
lines changed

2 files changed

+150
-0
lines changed

compiler/optimizer/optimizer.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ func (o *Optimizer) Optimize(main *dag.Main) error {
144144
seq = mergeFilters(seq)
145145
seq = mergeValuesOps(seq)
146146
inlineRecordExprSpreads(seq)
147+
seq = joinFilterPullup(seq)
147148
seq = removePassOps(seq)
148149
seq = replaceSortAndHeadOrTailWithTop(seq)
149150
o.optimizeParallels(seq)
@@ -506,6 +507,103 @@ func inlineRecordExprSpreads(v any) {
506507
})
507508
}
508509

510+
func joinFilterPullup(seq dag.Seq) dag.Seq {
511+
seq = mergeFilters(seq)
512+
for i := 0; i <= len(seq)-3; i++ {
513+
fork, isfork := seq[i].(*dag.Fork)
514+
leftAlias, rightAlias, isjoin := isJoin(seq[i+1])
515+
filter, isfilter := seq[i+2].(*dag.Filter)
516+
if !isfork || !isjoin || !isfilter {
517+
continue
518+
}
519+
if len(fork.Paths) != 2 {
520+
panic(seq[i])
521+
}
522+
var remaining []dag.Expr
523+
for _, e := range splitPredicate(filter.Expr) {
524+
if pullup, ok := pullupExpr(leftAlias, e); ok {
525+
fork.Paths[0] = append(fork.Paths[0], dag.NewFilter(pullup))
526+
continue
527+
}
528+
if pullup, ok := pullupExpr(rightAlias, e); ok {
529+
fork.Paths[1] = append(fork.Paths[1], dag.NewFilter(pullup))
530+
continue
531+
}
532+
remaining = append(remaining, e)
533+
}
534+
if len(remaining) == 0 {
535+
// Filter has been fully pulled up and can be removed.
536+
seq.Delete(i+2, i+3)
537+
} else {
538+
out := remaining[0]
539+
for _, e := range remaining[1:] {
540+
out = dag.NewBinaryExpr("and", e, out)
541+
}
542+
seq[i+2] = dag.NewFilter(out)
543+
}
544+
fork.Paths[0] = joinFilterPullup(fork.Paths[0])
545+
fork.Paths[1] = joinFilterPullup(fork.Paths[1])
546+
}
547+
return seq
548+
}
549+
550+
func isJoin(op dag.Op) (string, string, bool) {
551+
switch op := op.(type) {
552+
case *dag.HashJoin:
553+
return op.LeftAlias, op.RightAlias, true
554+
case *dag.Join:
555+
return op.LeftAlias, op.RightAlias, true
556+
default:
557+
return "", "", false
558+
}
559+
}
560+
561+
func splitPredicate(e dag.Expr) []dag.Expr {
562+
if b, ok := e.(*dag.BinaryExpr); ok && b.Op == "and" {
563+
return append(splitPredicate(b.LHS), splitPredicate(b.RHS)...)
564+
}
565+
return []dag.Expr{e}
566+
}
567+
568+
func pullupExpr(alias string, expr dag.Expr) (dag.Expr, bool) {
569+
e, ok := expr.(*dag.BinaryExpr)
570+
if !ok {
571+
return nil, false
572+
}
573+
if e.Op == "and" {
574+
lhs, lok := pullupExpr(alias, e.LHS)
575+
rhs, rok := pullupExpr(alias, e.RHS)
576+
if !lok || !rok {
577+
return nil, false
578+
}
579+
return dag.NewBinaryExpr("and", lhs, rhs), true
580+
}
581+
if e.Op == "or" {
582+
lhs, lok := pullupExpr(alias, e.LHS)
583+
rhs, rok := pullupExpr(alias, e.RHS)
584+
if !lok || !rok {
585+
return nil, false
586+
}
587+
return dag.NewBinaryExpr("or", lhs, rhs), true
588+
589+
}
590+
var literal *dag.Literal
591+
var this *dag.This
592+
for _, e := range []dag.Expr{e.RHS, e.LHS} {
593+
if l, ok := e.(*dag.Literal); ok && literal == nil {
594+
literal = l
595+
continue
596+
}
597+
if t, ok := e.(*dag.This); ok && this == nil && len(t.Path) > 1 && t.Path[0] == alias {
598+
this = t
599+
continue
600+
}
601+
return nil, false
602+
}
603+
path := slices.Clone(this.Path[1:])
604+
return dag.NewBinaryExpr(e.Op, dag.NewThis(path), literal), true
605+
}
606+
509607
func liftFilterOps(seq dag.Seq) dag.Seq {
510608
walkT(reflect.ValueOf(&seq), func(seq dag.Seq) dag.Seq {
511609
for i := len(seq) - 2; i >= 0; i-- {
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
script: |
2+
super compile -C -O "
3+
select *
4+
from (values (1,'badgers')) as teams(id,team)
5+
inner join (values (1,1,'blake','P')) as players(id,team_id,player,position) on teams.id = players.team_id
6+
where players.position = 'P' and 'badgers' = team
7+
"
8+
echo // ===
9+
super compile -C -O "
10+
select *
11+
from (values (1)) as t1(a1), (values (1)) as t2(a2), (values (1)) as t3(a3)
12+
where a1 == 1 and a3 == 1 and (1 == a2 or a2 == 2)
13+
"
14+
15+
outputs:
16+
- name: stdout
17+
data: |
18+
null
19+
| fork
20+
(
21+
values {id:1,team:"badgers"}
22+
| where team=="badgers"
23+
)
24+
(
25+
values {id:1,team_id:1,player:"blake",position:"P"}
26+
| where position=="P"
27+
)
28+
| inner hashjoin as {left,right} on id==team_id
29+
| values {team:left.team,id:right.id,team_id:right.team_id,player:right.player,position:right.position}
30+
| output main
31+
// ===
32+
null
33+
| fork
34+
(
35+
values {a1:1}
36+
| where a1==1
37+
)
38+
(
39+
fork
40+
(
41+
values {a2:1}
42+
| where a2==1 or a2==2
43+
)
44+
(
45+
values {a3:1}
46+
| where a3==1
47+
)
48+
| cross join as {left,right}
49+
)
50+
| cross join as {left,right}
51+
| values {a1:left.a1,a2:right.left.a2,a3:right.right.a3}
52+
| output main

0 commit comments

Comments
 (0)