Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions compiler/optimizer/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,85 @@ func replaceJoinWithHashJoin(seq dag.Seq) {
})
}

func liftFilterConvertCrossJoin(seq dag.Seq) dag.Seq {
var filter *dag.FilterOp
var i int
for i = range len(seq) - 3 {
_, isfork := seq[i].(*dag.ForkOp)
_, _, _, isjoin := isJoin(seq[i+1])
f, isfilter := seq[i+2].(*dag.FilterOp)
if isfork && isjoin && isfilter {
filter = f
}
}
if filter == nil {
return seq
}
var exprs []dag.Expr
for _, e := range splitPredicate(filter.Expr) {
b, ok := e.(*dag.BinaryExpr)
if !ok || b.Op != "==" {
continue
}
if ok := convertCrossJoinToHashJoin(seq[i:], b.LHS, b.RHS); !ok {
exprs = append(exprs, e)
}
}
if len(exprs) == 0 {
seq.Delete(i+2, i+3)
} else {
seq[i+2] = dag.NewFilterOp(buildConjunction(exprs))
}
return seq
}

func convertCrossJoinToHashJoin(seq dag.Seq, lhs, rhs dag.Expr) bool {
fork, isfork := seq[0].(*dag.ForkOp)
leftAlias, rightAlias, style, isjoin := isJoin(seq[1])
if !isfork || !isjoin {
return false
}
if len(fork.Paths) != 2 {
panic(fork)
}
lhsFirst, lok := firstThisPathComponent(lhs)
rhsFirst, rok := firstThisPathComponent(rhs)
if !lok || !rok {
return false
}
lhs, rhs = dag.CopyExpr(lhs), dag.CopyExpr(rhs)
stripFirstThisPathComponent(lhs)
stripFirstThisPathComponent(rhs)
if lhsFirst == rhsFirst {
if lhsFirst == leftAlias {
return convertCrossJoinToHashJoin(fork.Paths[0], lhs, rhs)
}
if lhsFirst == rightAlias {
return convertCrossJoinToHashJoin(fork.Paths[1], lhs, rhs)
}
return false
}
if style != "cross" {
return false
}
if lhsFirst != leftAlias {
lhsFirst, rhsFirst = rhsFirst, lhsFirst
lhs, rhs = rhs, lhs
}
if lhsFirst != leftAlias || rhsFirst != rightAlias {
return false
}
seq[1] = &dag.HashJoinOp{
Kind: "HashJoinOp",
Style: "inner",
LeftAlias: leftAlias,
RightAlias: rightAlias,
LeftKey: lhs,
RightKey: rhs,
}
return true
}

func equiJoinKeyExprs(e dag.Expr, leftAlias, rightAlias string) (left, right dag.Expr, ok bool) {
b, ok := e.(*dag.BinaryExpr)
if !ok || b.Op != "==" {
Expand Down
25 changes: 15 additions & 10 deletions compiler/optimizer/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ func (o *Optimizer) Optimize(main *dag.Main) error {
seq = mergeFilters(seq)
seq = mergeValuesOps(seq)
inlineRecordExprSpreads(seq)
seq = liftFilterConvertCrossJoin(seq)
seq = joinFilterPullup(seq)
seq = removePassOps(seq)
seq = replaceSortAndHeadOrTailWithTop(seq)
Expand Down Expand Up @@ -511,7 +512,7 @@ func joinFilterPullup(seq dag.Seq) dag.Seq {
seq = mergeFilters(seq)
for i := 0; i <= len(seq)-3; i++ {
fork, isfork := seq[i].(*dag.ForkOp)
leftAlias, rightAlias, isjoin := isJoin(seq[i+1])
leftAlias, rightAlias, _, isjoin := isJoin(seq[i+1])
filter, isfilter := seq[i+2].(*dag.FilterOp)
if !isfork || !isjoin || !isfilter {
continue
Expand All @@ -535,27 +536,31 @@ func joinFilterPullup(seq dag.Seq) dag.Seq {
// Filter has been fully pulled up and can be removed.
seq.Delete(i+2, i+3)
} else {
out := remaining[0]
for _, e := range remaining[1:] {
out = dag.NewBinaryExpr("and", e, out)
}
seq[i+2] = dag.NewFilterOp(out)
seq[i+2] = dag.NewFilterOp(buildConjunction(remaining))
}
fork.Paths[0] = joinFilterPullup(fork.Paths[0])
fork.Paths[1] = joinFilterPullup(fork.Paths[1])
}
return seq
}

func isJoin(op dag.Op) (string, string, bool) {
func isJoin(op dag.Op) (string, string, string, bool) {
switch op := op.(type) {
case *dag.HashJoinOp:
return op.LeftAlias, op.RightAlias, true
return op.LeftAlias, op.RightAlias, "inner", true
case *dag.JoinOp:
return op.LeftAlias, op.RightAlias, true
return op.LeftAlias, op.RightAlias, op.Style, true
default:
return "", "", false
return "", "", "", false
}
}

func buildConjunction(exprs []dag.Expr) dag.Expr {
out := exprs[0]
for _, e := range exprs[1:] {
out = dag.NewBinaryExpr("and", e, out)
}
return out
}

func splitPredicate(e dag.Expr) []dag.Expr {
Expand Down
78 changes: 78 additions & 0 deletions compiler/ztests/join-convert-cross.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
script: |
super compile -O -C "
fork
(
values {a1:1}
)
(
fork
(
values {a2:1}
)
(
values {a3:1}
)
| cross join as {left,right}
)
| cross join as {left,right}
| where right.right.a3 == right.left.a2+1
| where right.left.a2 == left.a1
"
echo // ===
super compile -O -C "
fork
(
fork
(
values {a2:1}
)
(
values {a3:1}
)
| cross join as {left,right}
)
(
values {a1:1}
)
| cross join as {left,right}
| where left.left.a2 == left.right.a3 and right.a1 == left.left.a2
"

outputs:
- name: stdout
data: |
null
| fork
(
values {a1:1}
)
(
fork
(
values {a2:1}
)
(
values {a3:1}
)
| inner hashjoin as {left,right} on a2+1==a3
)
| inner hashjoin as {left,right} on a1==left.a2
| output main
// ===
null
| fork
(
fork
(
values {a2:1}
)
(
values {a3:1}
)
| inner hashjoin as {left,right} on a2==a3
)
(
values {a1:1}
)
| inner hashjoin as {left,right} on left.a2==a1
| output main
Loading