diff --git a/enginetest/queries/json_scripts.go b/enginetest/queries/json_scripts.go index cb981cd1c8..3c0b09744e 100644 --- a/enginetest/queries/json_scripts.go +++ b/enginetest/queries/json_scripts.go @@ -17,11 +17,35 @@ package queries import ( "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" ) var JsonScripts = []ScriptTest{ + { + // https://github.com/dolthub/dolt/issues/10050 + Name: "TextStorage converts to JSON when using dolt wrapper", + SetUpScript: []string{ + "CREATE TABLE pages (id INT PRIMARY KEY, text_col TEXT, text_json JSON)", + "INSERT INTO pages VALUES (1, '{\"message\":\"hello\"}', NULL)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "UPDATE pages SET text_json = text_col", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}}, + }, + }, + { + Query: "SELECT text_json FROM pages", + Expected: []sql.Row{ + {types.MustJSON("{\"message\":\"hello\"}")}, + }, + }, + }, + }, { Name: "json_type scripts", Assertions: []ScriptTestAssertion{ diff --git a/sql/types/json.go b/sql/types/json.go index 44cb3434f9..64ba15842c 100644 --- a/sql/types/json.go +++ b/sql/types/json.go @@ -45,29 +45,55 @@ func (t JsonType) Compare(ctx context.Context, a interface{}, b interface{}) (in return CompareJSON(ctx, a, b) } +// convertJSONValue parses JSON-encoded data if the input is a string or []byte, returning the resulting JSONDocument. For +// other types, the value is returned if it can be marshalled. +func convertJSONValue(v interface{}) (interface{}, sql.ConvertInRange, error) { + var data []byte + var charsetMaxLength int64 = 1 + switch x := v.(type) { + case []byte: + data = x + case string: + data = []byte(x) + charsetMaxLength = sql.Collation_Default.CharacterSet().MaxLength() + default: + // if |v| can be marshalled, it contains + // a valid JSON document representation + if b, berr := json.Marshal(v); berr == nil { + data = b + } else { + return nil, sql.InRange, nil + } + } + + if int64(len(data))*charsetMaxLength > MaxJsonFieldByteLength { + return nil, sql.InRange, ErrLengthTooLarge.New(len(data), MaxJsonFieldByteLength) + } + + var val interface{} + if err := json.Unmarshal(data, &val); err != nil { + return nil, sql.OutOfRange, sql.ErrInvalidJson.New(err.Error()) + } + + return JSONDocument{Val: val}, sql.InRange, nil +} + // Convert implements Type interface. -func (t JsonType) Convert(c context.Context, v interface{}) (doc interface{}, inRange sql.ConvertInRange, err error) { +func (t JsonType) Convert(c context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { switch v := v.(type) { case sql.JSONWrapper: return v, sql.InRange, nil case []byte: - if int64(len(v)) > MaxJsonFieldByteLength { - return nil, sql.InRange, ErrLengthTooLarge.New(len(v), MaxJsonFieldByteLength) - } - err = json.Unmarshal(v, &doc) - if err != nil { - return nil, sql.OutOfRange, sql.ErrInvalidJson.New(err.Error()) - } + return convertJSONValue(v) case string: - charsetMaxLength := sql.Collation_Default.CharacterSet().MaxLength() - length := int64(len(v)) * charsetMaxLength - if length > MaxJsonFieldByteLength { - return nil, sql.InRange, ErrLengthTooLarge.New(length, MaxJsonFieldByteLength) - } - err = json.Unmarshal([]byte(v), &doc) + return convertJSONValue(v) + // Text values may be stored in wrappers (e.g. Dolt's TextStorage), so unwrap to the raw string before decoding. + case sql.StringWrapper: + str, err := v.Unwrap(c) if err != nil { - return nil, sql.OutOfRange, sql.ErrInvalidJson.New(err.Error()) + return nil, sql.OutOfRange, err } + return convertJSONValue(str) case int8: return JSONDocument{Val: int64(v)}, sql.InRange, nil case int16: @@ -91,22 +117,8 @@ func (t JsonType) Convert(c context.Context, v interface{}) (doc interface{}, in case decimal.Decimal: return JSONDocument{Val: v}, sql.InRange, nil default: - // if |v| can be marshalled, it contains - // a valid JSON document representation - if b, berr := json.Marshal(v); berr == nil { - if int64(len(b)) > MaxJsonFieldByteLength { - return nil, sql.InRange, ErrLengthTooLarge.New(len(b), MaxJsonFieldByteLength) - } - err = json.Unmarshal(b, &doc) - if err != nil { - return nil, sql.OutOfRange, sql.ErrInvalidJson.New(err.Error()) - } - } - } - if err != nil { - return nil, sql.OutOfRange, err + return convertJSONValue(v) } - return JSONDocument{Val: doc}, sql.InRange, nil } // Equals implements the Type interface. diff --git a/sql/types/jsontests/json_test.go b/sql/types/jsontests/json_test.go index e83aa4005d..1540106e0e 100644 --- a/sql/types/jsontests/json_test.go +++ b/sql/types/jsontests/json_test.go @@ -28,6 +28,34 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" ) +type mockStringWrapper struct { + val string +} + +func (m mockStringWrapper) Unwrap(ctx context.Context) (string, error) { + return m.val, nil +} + +func (m mockStringWrapper) UnwrapAny(ctx context.Context) (interface{}, error) { + return m.val, nil +} + +func (m mockStringWrapper) IsExactLength() bool { + return false +} + +func (m mockStringWrapper) MaxByteLength() int64 { + return int64(len(m.val)) +} + +func (m mockStringWrapper) Compare(ctx context.Context, other interface{}) (int, bool, error) { + return 0, false, nil +} + +func (m mockStringWrapper) Hash() interface{} { + return m.val +} + func TestJsonCompare(t *testing.T) { RunJsonCompareTests(t, JsonCompareTests, func(t *testing.T, left, right interface{}) (interface{}, interface{}) { return ConvertToJson(t, left), ConvertToJson(t, right) @@ -58,6 +86,7 @@ func TestJsonConvert(t *testing.T) { {types.MustJSON(`{"field":"test"}`), types.MustJSON(`{"field":"test"}`), false}, {[]string{}, types.MustJSON(`[]`), false}, {[]string{`555-555-5555`}, types.MustJSON(`["555-555-5555"]`), false}, + {mockStringWrapper{val: `{"c": 1}`}, types.MustJSON(`{"c":1}`), false}, } for _, test := range tests {