diff --git a/enginetest/join_planning_tests.go b/enginetest/join_planning_tests.go index 8d99d1e3fb..21dac2f88a 100644 --- a/enginetest/join_planning_tests.go +++ b/enginetest/join_planning_tests.go @@ -43,11 +43,13 @@ type JoinPlanTest struct { skipOld bool } -var JoinPlanningTests = []struct { +type joinPlanScript struct { name string setup []string tests []JoinPlanTest -}{ +} + +var JoinPlanningTests = []joinPlanScript{ { name: "filter pushdown through join uppercase name", setup: []string{ @@ -78,6 +80,33 @@ var JoinPlanningTests = []struct { }, }, }, + { + name: "block merge join", + setup: []string{ + "CREATE table xy (x int primary key, y int, unique index y_idx(y));", + "CREATE table ab (a int primary key, b int);", + "insert into xy values (1,0), (2,1), (0,2), (3,3);", + "insert into ab values (0,2), (1,2), (2,2), (3,1);", + `analyze table xy update histogram on x using data '{"row_count":1000}'`, + `analyze table ab update histogram on a using data '{"row_count":1000}'`, + }, + tests: []JoinPlanTest{ + { + q: "select /*+ JOIN_ORDER(ab, xy) MERGE_JOIN(ab, xy)*/ * from ab join xy on y = a order by 1, 3", + types: []plan.JoinType{plan.JoinTypeMerge}, + exp: []sql.Row{{0, 2, 1, 0}, {1, 2, 2, 1}, {2, 2, 0, 2}, {3, 1, 3, 3}}, + }, + { + q: "set @@SESSION.disable_merge_join = 1", + exp: []sql.Row{{}}, + }, + { + q: "select /*+ JOIN_ORDER(ab, xy) MERGE_JOIN(ab, xy)*/ * from ab join xy on y = a order by 1, 3", + types: []plan.JoinType{plan.JoinTypeLookup}, + exp: []sql.Row{{0, 2, 1, 0}, {1, 2, 2, 1}, {2, 2, 0, 2}, {3, 1, 3, 3}}, + }, + }, + }, { name: "merge join unary index", setup: []string{ @@ -1725,8 +1754,17 @@ join uv d on d.u = c.x`, } func TestJoinPlanning(t *testing.T, harness Harness) { - for _, tt := range JoinPlanningTests { + runJoinPlanningTests(t, harness, JoinPlanningTests) +} + +func runJoinPlanningTests(t *testing.T, harness Harness, tests []joinPlanScript) { + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + if sh, ok := harness.(SkippingHarness); ok { + if sh.SkipQueryTest(tt.name) { + t.Skip(tt.name) + } + } harness.Setup([]setup.SetupScript{setup.MydbData[0], tt.setup}) e := mustNewEngine(t, harness) defer e.Close() @@ -1750,7 +1788,6 @@ func TestJoinPlanning(t *testing.T, harness Harness) { }) } } - func evalJoinTypeTest(t *testing.T, harness Harness, e QueryEngine, query string, types []plan.JoinType, skipOld bool) { t.Run(query+" join types", func(t *testing.T) { if skipOld { diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index bd1f10eafa..e59f680386 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -111,7 +111,11 @@ func TestLateralJoin(t *testing.T) { // TestJoinPlanning runs join-specific tests for merge func TestJoinPlanning(t *testing.T) { - enginetest.TestJoinPlanning(t, enginetest.NewDefaultMemoryHarness()) + harness := enginetest.NewDefaultMemoryHarness() + if harness.IsUsingServer() { + harness.QueriesToSkip("block merge join") + } + enginetest.TestJoinPlanning(t, harness) } // TestJoinOps runs join-specific tests for merge diff --git a/sql/analyzer/indexed_joins.go b/sql/analyzer/indexed_joins.go index f53d9b4199..c91c720523 100644 --- a/sql/analyzer/indexed_joins.go +++ b/sql/analyzer/indexed_joins.go @@ -200,6 +200,7 @@ func replanJoin(ctx *sql.Context, n *plan.JoinNode, a *Analyzer, scope *plan.Sco return nil, err } + m.SetDefaultHints() hints := memo.ExtractJoinHint(n) for _, h := range hints { // this should probably happen earlier, but the root is not diff --git a/sql/memo/hinttype_string.go b/sql/memo/hinttype_string.go index 13579a90f9..ea3c9892da 100644 --- a/sql/memo/hinttype_string.go +++ b/sql/memo/hinttype_string.go @@ -11,20 +11,21 @@ func _() { _ = x[HintTypeUnknown-0] _ = x[HintTypeJoinOrder-1] _ = x[HintTypeJoinFixedOrder-2] - _ = x[HintTypeMergeJoin-3] - _ = x[HintTypeLookupJoin-4] - _ = x[HintTypeHashJoin-5] - _ = x[HintTypeSemiJoin-6] - _ = x[HintTypeAntiJoin-7] - _ = x[HintTypeInnerJoin-8] - _ = x[HintTypeLeftOuterLookupJoin-9] - _ = x[HintTypeNoIndexConditionPushDown-10] - _ = x[HintTypeLeftDeep-11] + _ = x[HintTypeNoMergeJoin-3] + _ = x[HintTypeMergeJoin-4] + _ = x[HintTypeLookupJoin-5] + _ = x[HintTypeHashJoin-6] + _ = x[HintTypeSemiJoin-7] + _ = x[HintTypeAntiJoin-8] + _ = x[HintTypeInnerJoin-9] + _ = x[HintTypeLeftOuterLookupJoin-10] + _ = x[HintTypeNoIndexConditionPushDown-11] + _ = x[HintTypeLeftDeep-12] } -const _HintType_name = "JOIN_ORDERJOIN_FIXED_ORDERMERGE_JOINLOOKUP_JOINHASH_JOINSEMI_JOINANTI_JOININNER_JOINLEFT_OUTER_LOOKUP_JOINNO_ICPLEFT_DEEP" +const _HintType_name = "JOIN_ORDERJOIN_FIXED_ORDERNO_MERGE_JOINMERGE_JOINLOOKUP_JOINHASH_JOINSEMI_JOINANTI_JOININNER_JOINLEFT_OUTER_LOOKUP_JOINNO_ICPLEFT_DEEP" -var _HintType_index = [...]uint8{0, 0, 10, 26, 36, 47, 56, 65, 74, 84, 106, 112, 121} +var _HintType_index = [...]uint8{0, 0, 10, 26, 39, 49, 60, 69, 78, 87, 97, 119, 125, 134} func (i HintType) String() string { if i >= HintType(len(_HintType_index)-1) { diff --git a/sql/memo/memo.go b/sql/memo/memo.go index 4dd401815e..d498dc6ec2 100644 --- a/sql/memo/memo.go +++ b/sql/memo/memo.go @@ -82,6 +82,12 @@ func (m *Memo) StatsProvider() sql.StatsProvider { return m.statsProv } +func (m *Memo) SetDefaultHints() { + if val, _ := m.Ctx.GetSessionVariable(m.Ctx, sql.DisableMergeJoin); val.(int8) != 0 { + m.ApplyHint(Hint{Typ: HintTypeNoMergeJoin}) + } +} + // newExprGroup creates a new logical expression group to encapsulate the // action of a SQL clause. // TODO: this is supposed to deduplicate logically equivalent table scans @@ -459,6 +465,11 @@ func (m *Memo) optimizeMemoGroup(grp *ExprGroup) error { // rather than a local property. func (m *Memo) updateBest(grp *ExprGroup, n RelExpr, cost float64) { if !m.hints.isEmpty() { + for _, block := range m.hints.block { + if !block.isOk(n) { + return + } + } if m.hints.satisfiedBy(n) { if !grp.HintOk { grp.Best = n @@ -514,17 +525,32 @@ func getProjectColset(p *Project) sql.ColSet { func (m *Memo) ApplyHint(hint Hint) { switch hint.Typ { case HintTypeJoinOrder: - m.WithJoinOrder(hint.Args) + m.SetJoinOrder(hint.Args) case HintTypeJoinFixedOrder: + case HintTypeNoMergeJoin: + m.SetBlockOp(func(n RelExpr) bool { + switch n := n.(type) { + case JoinRel: + jp := n.JoinPrivate() + if !jp.Left.Best.Group().HintOk || !jp.Right.Best.Group().HintOk { + // equiv closures can generate child plans that bypass hints + return false + } + if jp.Op.IsMerge() { + return false + } + } + return true + }) case HintTypeInnerJoin, HintTypeMergeJoin, HintTypeLookupJoin, HintTypeHashJoin, HintTypeSemiJoin, HintTypeAntiJoin, HintTypeLeftOuterLookupJoin: - m.WithJoinOp(hint.Typ, hint.Args[0], hint.Args[1]) + m.SetJoinOp(hint.Typ, hint.Args[0], hint.Args[1]) case HintTypeLeftDeep: m.hints.leftDeep = true default: } } -func (m *Memo) WithJoinOrder(tables []string) { +func (m *Memo) SetJoinOrder(tables []string) { // order maps groupId -> table dependencies order := make(map[sql.TableId]uint64) for i, t := range tables { @@ -542,7 +568,11 @@ func (m *Memo) WithJoinOrder(tables []string) { } } -func (m *Memo) WithJoinOp(op HintType, left, right string) { +func (m *Memo) SetBlockOp(cb func(n RelExpr) bool) { + m.hints.block = append(m.hints.block, joinBlockHint{cb: cb}) +} + +func (m *Memo) SetJoinOp(op HintType, left, right string) { var lTab, rTab sql.TableId for _, n := range m.root.RelProps.TableIdNodes() { if strings.EqualFold(left, n.Name()) { diff --git a/sql/memo/select_hints.go b/sql/memo/select_hints.go index f22926f853..6b41d168d6 100644 --- a/sql/memo/select_hints.go +++ b/sql/memo/select_hints.go @@ -32,6 +32,7 @@ const ( HintTypeUnknown HintType = iota // HintTypeJoinOrder // JOIN_ORDER HintTypeJoinFixedOrder // JOIN_FIXED_ORDER + HintTypeNoMergeJoin // NO_MERGE_JOIN HintTypeMergeJoin // MERGE_JOIN HintTypeLookupJoin // LOOKUP_JOIN HintTypeHashJoin // HASH_JOIN @@ -81,6 +82,8 @@ func newHint(joinTyp string, args []string) Hint { typ = HintTypeNoIndexConditionPushDown case "left_deep": typ = HintTypeLeftDeep + case "no_merge_join": + typ = HintTypeNoMergeJoin default: typ = HintTypeUnknown } @@ -111,6 +114,8 @@ func (h Hint) valid() bool { return len(h.Args) == 0 case HintTypeLeftDeep: return len(h.Args) == 0 + case HintTypeNoMergeJoin: + return true case HintTypeUnknown: return false default: @@ -367,11 +372,20 @@ func (o joinOpHint) typeMatches(n RelExpr) bool { return true } +type joinBlockHint struct { + cb func(n RelExpr) bool +} + +func (o joinBlockHint) isOk(n RelExpr) bool { + return o.cb(n) +} + // joinHints wraps a collection of join hints. The memo // interfaces with this object during costing. type joinHints struct { ops []joinOpHint order *joinOrderHint + block []joinBlockHint leftDeep bool } diff --git a/sql/memo/select_hints_test.go b/sql/memo/select_hints_test.go index 43211a40b3..45ba45de32 100644 --- a/sql/memo/select_hints_test.go +++ b/sql/memo/select_hints_test.go @@ -238,7 +238,7 @@ func TestOrderHintBuilding(t *testing.T) { t.Run(tt.name, func(t *testing.T) { j := NewJoinOrderBuilder(NewMemo(newContext(pro), nil, nil, 0, NewDefaultCoster(), nil)) j.ReorderJoin(tt.plan) - j.m.WithJoinOrder(tt.hint) + j.m.SetJoinOrder(tt.hint) if tt.invalid { require.Equal(t, j.m.hints.order, (*joinOrderHint)(nil)) } else { diff --git a/sql/statistics.go b/sql/statistics.go index fef5b6f27b..1dd283834b 100644 --- a/sql/statistics.go +++ b/sql/statistics.go @@ -21,6 +21,8 @@ import ( "time" ) +const DisableMergeJoin = "disable_merge_join" + // StatisticsTable is a table that can provide information about its number of rows and other facts to improve query // planning performance. type StatisticsTable interface { diff --git a/sql/variables/system_variables.go b/sql/variables/system_variables.go index 60fdc00a31..2e96bf9840 100644 --- a/sql/variables/system_variables.go +++ b/sql/variables/system_variables.go @@ -1049,6 +1049,14 @@ var systemVars = map[string]sql.SystemVariable{ Type: types.NewSystemBoolType("inmemory_joins"), Default: int8(0), }, + "disable_merge_join": &sql.MysqlSystemVariable{ + Name: sql.DisableMergeJoin, + Scope: sql.GetMysqlScope(sql.SystemVariableScope_Both), + Dynamic: true, + SetVarHintApplies: false, + Type: types.NewSystemBoolType(sql.DisableMergeJoin), + Default: int8(0), + }, "innodb_autoinc_lock_mode": &sql.MysqlSystemVariable{ Name: "innodb_autoinc_lock_mode", Scope: sql.GetMysqlScope(sql.SystemVariableScope_Global),