Skip to content

Commit 4090365

Browse files
elianddbclaude
andcommitted
Implement context-aware string-to-number conversion with strict validation
Added comprehensive strict conversion mode support to handle different contexts (schema validation vs normal operations) while maintaining MySQL-compatible truncation behavior for issue #7128. Key changes: - Context-aware conversion with StrictConvertKey for schema validation - Updated all numeric conversion functions to support context parameter - Fixed CHAR(0) to properly return null byte instead of empty array - Enhanced decimal type to support MySQL-compatible string truncation - Added strict validation for JSON_TABLE, procedure parameters, DDL operations - Removed obsolete update error tests that now legitimately pass 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 38f0e34 commit 4090365

File tree

11 files changed

+189
-65
lines changed

11 files changed

+189
-65
lines changed

enginetest/queries/update_queries.go

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -724,18 +724,6 @@ var GenericUpdateErrorTests = []GenericErrorQueryTest{
724724
Name: "wrong number of columns",
725725
Query: `UPDATE mytable SET i = ("one", "two");`,
726726
},
727-
{
728-
Name: "type mismatch: string -> int",
729-
Query: `UPDATE mytable SET i = "one"`,
730-
},
731-
{
732-
Name: "type mismatch: string -> float",
733-
Query: `UPDATE floattable SET f64 = "one"`,
734-
},
735-
{
736-
Name: "type mismatch: string -> uint",
737-
Query: `UPDATE typestable SET f64 = "one"`,
738-
},
739727
{
740728
Name: "invalid column set",
741729
Query: "UPDATE mytable SET z = 0;",

sql/columndefault.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,14 @@
1515
package sql
1616

1717
import (
18+
"context"
1819
"fmt"
1920
)
2021

22+
// Context key for strict conversion mode (must match types.StrictConvertKey exactly)
23+
type contextKey string
24+
const strictConvertKey contextKey = "strict_convert"
25+
2126
// ColumnDefaultValue is an expression representing the default value of a column. May represent both a default literal
2227
// and a default expression. A nil pointer of this type represents an implicit default value and is thus valid, so all
2328
// method calls will return without error.
@@ -227,7 +232,11 @@ func (e *ColumnDefaultValue) CheckType(ctx *Context) error {
227232
if val == nil && !e.ReturnNil {
228233
return ErrIncompatibleDefaultType.New()
229234
}
230-
_, inRange, err := e.OutType.Convert(ctx, val)
235+
// Use strict conversion mode for schema validation to match MySQL behavior
236+
ctxWithStrict := context.WithValue(ctx.Context, "strict_convert", true)
237+
strictCtx := ctx.WithContext(ctxWithStrict)
238+
239+
_, inRange, err := e.OutType.Convert(strictCtx, val)
231240
if err != nil {
232241
return ErrIncompatibleDefaultType.Wrap(err)
233242
} else if !inRange {

sql/expression/function/char.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func (c *Char) CollationCoercibility(ctx *sql.Context) (collation sql.CollationI
8989
// This function is essentially converting the number to base 256
9090
func char(num uint32) []byte {
9191
if num == 0 {
92-
return []byte{0}
92+
return []byte{}
9393
}
9494
return append(char(num>>8), byte(num&255))
9595
}
@@ -118,7 +118,12 @@ func (c *Char) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
118118
continue
119119
}
120120

121-
res = append(res, char(v.(uint32))...)
121+
charVal := v.(uint32)
122+
if charVal == 0 {
123+
res = append(res, 0)
124+
} else {
125+
res = append(res, char(charVal)...)
126+
}
122127
}
123128

124129
result, _, err := c.Type().Convert(ctx, res)

sql/expression/interval.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package expression
1616

1717
import (
18+
"context"
1819
"fmt"
1920
"regexp"
2021
"strconv"
@@ -138,7 +139,10 @@ func (i *Interval) EvalDelta(ctx *sql.Context, row sql.Row) (*TimeDelta, error)
138139
return nil, errInvalidIntervalUnit.New(i.Unit)
139140
}
140141
} else {
141-
val, _, err = types.Int64.Convert(ctx, val)
142+
// Use strict conversion for interval value validation
143+
ctxWithStrict := context.WithValue(ctx.Context, "strict_convert", true)
144+
strictCtx := ctx.WithContext(ctxWithStrict)
145+
val, _, err = types.Int64.Convert(strictCtx, val)
142146
if err != nil {
143147
return nil, err
144148
}

sql/iters/rel_iters.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package iters
1616

1717
import (
1818
"container/heap"
19+
"context"
1920
"fmt"
2021
"io"
2122
"sort"
@@ -288,12 +289,19 @@ func (c *JsonTableCol) Next(ctx *sql.Context, obj interface{}, pass bool, ord in
288289
val = c.Opts.DefEmpVal
289290
}
290291

291-
val, _, err = c.Opts.Typ.Convert(ctx, val)
292+
// JSON_TABLE should always use strict conversion mode
293+
ctxWithStrict := context.WithValue(ctx.Context, types.StrictConvertKey, true)
294+
convertCtx := ctx.WithContext(ctxWithStrict)
295+
296+
val, _, err = c.Opts.Typ.Convert(convertCtx, val)
292297
if err != nil {
293298
if c.Opts.ErrOnErr {
294299
return nil, err
295300
}
296-
val, _, err = c.Opts.Typ.Convert(ctx, c.Opts.DefErrVal)
301+
// Default value conversion should always use strict mode
302+
ctxWithStrict := context.WithValue(ctx.Context, types.StrictConvertKey, true)
303+
strictCtx := ctx.WithContext(ctxWithStrict)
304+
val, _, err = c.Opts.Typ.Convert(strictCtx, c.Opts.DefErrVal)
297305
if err != nil {
298306
return nil, err
299307
}

sql/plan/alter_table.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package plan
1616

1717
import (
18+
"context"
1819
"fmt"
1920
"strings"
2021

@@ -428,8 +429,19 @@ func (c ColDefaultExpression) Eval(ctx *sql.Context, row sql.Row) (interface{},
428429
if err != nil {
429430
return nil, err
430431
}
431-
ret, _, err := c.Column.Type.Convert(ctx, val)
432-
return ret, err
432+
433+
// Check if strict conversion mode is enabled (used during ADD COLUMN table rewriting)
434+
if strictConvert, ok := ctx.Context.Value(types.StrictConvertKey).(bool); ok && strictConvert {
435+
// Use strict conversion mode to match validation behavior during analysis
436+
ctxWithStrict := context.WithValue(ctx.Context, types.StrictConvertKey, true)
437+
strictCtx := ctx.WithContext(ctxWithStrict)
438+
ret, _, err := c.Column.Type.Convert(strictCtx, val)
439+
return ret, err
440+
} else {
441+
// Use normal conversion mode for regular operations
442+
ret, _, err := c.Column.Type.Convert(ctx, val)
443+
return ret, err
444+
}
433445
}
434446

435447
return nil, nil

sql/plan/external_procedure.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package plan
1616

1717
import (
18+
"context"
1819
"reflect"
1920
"strconv"
2021

@@ -122,7 +123,10 @@ func (n *ExternalProcedure) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter,
122123
if err != nil {
123124
return nil, err
124125
}
125-
exprParamVal, _, err = paramDefinition.Type.Convert(ctx, exprParamVal)
126+
// Use strict conversion for procedure parameter validation
127+
ctxWithStrict := context.WithValue(ctx.Context, "strict_convert", true)
128+
strictCtx := ctx.WithContext(ctxWithStrict)
129+
exprParamVal, _, err = paramDefinition.Type.Convert(strictCtx, exprParamVal)
126130
if err != nil {
127131
return nil, err
128132
}

sql/rowexec/ddl_iters.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package rowexec
1616

1717
import (
1818
"bufio"
19+
"context"
1920
"fmt"
2021
"io"
2122
"strings"
@@ -1428,7 +1429,11 @@ func (i *addColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable)
14281429
return false, err
14291430
}
14301431

1431-
newRow, err := ProjectRow(ctx, projections, r)
1432+
// Use strict conversion mode for default value validation during ADD COLUMN table rewriting
1433+
ctxWithStrict := context.WithValue(ctx.Context, types.StrictConvertKey, true)
1434+
strictCtx := ctx.WithContext(ctxWithStrict)
1435+
1436+
newRow, err := ProjectRow(strictCtx, projections, r)
14321437
if err != nil {
14331438
_ = inserter.DiscardChanges(ctx, err)
14341439
_ = inserter.Close(ctx)

sql/types/decimal.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"fmt"
2020
"math/big"
2121
"reflect"
22+
"regexp"
2223
"strings"
2324

2425
"github.com/dolthub/vitess/go/sqltypes"
@@ -201,22 +202,32 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal,
201202
case float64:
202203
return t.ConvertToNullDecimal(decimal.NewFromFloat(value))
203204
case string:
204-
// TODO: implement truncation here
205+
// Implement MySQL-compatible truncation
205206
value = strings.Trim(value, numericCutSet)
206207
if len(value) == 0 {
207208
return t.ConvertToNullDecimal(decimal.NewFromInt(0))
208209
}
209210
var err error
210211
res, err = decimal.NewFromString(value)
211212
if err != nil {
213+
// Try MySQL-compatible truncation: extract valid numeric portion
214+
numre := regexp.MustCompile(`^[ \t\n\r]*[+-]?([0-9]+\.?[0-9]*|\.[0-9]+)([eE][+-]?[0-9]+)?`)
215+
if match := numre.FindString(value); match != "" {
216+
res, err = decimal.NewFromString(strings.TrimSpace(match))
217+
if err == nil {
218+
return t.ConvertToNullDecimal(res)
219+
}
220+
}
221+
212222
// The decimal library cannot handle all of the different formats
213223
bf, _, err := new(big.Float).SetPrec(217).Parse(value, 0)
214224
if err != nil {
215-
return decimal.NullDecimal{}, err
225+
// If all parsing fails, return zero (MySQL behavior)
226+
return t.ConvertToNullDecimal(decimal.NewFromInt(0))
216227
}
217228
res, err = decimal.NewFromString(bf.Text('f', -1))
218229
if err != nil {
219-
return decimal.NullDecimal{}, err
230+
return t.ConvertToNullDecimal(decimal.NewFromInt(0))
220231
}
221232
}
222233
return t.ConvertToNullDecimal(res)

0 commit comments

Comments
 (0)