Skip to content

Commit 421a68e

Browse files
committed
Convert appropriate cross joins to hash joins
This commit adds functionality to the optimizer that rewrites cross joins with equality comparisons filters into hash joins where appropriate.
1 parent b166423 commit 421a68e

File tree

3 files changed

+170
-10
lines changed

3 files changed

+170
-10
lines changed

compiler/optimizer/join.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,83 @@ func replaceJoinWithHashJoin(seq dag.Seq) {
2727
})
2828
}
2929

30+
func liftFilterConvertCrossJoin(seq dag.Seq) dag.Seq {
31+
var filter *dag.FilterOp
32+
var i int
33+
for i = range len(seq) - 3 {
34+
var isfilter bool
35+
_, isfork := seq[i].(*dag.ForkOp)
36+
_, _, _, isjoin := isJoin(seq[i+1])
37+
filter, isfilter = seq[i+2].(*dag.FilterOp)
38+
if isfork && isjoin && isfilter {
39+
break
40+
}
41+
}
42+
if filter == nil {
43+
return seq
44+
}
45+
in := splitPredicate(filter.Expr)
46+
var exprs []dag.Expr
47+
for _, e := range in {
48+
if b, ok := e.(*dag.BinaryExpr); !ok || b.Op != "==" || !convertCrossJoinToHashJoin(seq[i:], b.LHS, b.RHS) {
49+
exprs = append(exprs, e)
50+
}
51+
}
52+
if len(exprs) == 0 {
53+
seq.Delete(i+2, i+3)
54+
} else if len(exprs) != len(in) {
55+
seq[i+2] = dag.NewFilterOp(buildConjunction(exprs))
56+
}
57+
return seq
58+
}
59+
60+
func convertCrossJoinToHashJoin(seq dag.Seq, lhs, rhs dag.Expr) bool {
61+
fork, isfork := seq[0].(*dag.ForkOp)
62+
leftAlias, rightAlias, style, isjoin := isJoin(seq[1])
63+
if !isfork || !isjoin {
64+
return false
65+
}
66+
if len(fork.Paths) != 2 {
67+
panic(fork)
68+
}
69+
lhsFirst, lok := firstThisPathComponent(lhs)
70+
rhsFirst, rok := firstThisPathComponent(rhs)
71+
if !lok || !rok {
72+
return false
73+
}
74+
lhs, rhs = dag.CopyExpr(lhs), dag.CopyExpr(rhs)
75+
stripFirstThisPathComponent(lhs)
76+
stripFirstThisPathComponent(rhs)
77+
if lhsFirst == rhsFirst {
78+
if lhsFirst == leftAlias {
79+
return convertCrossJoinToHashJoin(fork.Paths[0], lhs, rhs)
80+
}
81+
if lhsFirst == rightAlias {
82+
return convertCrossJoinToHashJoin(fork.Paths[1], lhs, rhs)
83+
}
84+
return false
85+
}
86+
if style != "cross" {
87+
return false
88+
}
89+
if lhsFirst != leftAlias {
90+
lhsFirst, rhsFirst = rhsFirst, lhsFirst
91+
lhs, rhs = rhs, lhs
92+
}
93+
if lhsFirst != leftAlias || rhsFirst != rightAlias {
94+
return false
95+
}
96+
seq[1] = &dag.HashJoinOp{
97+
Kind: "HashJoinOp",
98+
Style: "inner",
99+
LeftAlias: leftAlias,
100+
RightAlias: rightAlias,
101+
LeftKey: lhs,
102+
RightKey: rhs,
103+
}
104+
return true
105+
}
106+
30107
func equiJoinKeyExprs(e dag.Expr, leftAlias, rightAlias string) (left, right dag.Expr, ok bool) {
31108
b, ok := e.(*dag.BinaryExpr)
32109
if !ok || b.Op != "==" {

compiler/optimizer/optimizer.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ func (o *Optimizer) Optimize(main *dag.Main) error {
144144
seq = mergeFilters(seq)
145145
seq = mergeValuesOps(seq)
146146
inlineRecordExprSpreads(seq)
147+
seq = liftFilterConvertCrossJoin(seq)
147148
seq = joinFilterPullup(seq)
148149
seq = removePassOps(seq)
149150
seq = replaceSortAndHeadOrTailWithTop(seq)
@@ -511,7 +512,7 @@ func joinFilterPullup(seq dag.Seq) dag.Seq {
511512
seq = mergeFilters(seq)
512513
for i := 0; i <= len(seq)-3; i++ {
513514
fork, isfork := seq[i].(*dag.ForkOp)
514-
leftAlias, rightAlias, isjoin := isJoin(seq[i+1])
515+
leftAlias, rightAlias, _, isjoin := isJoin(seq[i+1])
515516
filter, isfilter := seq[i+2].(*dag.FilterOp)
516517
if !isfork || !isjoin || !isfilter {
517518
continue
@@ -535,27 +536,31 @@ func joinFilterPullup(seq dag.Seq) dag.Seq {
535536
// Filter has been fully pulled up and can be removed.
536537
seq.Delete(i+2, i+3)
537538
} else {
538-
out := remaining[0]
539-
for _, e := range remaining[1:] {
540-
out = dag.NewBinaryExpr("and", e, out)
541-
}
542-
seq[i+2] = dag.NewFilterOp(out)
539+
seq[i+2] = dag.NewFilterOp(buildConjunction(remaining))
543540
}
544541
fork.Paths[0] = joinFilterPullup(fork.Paths[0])
545542
fork.Paths[1] = joinFilterPullup(fork.Paths[1])
546543
}
547544
return seq
548545
}
549546

550-
func isJoin(op dag.Op) (string, string, bool) {
547+
func isJoin(op dag.Op) (string, string, string, bool) {
551548
switch op := op.(type) {
552549
case *dag.HashJoinOp:
553-
return op.LeftAlias, op.RightAlias, true
550+
return op.LeftAlias, op.RightAlias, "inner", true
554551
case *dag.JoinOp:
555-
return op.LeftAlias, op.RightAlias, true
552+
return op.LeftAlias, op.RightAlias, op.Style, true
556553
default:
557-
return "", "", false
554+
return "", "", "", false
555+
}
556+
}
557+
558+
func buildConjunction(exprs []dag.Expr) dag.Expr {
559+
out := exprs[0]
560+
for _, e := range exprs[1:] {
561+
out = dag.NewBinaryExpr("and", e, out)
558562
}
563+
return out
559564
}
560565

561566
func splitPredicate(e dag.Expr) []dag.Expr {
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
script: |
2+
super compile -O -C "
3+
fork
4+
(
5+
values {a1:1}
6+
)
7+
(
8+
fork
9+
(
10+
values {a2:1}
11+
)
12+
(
13+
values {a3:1}
14+
)
15+
| cross join as {left,right}
16+
)
17+
| cross join as {left,right}
18+
| where right.right.a3 == right.left.a2+1
19+
| where right.left.a2 == left.a1
20+
"
21+
echo // ===
22+
super compile -O -C "
23+
fork
24+
(
25+
fork
26+
(
27+
values {a2:1}
28+
)
29+
(
30+
values {a3:1}
31+
)
32+
| cross join as {left,right}
33+
)
34+
(
35+
values {a1:1}
36+
)
37+
| cross join as {left,right}
38+
| where left.left.a2 == left.right.a3 and right.a1 == left.left.a2
39+
"
40+
41+
outputs:
42+
- name: stdout
43+
data: |
44+
null
45+
| fork
46+
(
47+
values {a1:1}
48+
)
49+
(
50+
fork
51+
(
52+
values {a2:1}
53+
)
54+
(
55+
values {a3:1}
56+
)
57+
| inner hashjoin as {left,right} on a2+1==a3
58+
)
59+
| inner hashjoin as {left,right} on a1==left.a2
60+
| output main
61+
// ===
62+
null
63+
| fork
64+
(
65+
fork
66+
(
67+
values {a2:1}
68+
)
69+
(
70+
values {a3:1}
71+
)
72+
| inner hashjoin as {left,right} on a2==a3
73+
)
74+
(
75+
values {a1:1}
76+
)
77+
| inner hashjoin as {left,right} on left.a2==a1
78+
| output main

0 commit comments

Comments
 (0)