Skip to content

Commit dd1d9e7

Browse files
committed
Convert appropriate cross joins to hash joins
This commit adds functionality to the optimizer that will rewrite cross joins with equality comparisons filters into hash joins where appropriate.
1 parent 6d4d27a commit dd1d9e7

File tree

3 files changed

+172
-10
lines changed

3 files changed

+172
-10
lines changed

compiler/optimizer/join.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,85 @@ func replaceJoinWithHashJoin(seq dag.Seq) {
2727
})
2828
}
2929

30+
func liftFilterConvertCrossJoin(seq dag.Seq) dag.Seq {
31+
var filter *dag.FilterOp
32+
var i int
33+
for i = range len(seq) - 3 {
34+
_, isfork := seq[i].(*dag.ForkOp)
35+
_, _, _, isjoin := isJoin(seq[i+1])
36+
f, isfilter := seq[i+2].(*dag.FilterOp)
37+
if isfork && isjoin && isfilter {
38+
filter = f
39+
}
40+
}
41+
if filter == nil {
42+
return seq
43+
}
44+
var exprs []dag.Expr
45+
for _, e := range splitPredicate(filter.Expr) {
46+
b, ok := e.(*dag.BinaryExpr)
47+
if !ok || b.Op != "==" {
48+
continue
49+
}
50+
if ok := convertCrossJoinToHashJoin(seq[i:], b.LHS, b.RHS); !ok {
51+
exprs = append(exprs, e)
52+
}
53+
}
54+
if len(exprs) == 0 {
55+
seq.Delete(i+2, i+3)
56+
} else {
57+
seq[i+2] = dag.NewFilterOp(buildConjunction(exprs))
58+
}
59+
return seq
60+
}
61+
62+
func convertCrossJoinToHashJoin(seq dag.Seq, lhs, rhs dag.Expr) bool {
63+
fork, isfork := seq[0].(*dag.ForkOp)
64+
leftAlias, rightAlias, style, isjoin := isJoin(seq[1])
65+
if !isfork || !isjoin {
66+
return false
67+
}
68+
if len(fork.Paths) != 2 {
69+
panic(fork)
70+
}
71+
lhsFirst, lok := firstThisPathComponent(lhs)
72+
rhsFirst, rok := firstThisPathComponent(rhs)
73+
if !lok || !rok {
74+
return false
75+
}
76+
lhs, rhs = dag.CopyExpr(lhs), dag.CopyExpr(rhs)
77+
stripFirstThisPathComponent(lhs)
78+
stripFirstThisPathComponent(rhs)
79+
if lhsFirst == rhsFirst {
80+
if lhsFirst == leftAlias {
81+
return convertCrossJoinToHashJoin(fork.Paths[0], lhs, rhs)
82+
}
83+
if lhsFirst == rightAlias {
84+
return convertCrossJoinToHashJoin(fork.Paths[1], lhs, rhs)
85+
}
86+
return false
87+
}
88+
if style != "cross" {
89+
return false
90+
}
91+
if lhsFirst != leftAlias {
92+
lhsFirst, rhsFirst = rhsFirst, lhsFirst
93+
lhs, rhs = rhs, lhs
94+
}
95+
if lhsFirst != leftAlias || rhsFirst != rightAlias {
96+
return false
97+
}
98+
seq[1] = &dag.HashJoinOp{
99+
Kind: "HashJoinOp",
100+
Style: "inner",
101+
LeftAlias: leftAlias,
102+
RightAlias: rightAlias,
103+
LeftKey: lhs,
104+
RightKey: rhs,
105+
}
106+
return true
107+
}
108+
30109
func equiJoinKeyExprs(e dag.Expr, leftAlias, rightAlias string) (left, right dag.Expr, ok bool) {
31110
b, ok := e.(*dag.BinaryExpr)
32111
if !ok || b.Op != "==" {

compiler/optimizer/optimizer.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ func (o *Optimizer) Optimize(main *dag.Main) error {
144144
seq = mergeFilters(seq)
145145
seq = mergeValuesOps(seq)
146146
inlineRecordExprSpreads(seq)
147+
seq = liftFilterConvertCrossJoin(seq)
147148
seq = joinFilterPullup(seq)
148149
seq = removePassOps(seq)
149150
seq = replaceSortAndHeadOrTailWithTop(seq)
@@ -511,7 +512,7 @@ func joinFilterPullup(seq dag.Seq) dag.Seq {
511512
seq = mergeFilters(seq)
512513
for i := 0; i <= len(seq)-3; i++ {
513514
fork, isfork := seq[i].(*dag.ForkOp)
514-
leftAlias, rightAlias, isjoin := isJoin(seq[i+1])
515+
leftAlias, rightAlias, _, isjoin := isJoin(seq[i+1])
515516
filter, isfilter := seq[i+2].(*dag.FilterOp)
516517
if !isfork || !isjoin || !isfilter {
517518
continue
@@ -535,27 +536,31 @@ func joinFilterPullup(seq dag.Seq) dag.Seq {
535536
// Filter has been fully pulled up and can be removed.
536537
seq.Delete(i+2, i+3)
537538
} else {
538-
out := remaining[0]
539-
for _, e := range remaining[1:] {
540-
out = dag.NewBinaryExpr("and", e, out)
541-
}
542-
seq[i+2] = dag.NewFilterOp(out)
539+
seq[i+2] = dag.NewFilterOp(buildConjunction(remaining))
543540
}
544541
fork.Paths[0] = joinFilterPullup(fork.Paths[0])
545542
fork.Paths[1] = joinFilterPullup(fork.Paths[1])
546543
}
547544
return seq
548545
}
549546

550-
func isJoin(op dag.Op) (string, string, bool) {
547+
func isJoin(op dag.Op) (string, string, string, bool) {
551548
switch op := op.(type) {
552549
case *dag.HashJoinOp:
553-
return op.LeftAlias, op.RightAlias, true
550+
return op.LeftAlias, op.RightAlias, "inner", true
554551
case *dag.JoinOp:
555-
return op.LeftAlias, op.RightAlias, true
552+
return op.LeftAlias, op.RightAlias, op.Style, true
556553
default:
557-
return "", "", false
554+
return "", "", "", false
555+
}
556+
}
557+
558+
func buildConjunction(exprs []dag.Expr) dag.Expr {
559+
out := exprs[0]
560+
for _, e := range exprs[1:] {
561+
out = dag.NewBinaryExpr("and", e, out)
558562
}
563+
return out
559564
}
560565

561566
func splitPredicate(e dag.Expr) []dag.Expr {
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
script: |
2+
super compile -O -C "
3+
fork
4+
(
5+
values {a1:1}
6+
)
7+
(
8+
fork
9+
(
10+
values {a2:1}
11+
)
12+
(
13+
values {a3:1}
14+
)
15+
| cross join as {left,right}
16+
)
17+
| cross join as {left,right}
18+
| where right.right.a3 == right.left.a2+1
19+
| where right.left.a2 == left.a1
20+
"
21+
echo // ===
22+
super compile -O -C "
23+
fork
24+
(
25+
fork
26+
(
27+
values {a2:1}
28+
)
29+
(
30+
values {a3:1}
31+
)
32+
| cross join as {left,right}
33+
)
34+
(
35+
values {a1:1}
36+
)
37+
| cross join as {left,right}
38+
| where left.left.a2 == left.right.a3 and right.a1 == left.left.a2
39+
"
40+
41+
outputs:
42+
- name: stdout
43+
data: |
44+
null
45+
| fork
46+
(
47+
values {a1:1}
48+
)
49+
(
50+
fork
51+
(
52+
values {a2:1}
53+
)
54+
(
55+
values {a3:1}
56+
)
57+
| inner hashjoin as {left,right} on a2+1==a3
58+
)
59+
| inner hashjoin as {left,right} on a1==left.a2
60+
| output main
61+
// ===
62+
null
63+
| fork
64+
(
65+
fork
66+
(
67+
values {a2:1}
68+
)
69+
(
70+
values {a3:1}
71+
)
72+
| inner hashjoin as {left,right} on a2==a3
73+
)
74+
(
75+
values {a1:1}
76+
)
77+
| inner hashjoin as {left,right} on left.a2==a1
78+
| output main

0 commit comments

Comments
 (0)