Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sql/analyzer/validation_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
defer span.End()

// only enforce strict group by when this variable is set
if !sql.LoadSqlMode(ctx).ModeEnabled(sql.OnlyFullGroupBy) {
if !sql.LoadSqlMode(ctx).OnlyFullGroupBy() {
return n, transform.SameTree, nil
}

Expand Down
75 changes: 55 additions & 20 deletions sql/sql_mode.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,31 @@ import (
const (
SqlModeSessionVar = "SQL_MODE"

ANSI = "ANSI"
ANSIQuotes = "ANSI_QUOTES"
OnlyFullGroupBy = "ONLY_FULL_GROUP_BY"
NoAutoValueOnZero = "NO_AUTO_VALUE_ON_ZERO"
NoEngineSubstitution = "NO_ENGINE_SUBSTITUTION"
StrictTransTables = "STRICT_TRANS_TABLES"
PipesAsConcat = "PIPES_AS_CONCAT"
DefaultSqlMode = "NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES"
AllowInvalidDates = "ALLOW_INVALID_DATES"
ANSIQuotes = "ANSI_QUOTES"
ErrorForDivisionByZero = "ERROR_FOR_DIVISION_BY_ZERO"
HighNotPrecedence = "HIGH_NOT_PRECEDENCE"
IgnoreSpaces = "IGNORE_SPACE"
NoAutoValueOnZero = "NO_AUTO_VALUE_ON_ZERO"
NoBackslashEscapes = "NO_BACKSLASH_ESCAPES"
NoDirInCreate = "NO_DIR_IN_CREATE"
NoEngineSubstitution = "NO_ENGINE_SUBSTITUTION"
NoUnsignedSubtraction = "NO_UNSIGNED_SUBTRACTION"
NoZeroInDate = "NO_ZERO_IN_DATE"
OnlyFullGroupBy = "ONLY_FULL_GROUP_BY"
PadCharToFullLength = "PAD_CHAR_TO_FULL_LENGTH"
PipesAsConcat = "PIPES_AS_CONCAT"
RealAsFloat = "REAL_AS_FLOAT"
StrictTransTables = "STRICT_TRANS_TABLES"
StrictAllTables = "STRICT_ALL_TABLES"
TimeTruncateFractional = "TIME_TRUNCATE_FRACTIONAL"

// ANSI mode includes REAL_AS_FLOAT, PIPES_AS_CONCAT, ANSI_QUOTES, IGNORE_SPACE, and ONLY_FULL_GROUP_BY
ANSI = "ANSI"
// Traditional mode includes STRICT_TRANS_TABLES, STRICT_ALL_TABLES, NO_ZERO_IN_DATE, ERROR_FOR_DIVISION_BY_ZERO,
// and NO_ENGINE_SUBSTITUTION
Traditional = "TRADITIONAL"
DefaultSqlMode = "NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES"
)

var defaultMode *SqlMode
Expand Down Expand Up @@ -74,7 +91,7 @@ func LoadSqlMode(ctx *Context) *SqlMode {
}

// NewSqlModeFromString returns a new SqlMode instance, constructed from the specified |sqlModeString| that
// has a comma delimited list of SQL modes (e.g. "ONLY_FULLY_GROUP_BY,ANSI_QUOTES").
// has a comma-delimited list of SQL modes (e.g. "ONLY_FULLY_GROUP_BY,ANSI_QUOTES").
func NewSqlModeFromString(sqlModeString string) *SqlMode {
if sqlModeString == DefaultSqlMode {
return defaultMode
Expand All @@ -99,9 +116,36 @@ func (s *SqlMode) AnsiQuotes() bool {
return s.ModeEnabled(ANSIQuotes) || s.ModeEnabled(ANSI)
}

// PipesAsConcat returns true if PIPES_AS_CONCAT SQL mode is enabled.
// OnlyFullGroupBy returns true is ONLY_TRUE_GROUP_BY SQL mode is enabled. Note that ANSI mode is a compound mode that
// includes ONLY_FULL_GROUP_BY and other options, so if ANSI or ONLY_TRUE_GROUP_BY is enabled, this function will
// return true.
func (s *SqlMode) OnlyFullGroupBy() bool {
return s.ModeEnabled(OnlyFullGroupBy) || s.ModeEnabled(ANSI)
}

// PipesAsConcat returns true if PIPES_AS_CONCAT SQL mode is enabled. Note that ANSI mode is a compound mode that
// includes PIPES_AS_CONCAT and other options, so if ANSI or PIPES_AS_CONCAT is enabled, this function will return true.
func (s *SqlMode) PipesAsConcat() bool {
return s.ModeEnabled(PipesAsConcat)
return s.ModeEnabled(PipesAsConcat) || s.ModeEnabled(ANSI)
}

// StrictTransTables returns true if STRICT_TRANS_TABLES SQL mode is enabled. Note that TRADITIONAL mode is a compound
// mode that includes STRICT_TRANS_TABLES and other options, so if TRADITIONAL or STRICT_TRANS_TABLES is enabled, this
// function will return true.
func (s *SqlMode) StrictTransTables() bool {
return s.ModeEnabled(StrictTransTables) || s.ModeEnabled(Traditional)
}

// StrictAllTables returns true if STRICT_ALL_TABLES SQL mode is enabled. Note that TRADITIONAL mode is a compound
// mode that includes STRICT_ALL_TABLES and other options, so if TRADITIONAL or STRICT_ALL_TABLES is enabled, this
// function will return true.
func (s *SqlMode) StrictAllTables() bool {
return s.ModeEnabled(StrictAllTables) || s.ModeEnabled(Traditional)
}

// Strict mode is enabled when either STRICT_TRANS_TABLES or STRICT_ALL_TABLES is enabled.
func (s *SqlMode) Strict() bool {
return s.StrictAllTables() || s.StrictTransTables()
}

// ModeEnabled returns true if |mode| was explicitly specified in the SQL_MODE string that was used to
Expand All @@ -126,12 +170,3 @@ func (s *SqlMode) ParserOptions() sqlparser.ParserOptions {
func (s *SqlMode) String() string {
return s.modeString
}

// ValidateStrictMode returns true if either STRICT_TRANS_TABLES or STRICT_ALL_TABLES is enabled
func ValidateStrictMode(ctx *Context) bool {
if ctx == nil {
return false
}
sqlMode := LoadSqlMode(ctx)
return sqlMode.ModeEnabled("STRICT_TRANS_TABLES") || sqlMode.ModeEnabled("STRICT_ALL_TABLES")
}
13 changes: 7 additions & 6 deletions sql/sql_mode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ import (
)

func TestSqlMode(t *testing.T) {
// Test that ANSI mode includes ANSI_QUOTES mode
sqlMode := NewSqlModeFromString("only_full_group_by,ansi")
// Test that ANSI mode includes ANSI_QUOTES, PIPES_AS_CONCAT, and ONLY_FULL_GROUP_BY mode
sqlMode := NewSqlModeFromString("ansi")
assert.True(t, sqlMode.AnsiQuotes())
assert.True(t, sqlMode.ModeEnabled("ansi"))
assert.True(t, sqlMode.ModeEnabled("ANSI"))
assert.True(t, sqlMode.ModeEnabled("ONLY_FULL_GROUP_BY"))
assert.False(t, sqlMode.ModeEnabled("fake_mode"))
assert.True(t, sqlMode.ParserOptions().AnsiQuotes)
assert.Equal(t, "ANSI,ONLY_FULL_GROUP_BY", sqlMode.String())
assert.False(t, sqlMode.PipesAsConcat())
assert.Equal(t, "ANSI", sqlMode.String())
assert.True(t, sqlMode.PipesAsConcat()) // PIPES_AS_CONCAT is included in ANSI mode
assert.True(t, sqlMode.OnlyFullGroupBy()) // ONLY_FULL_GROUP_BY is included in ANSI mode
assert.False(t, sqlMode.ModeEnabled("pipes_as_concat"))

// Test a mixed case SQL_MODE string where only ANSI_QUOTES is enabled
Expand All @@ -44,13 +44,14 @@ func TestSqlMode(t *testing.T) {
assert.False(t, sqlMode.PipesAsConcat())
assert.False(t, sqlMode.ModeEnabled("pipes_as_concat"))

// Test when SQL_MODE does not include ANSI_QUOTES, includes PIPES_AS_CONCAT
// Test when SQL_MODE does not include ANSI_QUOTES, includes PIPES_AS_CONCAT and STRICT_TRANS_TABLES
sqlMode = NewSqlModeFromString("ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,PIPES_AS_CONCAT")
assert.False(t, sqlMode.AnsiQuotes())
assert.True(t, sqlMode.ModeEnabled("STRICT_TRANS_TABLES"))
assert.False(t, sqlMode.ModeEnabled("ansi_quotes"))
assert.False(t, sqlMode.ParserOptions().AnsiQuotes)
assert.True(t, sqlMode.PipesAsConcat())
assert.True(t, sqlMode.ModeEnabled("pipes_as_concat"))
assert.True(t, sqlMode.Strict())
assert.Equal(t, "ONLY_FULL_GROUP_BY,PIPES_AS_CONCAT,STRICT_TRANS_TABLES", sqlMode.String())
}
2 changes: 1 addition & 1 deletion sql/types/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (t EnumType) Convert(ctx context.Context, v interface{}) (interface{}, sql.
case int:
// MySQL rejects 0 values in strict mode regardless of enum definition
if value == 0 {
if sqlCtx, ok := ctx.(*sql.Context); ok && sql.ValidateStrictMode(sqlCtx) {
if sqlCtx, ok := ctx.(*sql.Context); ok && sql.LoadSqlMode(sqlCtx).Strict() {
return nil, sql.OutOfRange, ErrConvertingToEnum.New(value)
}
}
Expand Down
2 changes: 1 addition & 1 deletion sql/types/strings.go
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ func ConvertToBytes(ctx context.Context, v interface{}, t sql.StringType, dest [
if !IsBinaryType(t) && !utf8.Valid(bytesVal) {
charset := t.CharacterSet()
if charset == sql.CharacterSet_utf8mb4 {
if sqlCtx, ok := ctx.(*sql.Context); ok && sql.ValidateStrictMode(sqlCtx) {
if sqlCtx, ok := ctx.(*sql.Context); ok && sql.LoadSqlMode(sqlCtx).Strict() {
// Strict mode: reject invalid UTF8
invalidByte := formatInvalidByteForError(bytesVal)
colName, rowNum := getColumnContext(ctx)
Expand Down
Loading