Skip to content

Commit 00fc297

Browse files
committed
Propogate context parameter.
1 parent 5585f96 commit 00fc297

20 files changed

+102
-108
lines changed

enginetest/enginetests.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6043,15 +6043,16 @@ func findRole(toUser string, roles []*mysql_db.RoleEdge) *mysql_db.RoleEdge {
60436043
}
60446044

60456045
func TestBlobs(t *testing.T, h Harness) {
6046+
ctx := sql.NewEmptyContext()
60466047
h.Setup(setup.MydbData, setup.BlobData, setup.MytableData)
60476048

60486049
// By default, strict_mysql_compatibility is disabled, but these tests require it to be enabled.
6049-
err := sql.SystemVariables.SetGlobal("strict_mysql_compatibility", int8(1))
6050+
err := sql.SystemVariables.SetGlobal(ctx, "strict_mysql_compatibility", int8(1))
60506051
require.NoError(t, err)
60516052
for _, tt := range queries.BlobErrors {
60526053
runQueryErrorTest(t, h, tt)
60536054
}
6054-
err = sql.SystemVariables.SetGlobal("strict_mysql_compatibility", int8(0))
6055+
err = sql.SystemVariables.SetGlobal(ctx, "strict_mysql_compatibility", int8(0))
60556056
require.NoError(t, err)
60566057

60576058
e := mustNewEngine(t, h)

enginetest/histogram_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ func TestMultiDist(t *testing.T) {
236236
// the stats join algo to simulate a join estimate, and (4) compare the
237237
// estimate to the actual result set count.
238238
func runStatsSuite(t *testing.T, tests []statsTest, rowCnt, bucketCnt int, debug bool) {
239+
ctx := sql.NewEmptyContext()
239240
for i, tt := range tests {
240241
t.Run(fmt.Sprintf("%s: , rows: %d, buckets: %d", tt.name, rowCnt, bucketCnt), func(t *testing.T) {
241242
db := memory.NewDatabase(fmt.Sprintf("test%d", i))
@@ -275,7 +276,7 @@ func runStatsSuite(t *testing.T, tests []statsTest, rowCnt, bucketCnt int, debug
275276
rStat.Hist = append(rStat.Hist, b.(*stats.Bucket))
276277
}
277278

278-
res, err := stats.Join(stats.UpdateCounts(lStat), stats.UpdateCounts(rStat), 1, debug)
279+
res, err := stats.Join(ctx, stats.UpdateCounts(lStat), stats.UpdateCounts(rStat), 1, debug)
279280
require.NoError(t, err)
280281
if debug {
281282
log.Printf("join %s\n", res.Histogram().DebugString())

processlist_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ func TestSlowQueryTracking(t *testing.T) {
242242
require.NoError(t, err)
243243

244244
// Change @@long_query_time so we don't have to wait for 10 seconds
245-
require.NoError(t, sql.SystemVariables.SetGlobal("long_query_time", 1))
245+
require.NoError(t, sql.SystemVariables.SetGlobal(ctx, "long_query_time", 1))
246246
time.Sleep(1_500 * time.Millisecond)
247247
p.EndQuery(ctx)
248248

sql/analyzer/costed_index_scan.go

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ func (c *indexCoster) cost(f indexFilter, stat sql.Statistic, idx sql.Index) err
462462

463463
switch f := f.(type) {
464464
case *iScanAnd:
465-
newHist, newFds, filters, prefix, err = c.costIndexScanAnd(f, stat, stat.Histogram(), ordinals, idx)
465+
newHist, newFds, filters, prefix, err = c.costIndexScanAnd(c.ctx, f, stat, stat.Histogram(), ordinals, idx)
466466
if err != nil {
467467
return err
468468
}
@@ -1199,7 +1199,7 @@ func ordinalsForStat(stat sql.Statistic) map[string]int {
11991199
// updated statistic, the subset of applicable filters, the maximum prefix
12001200
// key created by a subset of equality filters (from conjunction only),
12011201
// or an error if applicable.
1202-
func (c *indexCoster) costIndexScanAnd(filter *iScanAnd, s sql.Statistic, buckets []sql.HistogramBucket, ordinals map[string]int, idx sql.Index) ([]sql.HistogramBucket, *sql.FuncDepSet, sql.FastIntSet, int, error) {
1202+
func (c *indexCoster) costIndexScanAnd(ctx *sql.Context, filter *iScanAnd, s sql.Statistic, buckets []sql.HistogramBucket, ordinals map[string]int, idx sql.Index) ([]sql.HistogramBucket, *sql.FuncDepSet, sql.FastIntSet, int, error) {
12031203
// first step finds the conjunctions that match index prefix columns.
12041204
// we divide into eqFilters and rangeFilters
12051205

@@ -1214,7 +1214,7 @@ func (c *indexCoster) costIndexScanAnd(filter *iScanAnd, s sql.Statistic, bucket
12141214
}
12151215
// if valid, INTERSECT
12161216
if ok {
1217-
ret, err = stats.Intersect(ret, childStat, s.Types())
1217+
ret, err = stats.Intersect(c.ctx, ret, childStat, s.Types())
12181218
if err != nil {
12191219
return nil, nil, sql.FastIntSet{}, 0, err
12201220
}
@@ -1227,7 +1227,7 @@ func (c *indexCoster) costIndexScanAnd(filter *iScanAnd, s sql.Statistic, bucket
12271227
for _, c := range s.Columns() {
12281228
if colFilters, ok := filter.leafChildren[c]; ok {
12291229
for _, f := range colFilters {
1230-
conj.add(f)
1230+
conj.add(ctx, f)
12311231
}
12321232
}
12331233
}
@@ -1253,15 +1253,15 @@ func (c *indexCoster) costIndexScanOr(filter *iScanOr, s sql.Statistic, buckets
12531253
for _, child := range filter.children {
12541254
switch child := child.(type) {
12551255
case *iScanAnd:
1256-
childBuckets, _, ids, _, err := c.costIndexScanAnd(child, s, buckets, ordinals, idx)
1256+
childBuckets, _, ids, _, err := c.costIndexScanAnd(c.ctx, child, s, buckets, ordinals, idx)
12571257
if err != nil {
12581258
return nil, nil, false, err
12591259
}
12601260
if ids.Len() != 1 || !ids.Contains(int(child.Id())) {
12611261
// scan option missed some filters
12621262
return nil, nil, false, nil
12631263
}
1264-
ret, err = stats.Union(buckets, childBuckets, s.Types())
1264+
ret, err = stats.Union(c.ctx, buckets, childBuckets, s.Types())
12651265
if err != nil {
12661266
return nil, nil, false, err
12671267
}
@@ -1275,7 +1275,7 @@ func (c *indexCoster) costIndexScanOr(filter *iScanOr, s sql.Statistic, buckets
12751275
if !ok {
12761276
return nil, nil, false, nil
12771277
}
1278-
ret, err = stats.Union(ret, childBuckets, s.Types())
1278+
ret, err = stats.Union(c.ctx, ret, childBuckets, s.Types())
12791279
if err != nil {
12801280
return nil, nil, false, err
12811281
}
@@ -1339,7 +1339,7 @@ func (c *indexCoster) costIndexScanLeaf(filter *iScanLeaf, s sql.Statistic, buck
13391339
return buckets, stat.FuncDeps(), ok, 0, err
13401340
default:
13411341
conj := newConjCollector(s, buckets, ordinals)
1342-
conj.add(filter)
1342+
conj.add(c.ctx, filter)
13431343
var conjFDs *sql.FuncDepSet
13441344
if idx.IsUnique() {
13451345
conjFDs = conj.getFds()
@@ -1670,19 +1670,19 @@ type conjCollector struct {
16701670
isFalse bool
16711671
}
16721672

1673-
func (c *conjCollector) add(f *iScanLeaf) error {
1673+
func (c *conjCollector) add(ctx *sql.Context, f *iScanLeaf) error {
16741674
c.applied.Add(int(f.Id()))
16751675
var err error
16761676
switch f.Op() {
16771677
case IndexScanOpNullSafeEq:
1678-
err = c.addEq(f.gf.Name(), f.litValue, true)
1678+
err = c.addEq(ctx, f.gf.Name(), f.litValue, true)
16791679
case IndexScanOpEq:
1680-
err = c.addEq(f.gf.Name(), f.litValue, false)
1680+
err = c.addEq(ctx, f.gf.Name(), f.litValue, false)
16811681
case IndexScanOpInSet:
16821682
// TODO cost UNION of equals
1683-
err = c.addEq(f.gf.Name(), f.setValues[0], false)
1683+
err = c.addEq(ctx, f.gf.Name(), f.setValues[0], false)
16841684
default:
1685-
err = c.addIneq(f.Op(), f.gf.Name(), f.litValue)
1685+
err = c.addIneq(ctx, f.Op(), f.gf.Name(), f.litValue)
16861686
}
16871687
return err
16881688
}
@@ -1695,7 +1695,7 @@ func (c *conjCollector) getFds() *sql.FuncDepSet {
16951695
return sql.NewLookupFDs(c.stat.FuncDeps(), c.stat.ColSet(), sql.ColSet{}, constCols, nil)
16961696
}
16971697

1698-
func (c *conjCollector) addEq(col string, val interface{}, nullSafe bool) error {
1698+
func (c *conjCollector) addEq(ctx *sql.Context, col string, val interface{}, nullSafe bool) error {
16991699
// make constant
17001700
ord := c.ordinals[col]
17011701
if c.constant.Contains(ord + 1) {
@@ -1722,20 +1722,20 @@ func (c *conjCollector) addEq(col string, val interface{}, nullSafe bool) error
17221722

17231723
// truncate buckets
17241724
var err error
1725-
c.hist, err = stats.PrefixKey(c.stat.Histogram(), c.stat.Types(), c.eqVals[:ord+1])
1725+
c.hist, err = stats.PrefixKey(ctx, c.stat.Histogram(), c.stat.Types(), c.eqVals[:ord+1])
17261726
if err != nil {
17271727
return err
17281728
}
17291729
}
17301730
return nil
17311731
}
17321732

1733-
func (c *conjCollector) addIneq(op IndexScanOp, col string, val interface{}) error {
1733+
func (c *conjCollector) addIneq(ctx *sql.Context, op IndexScanOp, col string, val interface{}) error {
17341734
ord := c.ordinals[col]
17351735
if ord > 0 {
17361736
return nil
17371737
}
1738-
err := c.cmpFirstCol(op, val)
1738+
err := c.cmpFirstCol(ctx, op, val)
17391739
if err != nil {
17401740
return err
17411741
}
@@ -1744,7 +1744,7 @@ func (c *conjCollector) addIneq(op IndexScanOp, col string, val interface{}) err
17441744

17451745
// cmpFirstCol checks whether we should try to range truncate the first
17461746
// column in the index
1747-
func (c *conjCollector) cmpFirstCol(op IndexScanOp, val interface{}) error {
1747+
func (c *conjCollector) cmpFirstCol(ctx *sql.Context, op IndexScanOp, val interface{}) error {
17481748
// check if first col already constant
17491749
// otherwise attempt to truncate histogram
17501750
var err error
@@ -1754,15 +1754,15 @@ func (c *conjCollector) cmpFirstCol(op IndexScanOp, val interface{}) error {
17541754
switch op {
17551755
case IndexScanOpNotEq:
17561756
// todo notEq
1757-
c.hist, err = stats.PrefixGt(c.hist, c.stat.Types(), val)
1757+
c.hist, err = stats.PrefixGt(ctx, c.hist, c.stat.Types(), val)
17581758
case IndexScanOpGt:
1759-
c.hist, err = stats.PrefixGt(c.hist, c.stat.Types(), val)
1759+
c.hist, err = stats.PrefixGt(ctx, c.hist, c.stat.Types(), val)
17601760
case IndexScanOpGte:
1761-
c.hist, err = stats.PrefixGte(c.hist, c.stat.Types(), val)
1761+
c.hist, err = stats.PrefixGte(ctx, c.hist, c.stat.Types(), val)
17621762
case IndexScanOpLt:
1763-
c.hist, err = stats.PrefixLt(c.hist, c.stat.Types(), val)
1763+
c.hist, err = stats.PrefixLt(ctx, c.hist, c.stat.Types(), val)
17641764
case IndexScanOpLte:
1765-
c.hist, err = stats.PrefixLte(c.hist, c.stat.Types(), val)
1765+
c.hist, err = stats.PrefixLte(ctx, c.hist, c.stat.Types(), val)
17661766
case IndexScanOpIsNull:
17671767
c.hist, err = stats.PrefixIsNull(c.hist)
17681768
case IndexScanOpIsNotNull:

sql/analyzer/indexed_joins.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ func replanJoin(ctx *sql.Context, n *plan.JoinNode, a *Analyzer, scope *plan.Sco
185185
return nil, err
186186
}
187187

188-
memo.CardMemoGroups(m.Root())
188+
memo.CardMemoGroups(ctx, m.Root())
189189

190190
err = addCrossHashJoins(m)
191191
if err != nil {

sql/analyzer/validation_rules_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ func TestValidateGroupByErr(t *testing.T) {
152152
plan.NewResolvedTable(child, nil, nil),
153153
)
154154

155-
err = sql.SystemVariables.SetGlobal("sql_mode", "NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES")
155+
err = sql.SystemVariables.SetGlobal(ctx, "sql_mode", "NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES")
156156
require.NoError(err)
157157
_, _, err = vr.Apply(ctx, nil, p, nil, DefaultRuleSelector, nil)
158158
require.Error(err)

sql/core.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ type SystemVariableRegistry interface {
414414
// GetGlobal returns the current global value of the system variable with the given name
415415
GetGlobal(name string) (SystemVariable, interface{}, bool)
416416
// SetGlobal sets the global value of the system variable with the given name
417-
SetGlobal(name string, val interface{}) error
417+
SetGlobal(ctx *Context, name string, val interface{}) error
418418
// GetAllGlobalVariables returns a copy of all global variable values.
419419
GetAllGlobalVariables() map[string]interface{}
420420
}
@@ -480,7 +480,7 @@ type MysqlSystemVariable struct {
480480
// the global context and in a particular session. They should never
481481
// block. NotifyChanged is not called when a new system variable is
482482
// registered.
483-
NotifyChanged func(SystemVariableScope, SystemVarValue) error
483+
NotifyChanged func(*Context, SystemVariableScope, SystemVarValue) error
484484
// ValueFunction defines an optional function that is executed to provide
485485
// the value of this system variable whenever it is requested. System variables
486486
// that provide a ValueFunction should also set Dynamic to false, since they
@@ -528,7 +528,7 @@ func (m *MysqlSystemVariable) InitValue(ctx *Context, val any, global bool) (Sys
528528
scope = GetMysqlScope(SystemVariableScope_Global)
529529
}
530530
if m.NotifyChanged != nil {
531-
err = m.NotifyChanged(scope, svv)
531+
err = m.NotifyChanged(ctx, scope, svv)
532532
if err != nil {
533533
return SystemVarValue{}, err
534534
}
@@ -595,7 +595,7 @@ func GetMysqlScope(t MysqlSVScopeType) *MysqlScope {
595595
func (m *MysqlScope) SetValue(ctx *Context, name string, val any) error {
596596
switch m.Type {
597597
case SystemVariableScope_Global:
598-
err := SystemVariables.SetGlobal(name, val)
598+
err := SystemVariables.SetGlobal(ctx, name, val)
599599
if err != nil {
600600
return err
601601
}
@@ -613,7 +613,7 @@ func (m *MysqlScope) SetValue(ctx *Context, name string, val any) error {
613613
if err != nil {
614614
return err
615615
}
616-
err = SystemVariables.SetGlobal(name, val)
616+
err = SystemVariables.SetGlobal(ctx, name, val)
617617
if err != nil {
618618
return err
619619
}

sql/expression/arithmetic.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ func (a *Arithmetic) convertLeftRight(ctx *sql.Context, left interface{}, right
359359
if types.IsInteger(typ) || types.IsFloat(typ) || types.IsTime(typ) {
360360
left = convertValueToType(ctx, typ, left, lIsTimeType)
361361
} else {
362-
left = convertToDecimalValue(left, lIsTimeType)
362+
left = convertToDecimalValue(ctx, left, lIsTimeType)
363363
}
364364
}
365365

@@ -370,7 +370,7 @@ func (a *Arithmetic) convertLeftRight(ctx *sql.Context, left interface{}, right
370370
if types.IsInteger(typ) || types.IsFloat(typ) || types.IsTime(typ) {
371371
right = convertValueToType(ctx, typ, right, rIsTimeType)
372372
} else {
373-
right = convertToDecimalValue(right, rIsTimeType)
373+
right = convertToDecimalValue(ctx, right, rIsTimeType)
374374
}
375375
}
376376

sql/expression/comparison.go

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ func (c *comparison) Compare(ctx *sql.Context, row sql.Row) (int, error) {
173173
}
174174
}
175175
if compareType == nil {
176-
left, right, compareType, err = c.castLeftAndRight(left, right)
176+
left, right, compareType, err = c.castLeftAndRight(ctx, left, right)
177177
if err != nil {
178178
return 0, err
179179
}
@@ -201,15 +201,15 @@ func (c *comparison) evalLeftAndRight(ctx *sql.Context, row sql.Row) (interface{
201201
return left, right, nil
202202
}
203203

204-
func (c *comparison) castLeftAndRight(left, right interface{}) (interface{}, interface{}, sql.Type, error) {
204+
func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) (interface{}, interface{}, sql.Type, error) {
205205
leftType := c.Left().Type()
206206
rightType := c.Right().Type()
207207
if types.IsTuple(leftType) && types.IsTuple(rightType) {
208208
return left, right, c.Left().Type(), nil
209209
}
210210

211211
if types.IsTime(leftType) || types.IsTime(rightType) {
212-
l, r, err := convertLeftAndRight(left, right, ConvertToDatetime)
212+
l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDatetime)
213213
if err != nil {
214214
return nil, nil, nil, err
215215
}
@@ -223,7 +223,7 @@ func (c *comparison) castLeftAndRight(left, right interface{}) (interface{}, int
223223
}
224224

225225
if types.IsBinaryType(leftType) || types.IsBinaryType(rightType) {
226-
l, r, err := convertLeftAndRight(left, right, ConvertToBinary)
226+
l, r, err := convertLeftAndRight(ctx, left, right, ConvertToBinary)
227227
if err != nil {
228228
return nil, nil, nil, err
229229
}
@@ -233,7 +233,7 @@ func (c *comparison) castLeftAndRight(left, right interface{}) (interface{}, int
233233
if types.IsNumber(leftType) || types.IsNumber(rightType) {
234234
if types.IsDecimal(leftType) || types.IsDecimal(rightType) {
235235
//TODO: We need to set to the actual DECIMAL type
236-
l, r, err := convertLeftAndRight(left, right, ConvertToDecimal)
236+
l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDecimal)
237237
if err != nil {
238238
return nil, nil, nil, err
239239
}
@@ -246,7 +246,7 @@ func (c *comparison) castLeftAndRight(left, right interface{}) (interface{}, int
246246
}
247247

248248
if types.IsFloat(leftType) || types.IsFloat(rightType) {
249-
l, r, err := convertLeftAndRight(left, right, ConvertToDouble)
249+
l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDouble)
250250
if err != nil {
251251
return nil, nil, nil, err
252252
}
@@ -255,7 +255,7 @@ func (c *comparison) castLeftAndRight(left, right interface{}) (interface{}, int
255255
}
256256

257257
if types.IsSigned(leftType) && types.IsSigned(rightType) {
258-
l, r, err := convertLeftAndRight(left, right, ConvertToSigned)
258+
l, r, err := convertLeftAndRight(ctx, left, right, ConvertToSigned)
259259
if err != nil {
260260
return nil, nil, nil, err
261261
}
@@ -264,33 +264,31 @@ func (c *comparison) castLeftAndRight(left, right interface{}) (interface{}, int
264264
}
265265

266266
if types.IsUnsigned(leftType) && types.IsUnsigned(rightType) {
267-
l, r, err := convertLeftAndRight(left, right, ConvertToUnsigned)
267+
l, r, err := convertLeftAndRight(ctx, left, right, ConvertToUnsigned)
268268
if err != nil {
269269
return nil, nil, nil, err
270270
}
271271

272272
return l, r, types.Uint64, nil
273273
}
274274

275-
l, r, err := convertLeftAndRight(left, right, ConvertToDouble)
275+
l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDouble)
276276
if err != nil {
277277
return nil, nil, nil, err
278278
}
279279

280280
return l, r, types.Float64, nil
281281
}
282282

283-
left, right, err := convertLeftAndRight(left, right, ConvertToChar)
283+
left, right, err := convertLeftAndRight(ctx, left, right, ConvertToChar)
284284
if err != nil {
285285
return nil, nil, nil, err
286286
}
287287

288288
return left, right, types.LongText, nil
289289
}
290290

291-
func convertLeftAndRight(left, right interface{}, convertTo string) (interface{}, interface{}, error) {
292-
// TODO: Add context parameter
293-
ctx := sql.NewEmptyContext()
291+
func convertLeftAndRight(ctx *sql.Context, left, right interface{}, convertTo string) (interface{}, interface{}, error) {
294292
l, err := convertValue(ctx, left, convertTo, nil, 0, 0)
295293
if err != nil {
296294
return nil, nil, err
@@ -441,7 +439,7 @@ func (e *NullSafeEquals) Compare(ctx *sql.Context, row sql.Row) (int, error) {
441439
}
442440

443441
var compareType sql.Type
444-
left, right, compareType, err = e.castLeftAndRight(left, right)
442+
left, right, compareType, err = e.castLeftAndRight(ctx, left, right)
445443
if err != nil {
446444
return 0, err
447445
}

0 commit comments

Comments
 (0)