diff --git a/compiler/optimizer/join.go b/compiler/optimizer/join.go index 902f4c1dc..9be2bcf3b 100644 --- a/compiler/optimizer/join.go +++ b/compiler/optimizer/join.go @@ -27,6 +27,85 @@ func replaceJoinWithHashJoin(seq dag.Seq) { }) } +func liftFilterConvertCrossJoin(seq dag.Seq) dag.Seq { + var filter *dag.FilterOp + var i int + for i = range len(seq) - 3 { + _, isfork := seq[i].(*dag.ForkOp) + _, _, _, isjoin := isJoin(seq[i+1]) + f, isfilter := seq[i+2].(*dag.FilterOp) + if isfork && isjoin && isfilter { + filter = f + } + } + if filter == nil { + return seq + } + var exprs []dag.Expr + for _, e := range splitPredicate(filter.Expr) { + b, ok := e.(*dag.BinaryExpr) + if !ok || b.Op != "==" { + continue + } + if ok := convertCrossJoinToHashJoin(seq[i:], b.LHS, b.RHS); !ok { + exprs = append(exprs, e) + } + } + if len(exprs) == 0 { + seq.Delete(i+2, i+3) + } else { + seq[i+2] = dag.NewFilterOp(buildConjunction(exprs)) + } + return seq +} + +func convertCrossJoinToHashJoin(seq dag.Seq, lhs, rhs dag.Expr) bool { + fork, isfork := seq[0].(*dag.ForkOp) + leftAlias, rightAlias, style, isjoin := isJoin(seq[1]) + if !isfork || !isjoin { + return false + } + if len(fork.Paths) != 2 { + panic(fork) + } + lhsFirst, lok := firstThisPathComponent(lhs) + rhsFirst, rok := firstThisPathComponent(rhs) + if !lok || !rok { + return false + } + lhs, rhs = dag.CopyExpr(lhs), dag.CopyExpr(rhs) + stripFirstThisPathComponent(lhs) + stripFirstThisPathComponent(rhs) + if lhsFirst == rhsFirst { + if lhsFirst == leftAlias { + return convertCrossJoinToHashJoin(fork.Paths[0], lhs, rhs) + } + if lhsFirst == rightAlias { + return convertCrossJoinToHashJoin(fork.Paths[1], lhs, rhs) + } + return false + } + if style != "cross" { + return false + } + if lhsFirst != leftAlias { + lhsFirst, rhsFirst = rhsFirst, lhsFirst + lhs, rhs = rhs, lhs + } + if lhsFirst != leftAlias || rhsFirst != rightAlias { + return false + } + seq[1] = &dag.HashJoinOp{ + Kind: "HashJoinOp", + Style: "inner", + LeftAlias: leftAlias, + RightAlias: rightAlias, + LeftKey: lhs, + RightKey: rhs, + } + return true +} + func equiJoinKeyExprs(e dag.Expr, leftAlias, rightAlias string) (left, right dag.Expr, ok bool) { b, ok := e.(*dag.BinaryExpr) if !ok || b.Op != "==" { diff --git a/compiler/optimizer/optimizer.go b/compiler/optimizer/optimizer.go index 4553f2126..4e3ab1a71 100644 --- a/compiler/optimizer/optimizer.go +++ b/compiler/optimizer/optimizer.go @@ -144,6 +144,7 @@ func (o *Optimizer) Optimize(main *dag.Main) error { seq = mergeFilters(seq) seq = mergeValuesOps(seq) inlineRecordExprSpreads(seq) + seq = liftFilterConvertCrossJoin(seq) seq = joinFilterPullup(seq) seq = removePassOps(seq) seq = replaceSortAndHeadOrTailWithTop(seq) @@ -511,7 +512,7 @@ func joinFilterPullup(seq dag.Seq) dag.Seq { seq = mergeFilters(seq) for i := 0; i <= len(seq)-3; i++ { fork, isfork := seq[i].(*dag.ForkOp) - leftAlias, rightAlias, isjoin := isJoin(seq[i+1]) + leftAlias, rightAlias, _, isjoin := isJoin(seq[i+1]) filter, isfilter := seq[i+2].(*dag.FilterOp) if !isfork || !isjoin || !isfilter { continue @@ -535,11 +536,7 @@ func joinFilterPullup(seq dag.Seq) dag.Seq { // Filter has been fully pulled up and can be removed. seq.Delete(i+2, i+3) } else { - out := remaining[0] - for _, e := range remaining[1:] { - out = dag.NewBinaryExpr("and", e, out) - } - seq[i+2] = dag.NewFilterOp(out) + seq[i+2] = dag.NewFilterOp(buildConjunction(remaining)) } fork.Paths[0] = joinFilterPullup(fork.Paths[0]) fork.Paths[1] = joinFilterPullup(fork.Paths[1]) @@ -547,15 +544,23 @@ func joinFilterPullup(seq dag.Seq) dag.Seq { return seq } -func isJoin(op dag.Op) (string, string, bool) { +func isJoin(op dag.Op) (string, string, string, bool) { switch op := op.(type) { case *dag.HashJoinOp: - return op.LeftAlias, op.RightAlias, true + return op.LeftAlias, op.RightAlias, "inner", true case *dag.JoinOp: - return op.LeftAlias, op.RightAlias, true + return op.LeftAlias, op.RightAlias, op.Style, true default: - return "", "", false + return "", "", "", false + } +} + +func buildConjunction(exprs []dag.Expr) dag.Expr { + out := exprs[0] + for _, e := range exprs[1:] { + out = dag.NewBinaryExpr("and", e, out) } + return out } func splitPredicate(e dag.Expr) []dag.Expr { diff --git a/compiler/ztests/join-convert-cross.yaml b/compiler/ztests/join-convert-cross.yaml new file mode 100644 index 000000000..e4aff2ab0 --- /dev/null +++ b/compiler/ztests/join-convert-cross.yaml @@ -0,0 +1,78 @@ +script: | + super compile -O -C " + fork + ( + values {a1:1} + ) + ( + fork + ( + values {a2:1} + ) + ( + values {a3:1} + ) + | cross join as {left,right} + ) + | cross join as {left,right} + | where right.right.a3 == right.left.a2+1 + | where right.left.a2 == left.a1 + " + echo // === + super compile -O -C " + fork + ( + fork + ( + values {a2:1} + ) + ( + values {a3:1} + ) + | cross join as {left,right} + ) + ( + values {a1:1} + ) + | cross join as {left,right} + | where left.left.a2 == left.right.a3 and right.a1 == left.left.a2 + " + +outputs: + - name: stdout + data: | + null + | fork + ( + values {a1:1} + ) + ( + fork + ( + values {a2:1} + ) + ( + values {a3:1} + ) + | inner hashjoin as {left,right} on a2+1==a3 + ) + | inner hashjoin as {left,right} on a1==left.a2 + | output main + // === + null + | fork + ( + fork + ( + values {a2:1} + ) + ( + values {a3:1} + ) + | inner hashjoin as {left,right} on a2==a3 + ) + ( + values {a1:1} + ) + | inner hashjoin as {left,right} on left.a2==a1 + | output main