diff --git a/enginetest/evaluation.go b/enginetest/evaluation.go index bfc799d4d2..772ffa7792 100644 --- a/enginetest/evaluation.go +++ b/enginetest/evaluation.go @@ -852,7 +852,7 @@ func WidenRow(t *testing.T, sch sql.Schema, row sql.Row) sql.Row { widened := make(sql.Row, len(row)) for i, v := range row { if i < len(sch) && types.IsJSON(sch[i].Type) { - widened[i] = widenJSONValues(v) + widened[i] = widenJSONValues(t.Context(), v) continue } @@ -893,7 +893,7 @@ func widenValue(t *testing.T, v interface{}) (vw interface{}) { return vw } -func widenJSONValues(val interface{}) sql.JSONWrapper { +func widenJSONValues(ctx context.Context, val interface{}) sql.JSONWrapper { if val == nil { return nil } @@ -907,7 +907,7 @@ func widenJSONValues(val interface{}) sql.JSONWrapper { js = types.MustJSON(str) } - doc, err := js.ToInterface() + doc, err := js.ToInterface(ctx) if err != nil { panic(err) } diff --git a/sql/core.go b/sql/core.go index 235e744d51..6ed2246f83 100644 --- a/sql/core.go +++ b/sql/core.go @@ -15,6 +15,7 @@ package sql import ( + "context" "encoding/json" "fmt" trace2 "runtime/trace" @@ -324,7 +325,7 @@ func ConvertToBool(ctx *Context, v interface{}) (bool, error) { } } -func ConvertToVector(v interface{}) ([]float64, error) { +func ConvertToVector(ctx context.Context, v interface{}) ([]float64, error) { switch b := v.(type) { case []float64: return b, nil @@ -336,7 +337,7 @@ func ConvertToVector(v interface{}) ([]float64, error) { } return convertJsonInterfaceToVector(val) case JSONWrapper: - val, err := b.ToInterface() + val, err := b.ToInterface(ctx) if err != nil { return nil, err } diff --git a/sql/expression/function/aggregation/json_agg.go b/sql/expression/function/aggregation/json_agg.go index 6d4e49b8d6..3f3505c85f 100644 --- a/sql/expression/function/aggregation/json_agg.go +++ b/sql/expression/function/aggregation/json_agg.go @@ -163,7 +163,7 @@ func (j *jsonObjectBuffer) Update(ctx *sql.Context, row sql.Row) error { return err } if js, ok := val.(sql.JSONWrapper); ok { - val, err = js.ToInterface() + val, err = js.ToInterface(ctx) if err != nil { return err } diff --git a/sql/expression/function/aggregation/unary_agg_buffers.go b/sql/expression/function/aggregation/unary_agg_buffers.go index c484b5321a..28b8aeb1ae 100644 --- a/sql/expression/function/aggregation/unary_agg_buffers.go +++ b/sql/expression/function/aggregation/unary_agg_buffers.go @@ -655,7 +655,7 @@ func (j *jsonArrayBuffer) Update(ctx *sql.Context, row sql.Row) error { return err } if js, ok := v.(sql.JSONWrapper); ok { - v, err = js.ToInterface() + v, err = js.ToInterface(ctx) if err != nil { return err } diff --git a/sql/expression/function/aggregation/window_functions.go b/sql/expression/function/aggregation/window_functions.go index 76a3ff5622..c0af914abc 100644 --- a/sql/expression/function/aggregation/window_functions.go +++ b/sql/expression/function/aggregation/window_functions.go @@ -1055,7 +1055,7 @@ func (a *WindowedJSONArrayAgg) aggregateVals(ctx *sql.Context, interval sql.Wind return nil, err } if js, ok := v.(sql.JSONWrapper); ok { - v, err = js.ToInterface() + v, err = js.ToInterface(ctx) if err != nil { return nil, err } @@ -1146,7 +1146,7 @@ func (a *WindowedJSONObjectAgg) aggregateVals(ctx *sql.Context, interval sql.Win return nil, err } if js, ok := val.(sql.JSONWrapper); ok { - val, err = js.ToInterface() + val, err = js.ToInterface(ctx) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_array.go b/sql/expression/function/json/json_array.go index ab0fb49bb9..3d321dded4 100644 --- a/sql/expression/function/json/json_array.go +++ b/sql/expression/function/json/json_array.go @@ -113,7 +113,7 @@ func (j *JSONArray) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { switch v := val.(type) { case sql.JSONWrapper: - val, err = v.ToInterface() + val, err = v.ToInterface(ctx) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_array_append.go b/sql/expression/function/json/json_array_append.go index daba7ef3bb..4a13557496 100644 --- a/sql/expression/function/json/json_array_append.go +++ b/sql/expression/function/json/json_array_append.go @@ -88,7 +88,7 @@ func (j JSONArrayAppend) Eval(ctx *sql.Context, row sql.Row) (interface{}, error // Apply the path-value pairs to the document. for _, pair := range pairs { - doc, _, err = doc.ArrayAppend(pair.path, pair.val) + doc, _, err = doc.ArrayAppend(ctx, pair.path, pair.val) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_array_insert.go b/sql/expression/function/json/json_array_insert.go index 74623da9b1..52e7a99430 100644 --- a/sql/expression/function/json/json_array_insert.go +++ b/sql/expression/function/json/json_array_insert.go @@ -89,7 +89,7 @@ func (j JSONArrayInsert) Eval(ctx *sql.Context, row sql.Row) (interface{}, error // Apply the path-value pairs to the document. for _, pair := range pairs { - doc, _, err = doc.ArrayInsert(pair.path, pair.val) + doc, _, err = doc.ArrayInsert(ctx, pair.path, pair.val) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_common.go b/sql/expression/function/json/json_common.go index 72d249a7b8..14a800aca4 100644 --- a/sql/expression/function/json/json_common.go +++ b/sql/expression/function/json/json_common.go @@ -102,7 +102,7 @@ func MutableJsonDoc(ctx context.Context, wrapper sql.JSONWrapper) (types.Mutable return mutable, nil } - val, err := clonedJsonWrapper.ToInterface() + val, err := clonedJsonWrapper.ToInterface(ctx) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_contains.go b/sql/expression/function/json/json_contains.go index 38683036a6..2faab384df 100644 --- a/sql/expression/function/json/json_contains.go +++ b/sql/expression/function/json/json_contains.go @@ -147,7 +147,7 @@ func (j *JSONContains) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return nil, err } - extracted, err := types.LookupJSONValue(target, path.(string)) + extracted, err := types.LookupJSONValue(ctx, target, path.(string)) if err != nil { return nil, err } @@ -162,11 +162,11 @@ func (j *JSONContains) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) } // Now determine whether the candidate value exists in the target - targetVal, err := target.ToInterface() + targetVal, err := target.ToInterface(ctx) if err != nil { return nil, err } - candidateVal, err := candidate.ToInterface() + candidateVal, err := candidate.ToInterface(ctx) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_contains_path.go b/sql/expression/function/json/json_contains_path.go index 3a59b71d36..4f79fe4ca0 100644 --- a/sql/expression/function/json/json_contains_path.go +++ b/sql/expression/function/json/json_contains_path.go @@ -76,7 +76,7 @@ func (j JSONContainsPath) Eval(ctx *sql.Context, row sql.Row) (interface{}, erro return nil, err } - result, err := types.LookupJSONValue(target, path.(string)) + result, err := types.LookupJSONValue(ctx, target, path.(string)) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_depth.go b/sql/expression/function/json/json_depth.go index f29ede00c7..604ca003c7 100644 --- a/sql/expression/function/json/json_depth.go +++ b/sql/expression/function/json/json_depth.go @@ -114,7 +114,7 @@ func (j *JSONDepth) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - val, err := doc.ToInterface() + val, err := doc.ToInterface(ctx) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_extract.go b/sql/expression/function/json/json_extract.go index 065d417a22..3638271f05 100644 --- a/sql/expression/function/json/json_extract.go +++ b/sql/expression/function/json/json_extract.go @@ -106,7 +106,7 @@ func (j *JSONExtract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - results[i], err = types.LookupJSONValue(searchable, path.(string)) + results[i], err = types.LookupJSONValue(ctx, searchable, path.(string)) if err != nil { return nil, fmt.Errorf("failed to extract from expression '%s'; %s", j.JSON.String(), err.Error()) } diff --git a/sql/expression/function/json/json_keys.go b/sql/expression/function/json/json_keys.go index 233efae1d1..b7101451cc 100644 --- a/sql/expression/function/json/json_keys.go +++ b/sql/expression/function/json/json_keys.go @@ -105,7 +105,7 @@ func (j *JSONKeys) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - js, err := types.LookupJSONValue(doc, *path) + js, err := types.LookupJSONValue(ctx, doc, *path) if err != nil { if errors.Is(err, jsonpath.ErrKeyError) { return nil, nil @@ -117,7 +117,7 @@ func (j *JSONKeys) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - val, err := js.ToInterface() + val, err := js.ToInterface(ctx) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_length.go b/sql/expression/function/json/json_length.go index 35a27869f6..8c39c4cef1 100644 --- a/sql/expression/function/json/json_length.go +++ b/sql/expression/function/json/json_length.go @@ -95,7 +95,7 @@ func (j *JsonLength) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, strErr } - res, err := types.LookupJSONValue(doc, path) + res, err := types.LookupJSONValue(ctx, doc, path) if err != nil { return nil, err } @@ -104,7 +104,7 @@ func (j *JsonLength) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - val, err := res.ToInterface() + val, err := res.ToInterface(ctx) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_merge_patch.go b/sql/expression/function/json/json_merge_patch.go index 86537814e1..9ca8b58730 100644 --- a/sql/expression/function/json/json_merge_patch.go +++ b/sql/expression/function/json/json_merge_patch.go @@ -116,7 +116,7 @@ func (j *JSONMergePatch) Eval(ctx *sql.Context, row sql.Row) (interface{}, error return nil, nil } - val, err := initDoc.ToInterface() + val, err := initDoc.ToInterface(ctx) if err != nil { return nil, err } @@ -131,7 +131,7 @@ func (j *JSONMergePatch) Eval(ctx *sql.Context, row sql.Row) (interface{}, error if doc == nil { return nil, nil } - val, err = doc.ToInterface() + val, err = doc.ToInterface(ctx) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_merge_preserve.go b/sql/expression/function/json/json_merge_preserve.go index 64d9861a5e..77aa713c4e 100644 --- a/sql/expression/function/json/json_merge_preserve.go +++ b/sql/expression/function/json/json_merge_preserve.go @@ -124,7 +124,7 @@ func (j *JSONMergePreserve) Eval(ctx *sql.Context, row sql.Row) (interface{}, er return nil, nil } - val, err := initDoc.ToInterface() + val, err := initDoc.ToInterface(ctx) if err != nil { return nil, err } @@ -138,7 +138,7 @@ func (j *JSONMergePreserve) Eval(ctx *sql.Context, row sql.Row) (interface{}, er if doc == nil { return nil, nil } - val, err = doc.ToInterface() + val, err = doc.ToInterface(ctx) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_object.go b/sql/expression/function/json/json_object.go index 42f08dfd46..0cdd1e1ae8 100644 --- a/sql/expression/function/json/json_object.go +++ b/sql/expression/function/json/json_object.go @@ -116,7 +116,7 @@ func (j JSONObject) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } if json, ok := val.(sql.JSONWrapper); ok { - val, err = json.ToInterface() + val, err = json.ToInterface(ctx) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_overlaps.go b/sql/expression/function/json/json_overlaps.go index 8e3b58e5df..3815158b37 100644 --- a/sql/expression/function/json/json_overlaps.go +++ b/sql/expression/function/json/json_overlaps.go @@ -190,7 +190,7 @@ func (j *JSONOverlaps) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) if left == nil { return nil, nil } - leftVal, err := left.ToInterface() + leftVal, err := left.ToInterface(ctx) if err != nil { return nil, err } @@ -202,7 +202,7 @@ func (j *JSONOverlaps) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) if right == nil { return nil, nil } - rightVal, err := right.ToInterface() + rightVal, err := right.ToInterface(ctx) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_pretty.go b/sql/expression/function/json/json_pretty.go index f7f1a24070..627dde26ce 100644 --- a/sql/expression/function/json/json_pretty.go +++ b/sql/expression/function/json/json_pretty.go @@ -83,7 +83,7 @@ func (j *JSONPretty) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { if doc == nil { return nil, nil } - val, err := doc.ToInterface() + val, err := doc.ToInterface(ctx) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_search.go b/sql/expression/function/json/json_search.go index 660fe257ba..15f81d5643 100644 --- a/sql/expression/function/json/json_search.go +++ b/sql/expression/function/json/json_search.go @@ -295,7 +295,7 @@ func (j *JSONSearch) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } - val, err := doc.ToInterface() + val, err := doc.ToInterface(ctx) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_type.go b/sql/expression/function/json/json_type.go index 0b5841e293..b4ecbb9886 100644 --- a/sql/expression/function/json/json_type.go +++ b/sql/expression/function/json/json_type.go @@ -109,7 +109,7 @@ func (j JSONType) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return comparableDoc.Type(ctx) } - val, err := doc.ToInterface() + val, err := doc.ToInterface(ctx) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_value.go b/sql/expression/function/json/json_value.go index af051115e0..6bddbfe42b 100644 --- a/sql/expression/function/json/json_value.go +++ b/sql/expression/function/json/json_value.go @@ -95,7 +95,7 @@ func (j *JsonValue) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } // json NULLs also result in sql NULLs. - cmp, err := types.CompareJSON(js, types.JSONDocument{Val: nil}) + cmp, err := types.CompareJSON(ctx, js, types.JSONDocument{Val: nil}) if cmp == 0 { return nil, nil } @@ -111,7 +111,7 @@ func (j *JsonValue) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } var res interface{} - res, err = types.LookupJSONValue(searchable, path.(string)) + res, err = types.LookupJSONValue(ctx, searchable, path.(string)) if err != nil || res == nil { return nil, err } @@ -120,7 +120,7 @@ func (j *JsonValue) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // bad lookups on arrays, instead of an error. Note that this will cause lookups that expect [] to return incorrect // results. // See https://github.com/dolthub/dolt/issues/7905 for more information. - cmp, err = types.CompareJSON(res, types.JSONDocument{Val: []interface{}{}}) + cmp, err = types.CompareJSON(ctx, res, types.JSONDocument{Val: []interface{}{}}) if err != nil { return nil, err } @@ -185,7 +185,7 @@ func GetJSONFromWrapperOrCoercibleString(ctx *sql.Context, js interface{}, funct } return jsonData, nil case sql.JSONWrapper: - return jsType.ToInterface() + return jsType.ToInterface(ctx) default: return nil, sql.ErrInvalidJSONArgument.New(argumentPosition, functionName) } diff --git a/sql/expression/function/json/jsontests/json_extract_tests.go b/sql/expression/function/json/jsontests/json_extract_tests.go index 58e74a96ca..74ba522286 100644 --- a/sql/expression/function/json/jsontests/json_extract_tests.go +++ b/sql/expression/function/json/jsontests/json_extract_tests.go @@ -15,6 +15,7 @@ package jsontests import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -67,7 +68,7 @@ func JsonExtractTestCases(t *testing.T, prepare prepareJsonValue) []testCase { }} // Workaround for https://github.com/dolthub/dolt/issues/7998 // Otherwise, converting this to a string will create invalid JSON - jsonBytes, err := types.MarshallJson(jsonDocument) + jsonBytes, err := types.MarshallJson(context.Background(), jsonDocument) require.NoError(t, err) jsonInput := prepare(t, jsonBytes) diff --git a/sql/expression/function/json/jsontests/json_function_tests.go b/sql/expression/function/json/jsontests/json_function_tests.go index 064cd61036..10646974ee 100644 --- a/sql/expression/function/json/jsontests/json_function_tests.go +++ b/sql/expression/function/json/jsontests/json_function_tests.go @@ -49,7 +49,7 @@ var jsonFormatTests = []jsonFormatTest{ prepareFunc: func(t *testing.T, js interface{}) interface{} { doc, _, err := types.JSON.Convert(sqlCtx, js) require.NoError(t, err) - val, err := doc.(sql.JSONWrapper).ToInterface() + val, err := doc.(sql.JSONWrapper).ToInterface(t.Context()) require.NoError(t, err) return types.JSONDocument{Val: val} }, @@ -59,7 +59,7 @@ var jsonFormatTests = []jsonFormatTest{ prepareFunc: func(t *testing.T, js interface{}) interface{} { doc, _, err := types.JSON.Convert(sqlCtx, js) require.NoError(t, err) - bytes, err := types.MarshallJson(doc.(sql.JSONWrapper)) + bytes, err := types.MarshallJson(sqlCtx, doc.(sql.JSONWrapper)) require.NoError(t, err) return types.NewLazyJSONDocument(bytes) }, diff --git a/sql/expression/function/vector/distance.go b/sql/expression/function/vector/distance.go index f1d302cc8f..26f3f23a88 100644 --- a/sql/expression/function/vector/distance.go +++ b/sql/expression/function/vector/distance.go @@ -15,6 +15,7 @@ package vector import ( + "context" "fmt" "github.com/dolthub/go-mysql-server/sql" @@ -127,18 +128,18 @@ func (d Distance) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - return MeasureDistance(lval, rval, d.DistanceType) + return MeasureDistance(ctx, lval, rval, d.DistanceType) } -func MeasureDistance(left, right interface{}, distanceType DistanceType) (interface{}, error) { - leftVec, err := sql.ConvertToVector(left) +func MeasureDistance(ctx context.Context, left, right interface{}, distanceType DistanceType) (interface{}, error) { + leftVec, err := sql.ConvertToVector(ctx, left) if err != nil { return nil, err } if leftVec == nil { return nil, nil } - rightVec, err := sql.ConvertToVector(right) + rightVec, err := sql.ConvertToVector(ctx, right) if err != nil { return nil, err } diff --git a/sql/fulltext/fulltext.go b/sql/fulltext/fulltext.go index 60cea459c2..06f5d94602 100644 --- a/sql/fulltext/fulltext.go +++ b/sql/fulltext/fulltext.go @@ -173,7 +173,7 @@ func writeHashedValue(ctx context.Context, h hash.Hash, val interface{}) (valIsN return false, err } case sql.JSONWrapper: - str, err := types.JsonToMySqlString(val) + str, err := types.JsonToMySqlString(ctx, val) if err != nil { return false, err } diff --git a/sql/plan/histogram.go b/sql/plan/histogram.go index 3ee1b6dcc9..7dce2443d0 100644 --- a/sql/plan/histogram.go +++ b/sql/plan/histogram.go @@ -1,6 +1,7 @@ package plan import ( + "context" "fmt" "strings" @@ -59,7 +60,7 @@ func (u *UpdateHistogram) Resolved() bool { } func (u *UpdateHistogram) String() string { - statBytes, _ := types.MarshallJson(u.stats) + statBytes, _ := types.MarshallJson(context.TODO(), u.stats) return fmt.Sprintf("update histogram %s.(%s) using %s", u.table, strings.Join(u.cols, ","), statBytes) } diff --git a/sql/statistics.go b/sql/statistics.go index 18506a71c2..ae1f9fda95 100644 --- a/sql/statistics.go +++ b/sql/statistics.go @@ -173,7 +173,7 @@ func (h Histogram) Clone(context.Context) JSONWrapper { return h } -func (h Histogram) ToInterface() (interface{}, error) { +func (h Histogram) ToInterface(context.Context) (interface{}, error) { ret := make([]interface{}, len(h)) for i, b := range h { var upperBound Row diff --git a/sql/stats/statistic.go b/sql/stats/statistic.go index 3ec787c1f3..c72dc89565 100644 --- a/sql/stats/statistic.go +++ b/sql/stats/statistic.go @@ -226,13 +226,13 @@ func (s *Statistic) Clone(context.Context) sql.JSONWrapper { return s } -func (s *Statistic) ToInterface() (interface{}, error) { +func (s *Statistic) ToInterface(ctx context.Context) (interface{}, error) { typs := make([]string, len(s.Typs)) for i, t := range s.Typs { typs[i] = t.String() } - buckets, err := s.Histogram().ToInterface() + buckets, err := s.Histogram().ToInterface(ctx) if err != nil { return nil, err } diff --git a/sql/types/json.go b/sql/types/json.go index e4f1757145..44cb3434f9 100644 --- a/sql/types/json.go +++ b/sql/types/json.go @@ -38,11 +38,11 @@ var _ sql.CollationCoercible = JsonType{} type JsonType struct{} // Compare implements Type interface. -func (t JsonType) Compare(s context.Context, a interface{}, b interface{}) (int, error) { +func (t JsonType) Compare(ctx context.Context, a interface{}, b interface{}) (int, error) { if hasNulls, res := CompareNulls(a, b); hasNulls { return res, nil } - return CompareJSON(a, b) + return CompareJSON(ctx, a, b) } // Convert implements Type interface. @@ -137,7 +137,7 @@ func (t JsonType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va // This is kind of a hack, and it means that reading JSON from tables no longer matches MySQL byte-for-byte. // But its worth it to avoid the round-trip, which can be very slow. if j, ok := v.(JSONBytes); ok { - str, err := MarshallJson(j) + str, err := MarshallJson(ctx, j) if err != nil { return sqltypes.NULL, err } @@ -150,7 +150,7 @@ func (t JsonType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va } js := jsVal.(sql.JSONWrapper) - str, err := JsonToMySqlString(js) + str, err := JsonToMySqlString(ctx, js) if err != nil { return sqltypes.NULL, err } diff --git a/sql/types/json_value.go b/sql/types/json_value.go index 01a1b97020..3a3405f18d 100644 --- a/sql/types/json_value.go +++ b/sql/types/json_value.go @@ -35,8 +35,8 @@ import ( ) // JsonToMySqlString generates a string representation of a sql.JSONWrapper that is compatible with MySQL's JSON output, including spaces. -func JsonToMySqlString(jsonWrapper sql.JSONWrapper) (string, error) { - val, err := jsonWrapper.ToInterface() +func JsonToMySqlString(ctx context.Context, jsonWrapper sql.JSONWrapper) (string, error) { + val, err := jsonWrapper.ToInterface(ctx) if err != nil { return "", err } @@ -44,8 +44,8 @@ func JsonToMySqlString(jsonWrapper sql.JSONWrapper) (string, error) { } // JsonToMySqlBytes generates a byte slice representation of a sql.JSONWrapper that is compatible with MySQL's JSON output, including spaces. -func JsonToMySqlBytes(jsonWrapper sql.JSONWrapper) ([]byte, error) { - val, err := jsonWrapper.ToInterface() +func JsonToMySqlBytes(ctx context.Context, jsonWrapper sql.JSONWrapper) ([]byte, error) { + val, err := jsonWrapper.ToInterface(ctx) if err != nil { return nil, err } @@ -55,7 +55,7 @@ func JsonToMySqlBytes(jsonWrapper sql.JSONWrapper) ([]byte, error) { // JSONBytes are values which can be represented as JSON. type JSONBytes interface { sql.JSONWrapper - GetBytes() ([]byte, error) + GetBytes(ctx context.Context) ([]byte, error) } func MarshallJsonValue(value interface{}) ([]byte, error) { @@ -74,11 +74,11 @@ func MarshallJsonValue(value interface{}) ([]byte, error) { } // JSONBytes returns or generates a byte array for the JSON representation of the underlying sql.JSONWrapper -func MarshallJson(jsonWrapper sql.JSONWrapper) ([]byte, error) { +func MarshallJson(ctx context.Context, jsonWrapper sql.JSONWrapper) ([]byte, error) { if bytes, ok := jsonWrapper.(JSONBytes); ok { - return bytes.GetBytes() + return bytes.GetBytes(ctx) } - val, err := jsonWrapper.ToInterface() + val, err := jsonWrapper.ToInterface(ctx) if err != nil { return []byte{}, err } @@ -114,11 +114,11 @@ type MutableJSON interface { // Replace the value at the given path with the new value. If the path does not exist, no modification is made. Replace(ctx context.Context, path string, val sql.JSONWrapper) (MutableJSON, bool, error) // ArrayInsert inserts into the array object referenced by the given path. If the path does not exist, no modification is made. - ArrayInsert(path string, val sql.JSONWrapper) (MutableJSON, bool, error) + ArrayInsert(ctx context.Context, path string, val sql.JSONWrapper) (MutableJSON, bool, error) // ArrayAppend appends to an array object referenced by the given path. If the path does not exist, no modification is made, // or if the path exists and is not an array, the element will be converted into an array and the element will be // appended to it. - ArrayAppend(path string, val sql.JSONWrapper) (MutableJSON, bool, error) + ArrayAppend(ctx context.Context, path string, val sql.JSONWrapper) (MutableJSON, bool, error) } type JSONDocument struct { @@ -129,16 +129,16 @@ var _ sql.JSONWrapper = JSONDocument{} var _ MutableJSON = JSONDocument{} var _ SearchableJSON = JSONDocument{} -func (doc JSONDocument) ToInterface() (interface{}, error) { +func (doc JSONDocument) ToInterface(context.Context) (interface{}, error) { return doc.Val, nil } -func (doc JSONDocument) Compare(other sql.JSONWrapper) (int, error) { - otherVal, err := other.ToInterface() +func (doc JSONDocument) Compare(ctx context.Context, other sql.JSONWrapper) (int, error) { + otherVal, err := other.ToInterface(ctx) if err != nil { return 0, err } - return CompareJSON(doc.Val, otherVal) + return CompareJSON(ctx, doc.Val, otherVal) } func (doc JSONDocument) JSONString() (string, error) { @@ -193,29 +193,29 @@ func (j *LazyJSONDocument) Clone(context.Context) sql.JSONWrapper { return NewLazyJSONDocument(j.Bytes) } -func (j *LazyJSONDocument) ToInterface() (interface{}, error) { +func (j *LazyJSONDocument) ToInterface(context.Context) (interface{}, error) { return j.interfaceFunc() } -func (j *LazyJSONDocument) GetBytes() ([]byte, error) { +func (j *LazyJSONDocument) GetBytes(_ context.Context) ([]byte, error) { return j.Bytes, nil } // Value implements driver.Valuer for interoperability with other go libraries func (j *LazyJSONDocument) Value() (driver.Value, error) { - return JsonToMySqlString(j) + return JsonToMySqlString(context.Background(), j) } // LazyJSONDocument implements the fmt.Stringer interface. func (j *LazyJSONDocument) String() string { - s, err := JsonToMySqlString(j) + s, err := JsonToMySqlString(context.Background(), j) if err != nil { return fmt.Sprintf("error while stringifying JSON: %s", err.Error()) } return s } -func LookupJSONValue(j sql.JSONWrapper, path string) (sql.JSONWrapper, error) { +func LookupJSONValue(ctx context.Context, j sql.JSONWrapper, path string) (sql.JSONWrapper, error) { if path == "$" { // Special case the identity operation to handle a nil value for doc.Val return j, nil @@ -226,7 +226,7 @@ func LookupJSONValue(j sql.JSONWrapper, path string) (sql.JSONWrapper, error) { return searchableJson.Lookup(ctx, path) } - r, err := j.ToInterface() + r, err := j.ToInterface(ctx) if err != nil { return nil, err } @@ -301,7 +301,7 @@ func ConcatenateJSONValues(ctx *sql.Context, vals ...sql.JSONWrapper) (sql.JSONW var err error arr := make(JsonArray, len(vals)) for i, v := range vals { - arr[i], err = v.ToInterface() + arr[i], err = v.ToInterface(ctx) if err != nil { return nil, err } @@ -503,7 +503,7 @@ func containsJSONNumber(a float64, b interface{}) (bool, error) { // TODO(andy): BLOB, BIT, OPAQUE, DATETIME, TIME, DATE, INTEGER // // https://dev.mysql.com/doc/refman/8.0/en/json.html#json-comparison -func CompareJSON(a, b interface{}) (int, error) { +func CompareJSON(ctx context.Context, a, b interface{}) (int, error) { var err error if hasNulls, res := CompareNulls(b, a); hasNulls { return res, nil @@ -522,9 +522,9 @@ func CompareJSON(a, b interface{}) (int, error) { case bool: return compareJSONBool(a, b) case JsonArray: - return compareJSONArray(a, b) + return compareJSONArray(ctx, a, b) case JsonObject: - return compareJSONObject(a, b) + return compareJSONObject(ctx, a, b) case string: return compareJSONString(a, b) case int: @@ -554,16 +554,16 @@ func CompareJSON(a, b interface{}) (int, error) { return compareJSONNumber(af, b) case sql.JSONWrapper: if jw, ok := b.(sql.JSONWrapper); ok { - b, err = jw.ToInterface() + b, err = jw.ToInterface(ctx) if err != nil { return 0, err } } - aVal, err := a.ToInterface() + aVal, err := a.ToInterface(ctx) if err != nil { return 0, err } - return CompareJSON(aVal, b) + return CompareJSON(ctx, aVal, b) default: return 0, sql.ErrInvalidType.New(a) } @@ -590,7 +590,7 @@ func compareJSONBool(a bool, b interface{}) (int, error) { } } -func compareJSONArray(a JsonArray, b interface{}) (int, error) { +func compareJSONArray(ctx context.Context, a JsonArray, b interface{}) (int, error) { switch b := b.(type) { case bool: // a is lower precedence @@ -607,7 +607,7 @@ func compareJSONArray(a JsonArray, b interface{}) (int, error) { return 1, nil } - cmp, err := CompareJSON(aa, b[i]) + cmp, err := CompareJSON(ctx, aa, b[i]) if err != nil { return 0, err } @@ -627,7 +627,7 @@ func compareJSONArray(a JsonArray, b interface{}) (int, error) { } } -func compareJSONObject(a JsonObject, b interface{}) (int, error) { +func compareJSONObject(ctx context.Context, a JsonObject, b interface{}) (int, error) { switch b := b.(type) { case bool, @@ -640,7 +640,7 @@ func compareJSONObject(a JsonObject, b interface{}) (int, error) { // objects. The order of two objects that are not equal is unspecified but deterministic. inter := jsonObjectKeyIntersection(a, b) for _, key := range inter { - cmp, err := CompareJSON(a[key], b[key]) + cmp, err := CompareJSON(ctx, a[key], b[key]) if err != nil { return 0, err } @@ -768,9 +768,9 @@ func jsonObjectDeterministicOrder(a, b JsonObject, inter []string) (int, error) return strings.Compare(aa, bb), nil } -func (doc JSONDocument) Insert(_ context.Context, path string, val sql.JSONWrapper) (MutableJSON, bool, error) { +func (doc JSONDocument) Insert(ctx context.Context, path string, val sql.JSONWrapper) (MutableJSON, bool, error) { path = strings.TrimSpace(path) - return doc.unwrapAndExecute(path, val, INSERT) + return doc.unwrapAndExecute(ctx, path, val, INSERT) } func (doc JSONDocument) Remove(ctx context.Context, path string) (MutableJSON, bool, error) { @@ -779,25 +779,25 @@ func (doc JSONDocument) Remove(ctx context.Context, path string) (MutableJSON, b return nil, false, fmt.Errorf("The path expression '$' is not allowed in this context.") } - return doc.unwrapAndExecute(path, nil, REMOVE) + return doc.unwrapAndExecute(ctx, path, nil, REMOVE) } func (doc JSONDocument) Set(ctx context.Context, path string, val sql.JSONWrapper) (MutableJSON, bool, error) { path = strings.TrimSpace(path) - return doc.unwrapAndExecute(path, val, SET) + return doc.unwrapAndExecute(ctx, path, val, SET) } func (doc JSONDocument) Replace(ctx context.Context, path string, val sql.JSONWrapper) (MutableJSON, bool, error) { path = strings.TrimSpace(path) - return doc.unwrapAndExecute(path, val, REPLACE) + return doc.unwrapAndExecute(ctx, path, val, REPLACE) } -func (doc JSONDocument) ArrayAppend(path string, val sql.JSONWrapper) (MutableJSON, bool, error) { +func (doc JSONDocument) ArrayAppend(ctx context.Context, path string, val sql.JSONWrapper) (MutableJSON, bool, error) { path = strings.TrimSpace(path) - return doc.unwrapAndExecute(path, val, ARRAY_APPEND) + return doc.unwrapAndExecute(ctx, path, val, ARRAY_APPEND) } -func (doc JSONDocument) ArrayInsert(path string, val sql.JSONWrapper) (MutableJSON, bool, error) { +func (doc JSONDocument) ArrayInsert(ctx context.Context, path string, val sql.JSONWrapper) (MutableJSON, bool, error) { path = strings.TrimSpace(path) if path == "$" { @@ -805,7 +805,7 @@ func (doc JSONDocument) ArrayInsert(path string, val sql.JSONWrapper) (MutableJS return nil, false, fmt.Errorf("Path expression is not a path to a cell in an array: $") } - return doc.unwrapAndExecute(path, val, ARRAY_INSERT) + return doc.unwrapAndExecute(ctx, path, val, ARRAY_INSERT) } const ( @@ -819,7 +819,7 @@ const ( // unwrapAndExecute unwraps the JSONDocument and executes the given path on the unwrapped value. The path string passed // in at this point should be unmodified. -func (doc JSONDocument) unwrapAndExecute(path string, val sql.JSONWrapper, mode int) (MutableJSON, bool, error) { +func (doc JSONDocument) unwrapAndExecute(ctx context.Context, path string, val sql.JSONWrapper, mode int) (MutableJSON, bool, error) { if path == "" { return nil, false, fmt.Errorf("Invalid JSON path expression. Empty path") } @@ -827,7 +827,7 @@ func (doc JSONDocument) unwrapAndExecute(path string, val sql.JSONWrapper, mode var err error var unmarshalled interface{} if val != nil { - unmarshalled, err = val.ToInterface() + unmarshalled, err = val.ToInterface(ctx) if err != nil { return nil, false, err } diff --git a/sql/types/jsontests/json_test.go b/sql/types/jsontests/json_test.go index 5a068bf263..e83aa4005d 100644 --- a/sql/types/jsontests/json_test.go +++ b/sql/types/jsontests/json_test.go @@ -144,14 +144,14 @@ func TestLazyJsonDocument(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.s, func(t *testing.T) { doc := types.NewLazyJSONDocument([]byte(testCase.s)) - val, err := doc.ToInterface() + val, err := doc.ToInterface(context.Background()) require.NoError(t, err) require.Equal(t, testCase.json, val) }) } t.Run("lazy docs only error when deserialized", func(t *testing.T) { doc := types.NewLazyJSONDocument([]byte("not valid json")) - _, err := doc.ToInterface() + _, err := doc.ToInterface(context.Background()) require.Error(t, err) }) } @@ -366,7 +366,7 @@ func TestJsonInsertErrors(t *testing.T) { for _, test := range JsonArrayInsertErrors { t.Run("JSON Path: "+test.desc, func(t *testing.T) { - _, changed, err := doc.ArrayInsert(test.path, types.MustJSON(`{"a": 42}`)) + _, changed, err := doc.ArrayInsert(t.Context(), test.path, types.MustJSON(`{"a": 42}`)) assert.Equal(t, false, changed) require.Error(t, err) assert.Equal(t, test.expectErrStr, err.Error()) diff --git a/sql/types/jsontests/json_tests.go b/sql/types/jsontests/json_tests.go index 22b0da427e..d455c15e74 100644 --- a/sql/types/jsontests/json_tests.go +++ b/sql/types/jsontests/json_tests.go @@ -36,7 +36,7 @@ func ConvertToJson(t *testing.T, val interface{}) types.MutableJSON { require.NoError(t, err) require.True(t, bool(inRange)) require.Implements(t, (*sql.JSONWrapper)(nil), val) - val, err = val.(sql.JSONWrapper).ToInterface() + val, err = val.(sql.JSONWrapper).ToInterface(t.Context()) require.NoError(t, err) return types.JSONDocument{Val: val} } @@ -994,17 +994,17 @@ func RunJsonMutationTests(ctx context.Context, t *testing.T, tests []JsonMutatio case "replace": return doc.Replace(ctx, test.path, val) case "arrayappend": - return doc.ArrayAppend(test.path, val) + return doc.ArrayAppend(ctx, test.path, val) case "arrayinsert": - return doc.ArrayInsert(test.path, val) + return doc.ArrayInsert(ctx, test.path, val) default: panic("unexpected operation for test") } }() require.NoError(t, err) - expected, err := result.ToInterface() + expected, err := result.ToInterface(ctx) require.NoError(t, err) - actual, err := res.ToInterface() + actual, err := res.ToInterface(ctx) require.NoError(t, err) assert.Equal(t, expected, actual) assert.Equal(t, test.changed, changed) diff --git a/sql/types/number.go b/sql/types/number.go index c346efbfc9..b6152a800d 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -243,7 +243,7 @@ func (t NumberTypeImpl_) Compare(s context.Context, a interface{}, b interface{} } // Convert implements Type interface. -func (t NumberTypeImpl_) Convert(c context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { +func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { var err error if v == nil { return nil, sql.InRange, nil @@ -254,7 +254,7 @@ func (t NumberTypeImpl_) Convert(c context.Context, v interface{}) (interface{}, } if jv, ok := v.(sql.JSONWrapper); ok { - v, err = jv.ToInterface() + v, err = jv.ToInterface(ctx) if err != nil { return nil, sql.OutOfRange, err } @@ -559,7 +559,7 @@ func (t NumberTypeImpl_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqlt var err error if jv, ok := v.(sql.JSONWrapper); ok { - v, err = jv.ToInterface() + v, err = jv.ToInterface(ctx) if err != nil { return sqltypes.Value{}, err } diff --git a/sql/types/strings.go b/sql/types/strings.go index 0c17c69b67..50d8f07bb6 100644 --- a/sql/types/strings.go +++ b/sql/types/strings.go @@ -425,7 +425,7 @@ func ConvertToBytes(ctx context.Context, v interface{}, t sql.StringType, dest [ val = append(dest, s.Decimal.String()...) case sql.JSONWrapper: var err error - val, err = JsonToMySqlBytes(s) + val, err = JsonToMySqlBytes(ctx, s) if err != nil { return nil, err } @@ -708,7 +708,7 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes. var valueBytes []byte switch v := v.(type) { case JSONBytes: - valueBytes, err = v.GetBytes() + valueBytes, err = v.GetBytes(ctx) if err != nil { return sqltypes.Value{}, err } diff --git a/sql/wrapper.go b/sql/wrapper.go index c914dda106..0d083e155a 100644 --- a/sql/wrapper.go +++ b/sql/wrapper.go @@ -91,5 +91,5 @@ type JSONWrapper interface { // Clone creates a new value that can be mutated without affecting the original. Clone(ctx context.Context) JSONWrapper // ToInterface converts a JSONWrapper to an interface{} of simple types - ToInterface() (interface{}, error) + ToInterface(ctx context.Context) (interface{}, error) }