Skip to content

Commit 5352dc4

Browse files
committed
optimizer: Convert cross joins to hash joins
This commit adds functionality to the optimizer that rewrites cross joins with equality comparisons filters into hash joins where appropriate.
1 parent b166423 commit 5352dc4

File tree

3 files changed

+170
-10
lines changed

3 files changed

+170
-10
lines changed

compiler/optimizer/join.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,83 @@ func replaceJoinWithHashJoin(seq dag.Seq) {
2727
})
2828
}
2929

30+
func liftFilterConvertCrossJoin(seq dag.Seq) dag.Seq {
31+
var filter *dag.FilterOp
32+
var i int
33+
for i = range len(seq) - 3 {
34+
var isfilter bool
35+
_, isfork := seq[i].(*dag.ForkOp)
36+
_, _, _, isjoin := isJoin(seq[i+1])
37+
filter, isfilter = seq[i+2].(*dag.FilterOp)
38+
if isfork && isjoin && isfilter {
39+
break
40+
}
41+
}
42+
if filter == nil {
43+
return seq
44+
}
45+
in := splitPredicate(filter.Expr)
46+
var exprs []dag.Expr
47+
for _, e := range in {
48+
if b, ok := e.(*dag.BinaryExpr); !ok || b.Op != "==" || !convertCrossJoinToHashJoin(seq[i:], b.LHS, b.RHS) {
49+
exprs = append(exprs, e)
50+
}
51+
}
52+
if len(exprs) == 0 {
53+
seq.Delete(i+2, i+3)
54+
} else if len(exprs) != len(in) {
55+
seq[i+2] = dag.NewFilterOp(buildConjunction(exprs))
56+
}
57+
return seq
58+
}
59+
60+
func convertCrossJoinToHashJoin(seq dag.Seq, lhs, rhs dag.Expr) bool {
61+
fork, isfork := seq[0].(*dag.ForkOp)
62+
leftAlias, rightAlias, style, isjoin := isJoin(seq[1])
63+
if !isfork || !isjoin {
64+
return false
65+
}
66+
if len(fork.Paths) != 2 {
67+
panic(fork)
68+
}
69+
lhsFirst, lok := firstThisPathComponent(lhs)
70+
rhsFirst, rok := firstThisPathComponent(rhs)
71+
if !lok || !rok {
72+
return false
73+
}
74+
lhs, rhs = dag.CopyExpr(lhs), dag.CopyExpr(rhs)
75+
stripFirstThisPathComponent(lhs)
76+
stripFirstThisPathComponent(rhs)
77+
if lhsFirst == rhsFirst {
78+
if lhsFirst == leftAlias {
79+
return convertCrossJoinToHashJoin(fork.Paths[0], lhs, rhs)
80+
}
81+
if lhsFirst == rightAlias {
82+
return convertCrossJoinToHashJoin(fork.Paths[1], lhs, rhs)
83+
}
84+
return false
85+
}
86+
if style != "cross" {
87+
return false
88+
}
89+
if lhsFirst != leftAlias {
90+
lhsFirst, rhsFirst = rhsFirst, lhsFirst
91+
lhs, rhs = rhs, lhs
92+
}
93+
if lhsFirst != leftAlias || rhsFirst != rightAlias {
94+
return false
95+
}
96+
seq[1] = &dag.HashJoinOp{
97+
Kind: "HashJoinOp",
98+
Style: "inner",
99+
LeftAlias: leftAlias,
100+
RightAlias: rightAlias,
101+
LeftKey: lhs,
102+
RightKey: rhs,
103+
}
104+
return true
105+
}
106+
30107
func equiJoinKeyExprs(e dag.Expr, leftAlias, rightAlias string) (left, right dag.Expr, ok bool) {
31108
b, ok := e.(*dag.BinaryExpr)
32109
if !ok || b.Op != "==" {

compiler/optimizer/optimizer.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ func (o *Optimizer) Optimize(main *dag.Main) error {
144144
seq = mergeFilters(seq)
145145
seq = mergeValuesOps(seq)
146146
inlineRecordExprSpreads(seq)
147+
seq = liftFilterConvertCrossJoin(seq)
147148
seq = joinFilterPullup(seq)
148149
seq = removePassOps(seq)
149150
seq = replaceSortAndHeadOrTailWithTop(seq)
@@ -511,7 +512,7 @@ func joinFilterPullup(seq dag.Seq) dag.Seq {
511512
seq = mergeFilters(seq)
512513
for i := 0; i <= len(seq)-3; i++ {
513514
fork, isfork := seq[i].(*dag.ForkOp)
514-
leftAlias, rightAlias, isjoin := isJoin(seq[i+1])
515+
leftAlias, rightAlias, _, isjoin := isJoin(seq[i+1])
515516
filter, isfilter := seq[i+2].(*dag.FilterOp)
516517
if !isfork || !isjoin || !isfilter {
517518
continue
@@ -535,27 +536,31 @@ func joinFilterPullup(seq dag.Seq) dag.Seq {
535536
// Filter has been fully pulled up and can be removed.
536537
seq.Delete(i+2, i+3)
537538
} else {
538-
out := remaining[0]
539-
for _, e := range remaining[1:] {
540-
out = dag.NewBinaryExpr("and", e, out)
541-
}
542-
seq[i+2] = dag.NewFilterOp(out)
539+
seq[i+2] = dag.NewFilterOp(buildConjunction(remaining))
543540
}
544541
fork.Paths[0] = joinFilterPullup(fork.Paths[0])
545542
fork.Paths[1] = joinFilterPullup(fork.Paths[1])
546543
}
547544
return seq
548545
}
549546

550-
func isJoin(op dag.Op) (string, string, bool) {
547+
func isJoin(op dag.Op) (string, string, string, bool) {
551548
switch op := op.(type) {
552549
case *dag.HashJoinOp:
553-
return op.LeftAlias, op.RightAlias, true
550+
return op.LeftAlias, op.RightAlias, "inner", true
554551
case *dag.JoinOp:
555-
return op.LeftAlias, op.RightAlias, true
552+
return op.LeftAlias, op.RightAlias, op.Style, true
556553
default:
557-
return "", "", false
554+
return "", "", "", false
555+
}
556+
}
557+
558+
func buildConjunction(exprs []dag.Expr) dag.Expr {
559+
out := exprs[0]
560+
for _, e := range exprs[1:] {
561+
out = dag.NewBinaryExpr("and", e, out)
558562
}
563+
return out
559564
}
560565

561566
func splitPredicate(e dag.Expr) []dag.Expr {
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
script: |
2+
super compile -O -C "
3+
fork
4+
(
5+
values {a1:1}
6+
)
7+
(
8+
fork
9+
(
10+
values {a2:1}
11+
)
12+
(
13+
values {a3:1}
14+
)
15+
| cross join as {left,right}
16+
)
17+
| cross join as {left,right}
18+
| where right.right.a3 == right.left.a2+1
19+
| where right.left.a2 == left.a1
20+
"
21+
echo // ===
22+
super compile -O -C "
23+
fork
24+
(
25+
fork
26+
(
27+
values {a2:1}
28+
)
29+
(
30+
values {a3:1}
31+
)
32+
| cross join as {left,right}
33+
)
34+
(
35+
values {a1:1}
36+
)
37+
| cross join as {left,right}
38+
| where left.left.a2 == left.right.a3 and right.a1 == left.left.a2
39+
"
40+
41+
outputs:
42+
- name: stdout
43+
data: |
44+
null
45+
| fork
46+
(
47+
values {a1:1}
48+
)
49+
(
50+
fork
51+
(
52+
values {a2:1}
53+
)
54+
(
55+
values {a3:1}
56+
)
57+
| inner hashjoin as {left,right} on a2+1==a3
58+
)
59+
| inner hashjoin as {left,right} on a1==left.a2
60+
| output main
61+
// ===
62+
null
63+
| fork
64+
(
65+
fork
66+
(
67+
values {a2:1}
68+
)
69+
(
70+
values {a3:1}
71+
)
72+
| inner hashjoin as {left,right} on a2==a3
73+
)
74+
(
75+
values {a1:1}
76+
)
77+
| inner hashjoin as {left,right} on left.a2==a1
78+
| output main

0 commit comments

Comments
 (0)