Skip to content
24 changes: 24 additions & 0 deletions enginetest/queries/json_scripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,35 @@ package queries
import (
"github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/go-mysql-server/sql/plan"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)

var JsonScripts = []ScriptTest{
{
// https://github.com/dolthub/dolt/issues/10050
Name: "TextStorage converts to JSON when using dolt wrapper",
SetUpScript: []string{
"CREATE TABLE pages (id INT PRIMARY KEY, text_col TEXT, text_json JSON)",
"INSERT INTO pages VALUES (1, '{\"message\":\"hello\"}', NULL)",
},
Assertions: []ScriptTestAssertion{
{
Query: "UPDATE pages SET text_json = text_col",
Expected: []sql.Row{
{types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}},
},
},
{
Query: "SELECT text_json FROM pages",
Expected: []sql.Row{
{types.MustJSON("{\"message\":\"hello\"}")},
},
},
},
},
{
Name: "json_type scripts",
Assertions: []ScriptTestAssertion{
Expand Down
70 changes: 47 additions & 23 deletions sql/types/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,62 @@ func (t JsonType) Compare(ctx context.Context, a interface{}, b interface{}) (in
return CompareJSON(ctx, a, b)
}

// convertJSONValue parses JSON-encoded data if the input is a string or []byte, returning the resulting Go value. For
// other types, the value is returned as-is. The returned value is the raw, unwrapped JSON representation and is later
// wrapped in a JSONDocument by JsonType.Convert.
func convertJSONValue(v interface{}) (val interface{}, inRange sql.ConvertInRange, err error) {
var data []byte
var charsetMaxLength int64 = 1
switch x := v.(type) {
case []byte:
data = x
case string:
data = []byte(x)
charsetMaxLength = sql.Collation_Default.CharacterSet().MaxLength()
default:
// if |v| can be marshalled, it contains
// a valid JSON document representation
if b, berr := json.Marshal(v); berr == nil {
data = b
}
}

if int64(len(data))*charsetMaxLength > MaxJsonFieldByteLength {
return nil, sql.InRange, ErrLengthTooLarge.New(len(data), MaxJsonFieldByteLength)
}

if err := json.Unmarshal(data, &val); err != nil {
return nil, sql.OutOfRange, sql.ErrInvalidJson.New(err.Error())
}

return val, sql.InRange, nil
}

// Convert implements Type interface.
func (t JsonType) Convert(c context.Context, v interface{}) (doc interface{}, inRange sql.ConvertInRange, err error) {
docVal := v
switch v := v.(type) {
case sql.JSONWrapper:
return v, sql.InRange, nil
case []byte:
if int64(len(v)) > MaxJsonFieldByteLength {
return nil, sql.InRange, ErrLengthTooLarge.New(len(v), MaxJsonFieldByteLength)
}
err = json.Unmarshal(v, &doc)
docVal, inRange, err = convertJSONValue(v)
if err != nil {
return nil, sql.OutOfRange, sql.ErrInvalidJson.New(err.Error())
return nil, inRange, err
}
case string:
charsetMaxLength := sql.Collation_Default.CharacterSet().MaxLength()
length := int64(len(v)) * charsetMaxLength
if length > MaxJsonFieldByteLength {
return nil, sql.InRange, ErrLengthTooLarge.New(length, MaxJsonFieldByteLength)
docVal, inRange, err = convertJSONValue(v)
if err != nil {
return nil, inRange, err
}
err = json.Unmarshal([]byte(v), &doc)
// Text values may be stored in wrappers (e.g. Dolt's TextStorage), so unwrap to the raw string before decoding.
case sql.StringWrapper:
str, err := v.Unwrap(c)
if err != nil {
return nil, sql.OutOfRange, sql.ErrInvalidJson.New(err.Error())
return nil, sql.OutOfRange, err
}
docVal, inRange, err = convertJSONValue(str)
if err != nil {
return nil, inRange, err
}
case int8:
return JSONDocument{Val: int64(v)}, sql.InRange, nil
Expand All @@ -91,22 +125,12 @@ func (t JsonType) Convert(c context.Context, v interface{}) (doc interface{}, in
case decimal.Decimal:
return JSONDocument{Val: v}, sql.InRange, nil
default:
// if |v| can be marshalled, it contains
// a valid JSON document representation
if b, berr := json.Marshal(v); berr == nil {
if int64(len(b)) > MaxJsonFieldByteLength {
return nil, sql.InRange, ErrLengthTooLarge.New(len(b), MaxJsonFieldByteLength)
}
err = json.Unmarshal(b, &doc)
if err != nil {
return nil, sql.OutOfRange, sql.ErrInvalidJson.New(err.Error())
}
}
docVal, inRange, err = convertJSONValue(v)
}
if err != nil {
return nil, sql.OutOfRange, err
}
return JSONDocument{Val: doc}, sql.InRange, nil
return JSONDocument{Val: docVal}, sql.InRange, nil
}

// Equals implements the Type interface.
Expand Down
29 changes: 29 additions & 0 deletions sql/types/jsontests/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,34 @@ import (
"github.com/dolthub/go-mysql-server/sql/types"
)

type mockStringWrapper struct {
val string
}

func (m mockStringWrapper) Unwrap(ctx context.Context) (string, error) {
return m.val, nil
}

func (m mockStringWrapper) UnwrapAny(ctx context.Context) (interface{}, error) {
return m.val, nil
}

func (m mockStringWrapper) IsExactLength() bool {
return false
}

func (m mockStringWrapper) MaxByteLength() int64 {
return int64(len(m.val))
}

func (m mockStringWrapper) Compare(ctx context.Context, other interface{}) (int, bool, error) {
return 0, false, nil
}

func (m mockStringWrapper) Hash() interface{} {
return m.val
}

func TestJsonCompare(t *testing.T) {
RunJsonCompareTests(t, JsonCompareTests, func(t *testing.T, left, right interface{}) (interface{}, interface{}) {
return ConvertToJson(t, left), ConvertToJson(t, right)
Expand Down Expand Up @@ -58,6 +86,7 @@ func TestJsonConvert(t *testing.T) {
{types.MustJSON(`{"field":"test"}`), types.MustJSON(`{"field":"test"}`), false},
{[]string{}, types.MustJSON(`[]`), false},
{[]string{`555-555-5555`}, types.MustJSON(`["555-555-5555"]`), false},
{mockStringWrapper{val: `{"c": 1}`}, types.MustJSON(`{"c":1}`), false},
}

for _, test := range tests {
Expand Down