Skip to content

Commit 487575e

Browse files
authored
Merge pull request #1713 from dolthub/fulghum/dolt-5735
Changing `MaxTextResponseByteLength()` to respect `character_set_results`
2 parents 5f335ee + e48304e commit 487575e

34 files changed

+188
-102
lines changed

server/handler.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ func (h *Handler) ComPrepare(c *mysql.Conn, query string) ([]*query.Field, error
114114
if types.IsOkResultSchema(analyzed.Schema()) {
115115
return nil, nil
116116
}
117-
return schemaToFields(analyzed.Schema()), nil
117+
return schemaToFields(ctx, analyzed.Schema()), nil
118118
}
119119

120120
func (h *Handler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
@@ -420,7 +420,7 @@ func (h *Handler) doQuery(
420420
defer wg.Done()
421421
for {
422422
if r == nil {
423-
r = &sqltypes.Result{Fields: schemaToFields(schema)}
423+
r = &sqltypes.Result{Fields: schemaToFields(ctx, schema)}
424424
}
425425

426426
if r.RowsAffected == rowsBatch {
@@ -691,7 +691,7 @@ func row2ToSQL(s sql.Schema, row sql.Row2) ([]sqltypes.Value, error) {
691691
return o, nil
692692
}
693693

694-
func schemaToFields(s sql.Schema) []*query.Field {
694+
func schemaToFields(ctx *sql.Context, s sql.Schema) []*query.Field {
695695
fields := make([]*query.Field, len(s))
696696
for i, c := range s {
697697
var charset uint32 = mysql.CharacterSetUtf8
@@ -703,7 +703,7 @@ func schemaToFields(s sql.Schema) []*query.Field {
703703
Name: c.Name,
704704
Type: c.Type.Type(),
705705
Charset: charset,
706-
ColumnLength: c.Type.MaxTextResponseByteLength(),
706+
ColumnLength: c.Type.MaxTextResponseByteLength(ctx),
707707
}
708708
}
709709

server/handler_test.go

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ import (
3535
"github.com/dolthub/go-mysql-server/sql/analyzer"
3636
"github.com/dolthub/go-mysql-server/sql/expression"
3737
"github.com/dolthub/go-mysql-server/sql/types"
38+
"github.com/dolthub/go-mysql-server/sql/variables"
3839
)
3940

4041
func TestHandlerOutput(t *testing.T) {
41-
4242
e := setupMemDB(require.New(t))
4343
dummyConn := newConn(1)
4444
handler := &Handler{
@@ -655,14 +655,79 @@ func TestSchemaToFields(t *testing.T) {
655655

656656
require.Equal(len(schema), len(expected))
657657

658-
fields := schemaToFields(schema)
658+
handler := &Handler{
659+
e: setupMemDB(require),
660+
sm: NewSessionManager(
661+
testSessionBuilder,
662+
sql.NoopTracer,
663+
func(ctx *sql.Context, db string) bool { return db == "test" },
664+
sql.NewMemoryManager(nil),
665+
sqle.NewProcessList(),
666+
"foo",
667+
),
668+
readTimeout: time.Second,
669+
}
670+
671+
conn := newConn(1)
672+
handler.NewConnection(conn)
673+
674+
ctx, err := handler.sm.NewContextWithQuery(conn, "SELECT 1")
675+
require.NoError(err)
676+
677+
fields := schemaToFields(ctx, schema)
659678
for i := 0; i < len(fields); i++ {
660679
t.Run(schema[i].Name, func(t *testing.T) {
661680
assert.Equal(t, expected[i], fields[i])
662681
})
663682
}
664683
}
665684

685+
// TestHandlerMaxTextResponseBytes tests that the handler calculates the correct max text response byte
686+
// metadata for TEXT types, including honoring the character_set_results session variable. This is tested
687+
// here, instead of in string type unit tests, because of the dependency on system variables being loaded.
688+
func TestHandlerMaxTextResponseBytes(t *testing.T) {
689+
variables.InitSystemVariables()
690+
session := sql.NewBaseSession()
691+
ctx := sql.NewContext(
692+
context.Background(),
693+
sql.WithSession(session),
694+
)
695+
696+
tinyTextUtf8mb4 := types.MustCreateString(sqltypes.Text, types.TinyTextBlobMax, sql.Collation_Default)
697+
textUtf8mb4 := types.MustCreateString(sqltypes.Text, types.TextBlobMax, sql.Collation_Default)
698+
mediumTextUtf8mb4 := types.MustCreateString(sqltypes.Text, types.MediumTextBlobMax, sql.Collation_Default)
699+
longTextUtf8mb4 := types.MustCreateString(sqltypes.Text, types.LongTextBlobMax, sql.Collation_Default)
700+
701+
// When character_set_results is set to utf8mb4, the multibyte character multiplier is 4
702+
require.NoError(t, session.SetSessionVariable(ctx, "character_set_results", "utf8mb4"))
703+
require.EqualValues(t, types.TinyTextBlobMax*4, tinyTextUtf8mb4.MaxTextResponseByteLength(ctx))
704+
require.EqualValues(t, types.TextBlobMax*4, textUtf8mb4.MaxTextResponseByteLength(ctx))
705+
require.EqualValues(t, types.MediumTextBlobMax*4, mediumTextUtf8mb4.MaxTextResponseByteLength(ctx))
706+
require.EqualValues(t, types.LongTextBlobMax, longTextUtf8mb4.MaxTextResponseByteLength(ctx))
707+
708+
// When character_set_results is set to utf8mb3, the multibyte character multiplier is 3
709+
require.NoError(t, session.SetSessionVariable(ctx, "character_set_results", "utf8mb3"))
710+
require.EqualValues(t, types.TinyTextBlobMax*3, tinyTextUtf8mb4.MaxTextResponseByteLength(ctx))
711+
require.EqualValues(t, types.TextBlobMax*3, textUtf8mb4.MaxTextResponseByteLength(ctx))
712+
require.EqualValues(t, types.MediumTextBlobMax*3, mediumTextUtf8mb4.MaxTextResponseByteLength(ctx))
713+
require.EqualValues(t, types.LongTextBlobMax, longTextUtf8mb4.MaxTextResponseByteLength(ctx))
714+
715+
// When character_set_results is set to utf8, the multibyte character multiplier is 3
716+
require.NoError(t, session.SetSessionVariable(ctx, "character_set_results", "utf8"))
717+
require.EqualValues(t, types.TinyTextBlobMax*3, tinyTextUtf8mb4.MaxTextResponseByteLength(ctx))
718+
require.EqualValues(t, types.TextBlobMax*3, textUtf8mb4.MaxTextResponseByteLength(ctx))
719+
require.EqualValues(t, types.MediumTextBlobMax*3, mediumTextUtf8mb4.MaxTextResponseByteLength(ctx))
720+
require.EqualValues(t, types.LongTextBlobMax, longTextUtf8mb4.MaxTextResponseByteLength(ctx))
721+
722+
// When character_set_results is set to NULL, the multibyte character multiplier is taken from
723+
// the type's charset (4 in this case)
724+
require.NoError(t, session.SetSessionVariable(ctx, "character_set_results", nil))
725+
require.EqualValues(t, types.TinyTextBlobMax*4, tinyTextUtf8mb4.MaxTextResponseByteLength(ctx))
726+
require.EqualValues(t, types.TextBlobMax*4, textUtf8mb4.MaxTextResponseByteLength(ctx))
727+
require.EqualValues(t, types.MediumTextBlobMax*4, mediumTextUtf8mb4.MaxTextResponseByteLength(ctx))
728+
require.EqualValues(t, types.LongTextBlobMax, longTextUtf8mb4.MaxTextResponseByteLength(ctx))
729+
}
730+
666731
func TestHandlerTimeout(t *testing.T) {
667732
require := require.New(t)
668733

sql/analyzer/validate_create_table.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope,
159159
if err != nil {
160160
return nil, transform.SameTree, err
161161
}
162-
sch, err = validatePrimaryKey(initialSch, sch, n.(*plan.AlterPK))
162+
sch, err = validatePrimaryKey(ctx, initialSch, sch, n.(*plan.AlterPK))
163163
if err != nil {
164164
return nil, transform.SameTree, err
165165
}
@@ -397,7 +397,7 @@ func validateAlterIndex(ctx *sql.Context, initialSch, sch sql.Schema, ai *plan.A
397397
if !ok {
398398
return nil, sql.ErrKeyColumnDoesNotExist.New(badColName)
399399
}
400-
err := validateIndexType(ai.Columns, sch)
400+
err := validateIndexType(ctx, ai.Columns, sch)
401401
if err != nil {
402402
return nil, err
403403
}
@@ -455,7 +455,7 @@ func validateAlterIndex(ctx *sql.Context, initialSch, sch sql.Schema, ai *plan.A
455455
}
456456

457457
// validatePrefixLength handles all errors related to creating indexes with prefix lengths
458-
func validatePrefixLength(schCol *sql.Column, idxCol sql.IndexColumn) error {
458+
func validatePrefixLength(ctx *sql.Context, schCol *sql.Column, idxCol sql.IndexColumn) error {
459459
// Throw prefix length error for non-string types with prefixes
460460
if idxCol.Length > 0 && !types.IsText(schCol.Type) {
461461
return sql.ErrInvalidIndexPrefix.New(schCol.Name)
@@ -473,7 +473,7 @@ func validatePrefixLength(schCol *sql.Column, idxCol sql.IndexColumn) error {
473473
}
474474

475475
// The specified prefix length is longer than the column
476-
maxByteLength := int64(schCol.Type.MaxTextResponseByteLength())
476+
maxByteLength := int64(schCol.Type.MaxTextResponseByteLength(ctx))
477477
if prefixByteLength > maxByteLength {
478478
return sql.ErrInvalidIndexPrefix.New(schCol.Name)
479479
}
@@ -487,10 +487,10 @@ func validatePrefixLength(schCol *sql.Column, idxCol sql.IndexColumn) error {
487487
}
488488

489489
// validateIndexType prevents creating invalid indexes
490-
func validateIndexType(cols []sql.IndexColumn, sch sql.Schema) error {
490+
func validateIndexType(ctx *sql.Context, cols []sql.IndexColumn, sch sql.Schema) error {
491491
for _, idxCol := range cols {
492492
schCol := sch[sch.IndexOfColName(idxCol.Name)]
493-
err := validatePrefixLength(schCol, idxCol)
493+
err := validatePrefixLength(ctx, schCol, idxCol)
494494
if err != nil {
495495
return err
496496
}
@@ -610,7 +610,7 @@ func validateIndexes(ctx *sql.Context, tableSpec *plan.TableSpec) error {
610610
if !ok {
611611
return sql.ErrUnknownIndexColumn.New(idxCol.Name, idx.IndexName)
612612
}
613-
err := validatePrefixLength(schCol, idxCol)
613+
err := validatePrefixLength(ctx, schCol, idxCol)
614614
if err != nil {
615615
return err
616616
}
@@ -683,7 +683,7 @@ func getTableIndexNames(ctx *sql.Context, a *Analyzer, table sql.Node) ([]string
683683
}
684684

685685
// validatePrimaryKey validates a primary key add or drop operation.
686-
func validatePrimaryKey(initialSch, sch sql.Schema, ai *plan.AlterPK) (sql.Schema, error) {
686+
func validatePrimaryKey(ctx *sql.Context, initialSch, sch sql.Schema, ai *plan.AlterPK) (sql.Schema, error) {
687687
tableName := getTableName(ai.Table)
688688
switch ai.Action {
689689
case plan.PrimaryKeyAction_Create:
@@ -698,7 +698,7 @@ func validatePrimaryKey(initialSch, sch sql.Schema, ai *plan.AlterPK) (sql.Schem
698698

699699
for _, idxCol := range ai.Columns {
700700
schCol := sch[sch.IndexOf(idxCol.Name, tableName)]
701-
err := validatePrefixLength(schCol, idxCol)
701+
err := validatePrefixLength(ctx, schCol, idxCol)
702702
if err != nil {
703703
return nil, err
704704
}

sql/information_schema/columns_table.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ func getRowFromColumn(ctx *sql.Context, curOrdPos int, col *sql.Column, dbName,
209209
}
210210
}
211211

212-
charName, collName, charMaxLen, charOctetLen := getCharAndCollNamesAndCharMaxAndOctetLens(col.Type)
212+
charName, collName, charMaxLen, charOctetLen := getCharAndCollNamesAndCharMaxAndOctetLens(ctx, col.Type)
213213

214214
numericPrecision, numericScale := getColumnPrecisionAndScale(col.Type)
215215
if types.IsDatetimeType(col.Type) || types.IsTimestampType(col.Type) {
@@ -539,7 +539,7 @@ func getColumnPrecisionAndScale(colType sql.Type) (interface{}, interface{}) {
539539
}
540540
}
541541

542-
func getCharAndCollNamesAndCharMaxAndOctetLens(colType sql.Type) (interface{}, interface{}, interface{}, interface{}) {
542+
func getCharAndCollNamesAndCharMaxAndOctetLens(ctx *sql.Context, colType sql.Type) (interface{}, interface{}, interface{}, interface{}) {
543543
var (
544544
charName interface{}
545545
collName interface{}
@@ -551,8 +551,8 @@ func getCharAndCollNamesAndCharMaxAndOctetLens(colType sql.Type) (interface{}, i
551551
collName = colColl.Name()
552552
charName = colColl.CharacterSet().String()
553553
if types.IsEnum(colType) || types.IsSet(colType) {
554-
charOctetLen = int64(colType.MaxTextResponseByteLength())
555-
charMaxLen = int64(colType.MaxTextResponseByteLength()) / colColl.CharacterSet().MaxLength()
554+
charOctetLen = int64(colType.MaxTextResponseByteLength(ctx))
555+
charMaxLen = int64(colType.MaxTextResponseByteLength(ctx)) / colColl.CharacterSet().MaxLength()
556556
}
557557
}
558558
if st, ok := colType.(sql.StringType); ok {

sql/information_schema/routines_table.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ func parametersRowIter(ctx *Context, c Catalog, p map[string][]*plan.Procedure)
285285
parameterMode = "OUT"
286286
}
287287

288-
charName, collName, charMaxLen, charOctetLen := getCharAndCollNamesAndCharMaxAndOctetLens(param.Type)
288+
charName, collName, charMaxLen, charOctetLen := getCharAndCollNamesAndCharMaxAndOctetLens(ctx, param.Type)
289289
numericPrecision, numericScale := getColumnPrecisionAndScale(param.Type)
290290
// float types get nil for numericScale, but it gets 0 for this table
291291
if _, ok := param.Type.(NumberType); ok {

sql/type.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ type Type interface {
7777
// Equals returns whether the given type is equivalent to the calling type. All parameters are included in the
7878
// comparison, so ENUM("a", "b") is not equivalent to ENUM("a", "b", "c").
7979
Equals(otherType Type) bool
80-
// MaxTextResponseByteLength returns the maximum number of bytes needed to serialize an instance of this type as a string in a response over the wire for MySQL's text protocol – in other words, this is the maximum bytes needed to serialize any value of this type as human-readable text, NOT in a more compact, binary representation.
81-
MaxTextResponseByteLength() uint32
80+
// MaxTextResponseByteLength returns the maximum number of bytes needed to serialize an instance of this type
81+
// as a string in a response over the wire for MySQL's text protocol – in other words, this is the maximum bytes
82+
// needed to serialize any value of this type as human-readable text, NOT in a more compact, binary representation.
83+
MaxTextResponseByteLength(ctx *Context) uint32
8284
// Promote will promote the current type to the largest representing type of the same kind, such as Int8 to Int64.
8385
Promote() Type
8486
// SQL returns the sqltypes.Value for the given value.

sql/types/bit.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func MustCreateBitType(numOfBits uint8) BitType {
7272
}
7373

7474
// MaxTextResponseByteLength implements Type interface
75-
func (t BitType_) MaxTextResponseByteLength() uint32 {
75+
func (t BitType_) MaxTextResponseByteLength(_ *sql.Context) uint32 {
7676
// Because this is a text serialization format, each bit requires one byte in the text response format
7777
return uint32(t.numOfBits)
7878
}

sql/types/datetime.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ func (t datetimeType) Equals(otherType sql.Type) bool {
323323
}
324324

325325
// MaxTextResponseByteLength implements the Type interface
326-
func (t datetimeType) MaxTextResponseByteLength() uint32 {
326+
func (t datetimeType) MaxTextResponseByteLength(_ *sql.Context) uint32 {
327327
switch t.baseType {
328328
case sqltypes.Date:
329329
return uint32(len(sql.DateLayout))

sql/types/decimal.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ func (t DecimalType_) Equals(otherType sql.Type) bool {
267267
}
268268

269269
// MaxTextResponseByteLength implements the Type interface
270-
func (t DecimalType_) MaxTextResponseByteLength() uint32 {
270+
func (t DecimalType_) MaxTextResponseByteLength(_ *sql.Context) uint32 {
271271
if t.scale == 0 {
272272
// if no digits are reserved for the right-hand side of the decimal point,
273273
// just return precision plus one byte for sign

sql/types/deferred.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func (t deferredType) Convert(v interface{}) (interface{}, sql.ConvertInRange, e
5454
}
5555

5656
// MaxTextResponseByteLength implements the Type interface
57-
func (t deferredType) MaxTextResponseByteLength() uint32 {
57+
func (t deferredType) MaxTextResponseByteLength(_ *sql.Context) uint32 {
5858
// deferredType is never actually sent over the wire
5959
return 0
6060
}

0 commit comments

Comments
 (0)