diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index d9e025212b..7a989e04c0 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -5031,6 +5031,8 @@ SELECT * FROM cte WHERE d = 2;`, {"gtid_next", "AUTOMATIC"}, {"gtid_owned", ""}, {"gtid_purged", ""}, + {"gtid_domain_id", 0}, + {"gtid_seq_no", 0}, }, }, { diff --git a/sql/analyzer/catalog.go b/sql/analyzer/catalog.go index 02b6eaa106..cd5c08b896 100644 --- a/sql/analyzer/catalog.go +++ b/sql/analyzer/catalog.go @@ -34,6 +34,8 @@ type Catalog struct { DbProvider sql.DatabaseProvider AuthHandler sql.AuthorizationHandler + // BinlogConsumer holds an optional consumer that processes binlog events (e.g. for BINLOG statements). + BinlogConsumer binlogreplication.BinlogConsumer // BinlogReplicaController holds an optional controller that receives forwarded binlog // replication messages (e.g. "start replica"). BinlogReplicaController binlogreplication.BinlogReplicaController @@ -53,8 +55,9 @@ func (c *Catalog) DropDbStats(ctx *sql.Context, db string, flush bool) error { } var _ sql.Catalog = (*Catalog)(nil) -var _ binlogreplication.BinlogReplicaCatalog = (*Catalog)(nil) -var _ binlogreplication.BinlogPrimaryCatalog = (*Catalog)(nil) +var _ binlogreplication.BinlogConsumerProvider = (*Catalog)(nil) +var _ binlogreplication.BinlogReplicaProvider = (*Catalog)(nil) +var _ binlogreplication.BinlogPrimaryProvider = (*Catalog)(nil) type tableLocks map[string]struct{} @@ -76,6 +79,14 @@ func NewCatalog(provider sql.DatabaseProvider) *Catalog { return c } +func (c *Catalog) HasBinlogConsumer() bool { + return c.BinlogConsumer != nil +} + +func (c *Catalog) GetBinlogConsumer() binlogreplication.BinlogConsumer { + return c.BinlogConsumer +} + func (c *Catalog) HasBinlogReplicaController() bool { return c.BinlogReplicaController != nil } diff --git a/sql/binlogreplication/binlog_replication.go b/sql/binlogreplication/binlog_replication.go index 0e5ace0efd..6fbd3afdbd 100644 --- a/sql/binlogreplication/binlog_replication.go +++ b/sql/binlogreplication/binlog_replication.go @@ -25,6 +25,17 @@ import ( "github.com/dolthub/vitess/go/mysql" ) +// BinlogConsumer processes binlog events. This interface can be used by any component that needs to consume +// and apply binlog events, such as BINLOG statement execution, streaming replication, or other binlog processing. +type BinlogConsumer interface { + // ProcessEvent processes a single binlog event. + ProcessEvent(ctx *sql.Context, event mysql.BinlogEvent) error + + // HasFormatDescription returns true if a FORMAT_DESCRIPTION_EVENT has been processed. + // This is required before processing TABLE_MAP and row events in BINLOG statements. + HasFormatDescription() bool +} + // BinlogReplicaController allows callers to control a binlog replica. Providers built on go-mysql-server may optionally // implement this interface and use it when constructing a SQL engine in order to receive callbacks when replication // statements (e.g. START REPLICA, SHOW REPLICA STATUS) are being handled. @@ -147,21 +158,28 @@ type ReplicaStatus struct { SourceSsl bool } -// BinlogReplicaCatalog extends the Catalog interface and provides methods for accessing a BinlogReplicaController -// for a Catalog. -type BinlogReplicaCatalog interface { - // HasBinlogReplicaController returns true if a non-nil BinlogReplicaController is available for this BinlogReplicaCatalog. +// BinlogConsumerProvider provides methods for accessing a BinlogConsumer for BINLOG statement execution and other binlog +// event processing. Typically implemented by sql.Catalog. +type BinlogConsumerProvider interface { + // HasBinlogConsumer returns true if a non-nil BinlogConsumer is available. + HasBinlogConsumer() bool + // GetBinlogConsumer returns the BinlogConsumer. + GetBinlogConsumer() BinlogConsumer +} + +// BinlogReplicaProvider provides methods for accessing a BinlogReplicaController for binlog replica operations. +type BinlogReplicaProvider interface { + // HasBinlogReplicaController returns true if a non-nil BinlogReplicaController is available. HasBinlogReplicaController() bool - // GetBinlogReplicaController returns the BinlogReplicaController registered with this BinlogReplicaCatalog. + // GetBinlogReplicaController returns the BinlogReplicaController. GetBinlogReplicaController() BinlogReplicaController } -// BinlogPrimaryCatalog extends the Catalog interface and provides methods for accessing a BinlogPrimaryController -// for a Catalog. -type BinlogPrimaryCatalog interface { - // HasBinlogPrimaryController returns true if a non-nil BinlogPrimaryController is available for this BinlogPrimaryCatalog. +// BinlogPrimaryProvider provides methods for accessing a BinlogPrimaryController for binlog primary operations. +type BinlogPrimaryProvider interface { + // HasBinlogPrimaryController returns true if a non-nil BinlogPrimaryController is available. HasBinlogPrimaryController() bool - // GetBinlogPrimaryController returns the BinlogPrimaryController registered with this BinlogPrimaryCatalog. + // GetBinlogPrimaryController returns the BinlogPrimaryController. GetBinlogPrimaryController() BinlogPrimaryController } diff --git a/sql/collations.go b/sql/collations.go index 6b44dc9fa5..70c73bffa8 100644 --- a/sql/collations.go +++ b/sql/collations.go @@ -17,6 +17,7 @@ package sql import ( "fmt" "io" + "strconv" "strings" "sync" "unicode/utf8" @@ -972,3 +973,45 @@ type TypeWithCollation interface { // whether to include the character set and/or collation information. StringWithTableCollation(tableCollation CollationID) string } + +// ConvertCollationID converts numeric collation IDs to their string names. +func ConvertCollationID(val any) (string, error) { + var collationID uint64 + switch v := val.(type) { + case []byte: + if n, err := strconv.ParseUint(string(v), 10, 64); err == nil { + collationID = n + } else { + return string(v), nil + } + case int8: + collationID = uint64(v) + case int16: + collationID = uint64(v) + case int: + collationID = uint64(v) + case int32: + collationID = uint64(v) + case int64: + collationID = uint64(v) + case uint8: + collationID = uint64(v) + case uint16: + collationID = uint64(v) + case uint: + collationID = uint64(v) + case uint32: + collationID = uint64(v) + case uint64: + collationID = v + default: + return fmt.Sprintf("%v", val), nil + } + + if collationID >= uint64(len(collationArray)) { + return fmt.Sprintf("%v", val), nil + } + + collation := CollationID(collationID).Collation() + return collation.Name, nil +} diff --git a/sql/collations_test.go b/sql/collations_test.go index faf2f4c006..0766ccda8b 100644 --- a/sql/collations_test.go +++ b/sql/collations_test.go @@ -69,3 +69,61 @@ func testParseCollation(t *testing.T, charset string, collation string, binaryAt } }) } + +func TestConvertCollationID(t *testing.T) { + tests := []struct { + input any + expected string + }{ + {uint64(33), "utf8mb3_general_ci"}, + {int64(33), "utf8mb3_general_ci"}, + {[]byte("33"), "utf8mb3_general_ci"}, + {uint64(8), "latin1_swedish_ci"}, + {int32(8), "latin1_swedish_ci"}, + + {45, "utf8mb4_general_ci"}, + {uint64(46), "utf8mb4_bin"}, + {255, "utf8mb4_0900_ai_ci"}, + {uint64(309), "utf8mb4_0900_bin"}, + + {83, "utf8mb3_bin"}, + {uint64(223), "utf8mb3_general_mysql500_ci"}, + + {uint64(47), "latin1_bin"}, + {48, "latin1_general_ci"}, + {49, "latin1_general_cs"}, + + {uint64(63), "binary"}, + + {uint64(11), "ascii_general_ci"}, + {65, "ascii_bin"}, + + {uint64(15), "latin1_danish_ci"}, + {31, "latin1_german2_ci"}, + {94, "latin1_spanish_ci"}, + + {int8(8), "latin1_swedish_ci"}, + {int16(8), "latin1_swedish_ci"}, + {int(8), "latin1_swedish_ci"}, + {uint8(8), "latin1_swedish_ci"}, + {uint16(8), "latin1_swedish_ci"}, + {uint(8), "latin1_swedish_ci"}, + {uint32(8), "latin1_swedish_ci"}, + + {"utf8mb4_0900_bin", "utf8mb4_0900_bin"}, + {"utf8mb3_general_ci", "utf8mb3_general_ci"}, + {"", ""}, + + {uint64(99999), "99999"}, + {uint64(1000), "1000"}, + {int(500), "500"}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%T(%v)", tt.input, tt.input), func(t *testing.T) { + result, err := ConvertCollationID(tt.input) + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/sql/errors.go b/sql/errors.go index a8e6296084..7d97ea8ee2 100644 --- a/sql/errors.go +++ b/sql/errors.go @@ -953,6 +953,15 @@ var ( // ErrUnresolvedTableLock is returned when a FOR UPDATE OF clause references a table that doesn't exist in the query context. ErrUnresolvedTableLock = errors.NewKind("unresolved table name `%s` in locking clause.") + + // ErrBase64DecodeError is returned when decoding a base64 string fails. + ErrBase64DecodeError = errors.NewKind("Decoding of base64 string failed") + + // ErrNoFormatDescriptionEventBeforeBinlogStatement is returned when a BINLOG statement is not preceded by a format description event. + ErrNoFormatDescriptionEventBeforeBinlogStatement = errors.NewKind("The BINLOG statement of type `%s` was not preceded by a format description BINLOG statement.") + + // ErrOnlyFDAndRBREventsAllowedInBinlogStatement is returned when an unsupported event type is used in a BINLOG statement. + ErrOnlyFDAndRBREventsAllowedInBinlogStatement = errors.NewKind("Only Format_description_log_event and row events are allowed in BINLOG statements (but %s was provided)") ) // CastSQLError returns a *mysql.SQLError with the error code and in some cases, also a SQL state, populated for the @@ -1034,6 +1043,12 @@ func CastSQLError(err error) *mysql.SQLError { // https://en.wikipedia.org/wiki/SQLSTATE code = mysql.ERLockDeadlock sqlState = mysql.SSLockDeadlock + case ErrBase64DecodeError.Is(err): + code = mysql.ERBase64DecodeError + case ErrNoFormatDescriptionEventBeforeBinlogStatement.Is(err): + code = mysql.ERNoFormatDescriptionEventBeforeBinlogStatement + case ErrOnlyFDAndRBREventsAllowedInBinlogStatement.Is(err): + code = mysql.EROnlyFDAndRBREventsAllowedInBinlogStatement default: code = mysql.ERUnknownError } diff --git a/sql/plan/binlog.go b/sql/plan/binlog.go new file mode 100644 index 0000000000..7fb3724491 --- /dev/null +++ b/sql/plan/binlog.go @@ -0,0 +1,84 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plan + +import ( + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/binlogreplication" + "github.com/dolthub/go-mysql-server/sql/types" +) + +// DynamicPrivilege_BinlogAdmin enables binary log control by means of the PURGE BINARY LOGS and BINLOG statements. +// https://dev.mysql.com/doc/refman/8.0/en/privileges-provided.html#priv_binlog-admin +const DynamicPrivilege_BinlogAdmin = "binlog_admin" + +// Binlog replays binary log events, which record database changes in a binary format for efficiency. Tools like +// mysqldump, mysqlbinlog, and mariadb-binlog read these binary events from log files and output them as base64-encoded +// BINLOG statements for replay. +// +// The BINLOG statement execution is delegated to the BinlogConsumer. The base64-encoded event data is decoded +// and passed to the consumer's ProcessEvent method for processing. This allows integrators like Dolt to handle +// BINLOG statement execution using their binlog event processing infrastructure. +// +// See https://dev.mysql.com/doc/refman/8.4/en/binlog.html for the BINLOG statement specification. +type Binlog struct { + Base64Str string + Consumer binlogreplication.BinlogConsumer +} + +var _ sql.Node = (*Binlog)(nil) +var _ BinlogConsumerCommand = (*Binlog)(nil) + +// NewBinlog creates a new Binlog node. +func NewBinlog(base64Str string) *Binlog { + return &Binlog{ + Base64Str: base64Str, + } +} + +// WithBinlogConsumer implements the BinlogConsumerCommand interface. +func (b *Binlog) WithBinlogConsumer(consumer binlogreplication.BinlogConsumer) sql.Node { + nc := *b + nc.Consumer = consumer + return &nc +} + +func (b *Binlog) String() string { + return "BINLOG" +} + +func (b *Binlog) Resolved() bool { + return true +} + +func (b *Binlog) Schema() sql.Schema { + return types.OkResultSchema +} + +func (b *Binlog) Children() []sql.Node { + return nil +} + +func (b *Binlog) IsReadOnly() bool { + return false +} + +// WithChildren implements the Node interface. +func (b *Binlog) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(b, len(children), 0) + } + return b, nil +} diff --git a/sql/plan/replication_commands.go b/sql/plan/replication_commands.go index a34f9e620a..d053f4af3c 100644 --- a/sql/plan/replication_commands.go +++ b/sql/plan/replication_commands.go @@ -32,6 +32,19 @@ var ErrNoReplicationController = errors.NewKind("no replication controller avail // https://dev.mysql.com/doc/refman/8.0/en/privileges-provided.html#priv_replication-slave-admin const DynamicPrivilege_ReplicationSlaveAdmin = "replication_slave_admin" +// DynamicPrivilege_ReplicationApplier is a dynamic privilege that permits executing BINLOG statements. +// See https://dev.mysql.com/doc/refman/8.0/en/privileges-provided.html#priv_replication-applier +const DynamicPrivilege_ReplicationApplier = "replication_applier" + +// BinlogConsumerCommand represents a SQL statement that requires a BinlogConsumer +// (e.g. BINLOG statement). +type BinlogConsumerCommand interface { + sql.Node + + // WithBinlogConsumer returns a new instance of this command, with the binlog consumer configured. + WithBinlogConsumer(consumer binlogreplication.BinlogConsumer) sql.Node +} + // BinlogReplicaControllerCommand represents a SQL statement that requires a BinlogReplicaController // (e.g. Start Replica, Show Replica Status). type BinlogReplicaControllerCommand interface { @@ -54,6 +67,11 @@ type BinlogPrimaryControllerCommand interface { // ChangeReplicationSource is the plan node for the "CHANGE REPLICATION SOURCE TO" statement. // https://dev.mysql.com/doc/refman/8.0/en/change-replication-source-to.html +// +// TODO: When PRIVILEGE_CHECKS_USER option is specified, validate that the assigned user account has the +// REPLICATION_APPLIER privilege. This validation should happen before the option is passed to the integrator's +// BinlogReplicaController.SetReplicationSourceOptions(). +// See https://github.com/mysql/mysql-server/blob/8.0/sql/rpl_replica.cc change_master_cmd type ChangeReplicationSource struct { ReplicaController binlogreplication.BinlogReplicaController Options []binlogreplication.ReplicationOption diff --git a/sql/planbuilder/auth_default.go b/sql/planbuilder/auth_default.go index 21308f94b8..4b980f8cb1 100644 --- a/sql/planbuilder/auth_default.go +++ b/sql/planbuilder/auth_default.go @@ -24,6 +24,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/mysql_db" + "github.com/dolthub/go-mysql-server/sql/plan" ) // defaultAuthorizationQueryState contains query-specific state for defaultAuthorizationHandler. @@ -131,6 +132,10 @@ func (h defaultAuthorizationHandler) HandleAuth(ctx *sql.Context, aqs sql.Author hasPrivileges = state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{Database: "mysql"}, sql.PrivilegeType_Update)) || state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_CreateUser)) || state.user.User == auth.TargetNames[0] + case ast.AuthType_BINLOG: + hasPrivileges = state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_Super)) || + state.db.UserHasPrivileges(ctx, sql.NewDynamicPrivilegedOperation(plan.DynamicPrivilege_BinlogAdmin)) || + state.db.UserHasPrivileges(ctx, sql.NewDynamicPrivilegedOperation(plan.DynamicPrivilege_ReplicationApplier)) case ast.AuthType_CALL: hasPrivileges, err = h.call(ctx, state, auth) if err != nil { diff --git a/sql/planbuilder/binlog_replication.go b/sql/planbuilder/binlog_replication.go index 4302e1099f..ed913aa3c6 100644 --- a/sql/planbuilder/binlog_replication.go +++ b/sql/planbuilder/binlog_replication.go @@ -35,7 +35,7 @@ func (b *Builder) buildChangeReplicationSource(inScope *scope, n *ast.ChangeRepl convertedOptions = append(convertedOptions, *convertedOption) } repSrc := plan.NewChangeReplicationSource(convertedOptions) - if binCat, ok := b.cat.(binlogreplication.BinlogReplicaCatalog); ok && binCat.HasBinlogReplicaController() { + if binCat, ok := b.cat.(binlogreplication.BinlogReplicaProvider); ok && binCat.HasBinlogReplicaController() { repSrc.ReplicaController = binCat.GetBinlogReplicaController() } outScope.node = repSrc @@ -77,7 +77,7 @@ func (b *Builder) buildChangeReplicationFilter(inScope *scope, n *ast.ChangeRepl convertedOptions = append(convertedOptions, *convertedOption) } changeFilter := plan.NewChangeReplicationFilter(convertedOptions) - if binCat, ok := b.cat.(binlogreplication.BinlogReplicaCatalog); ok && binCat.HasBinlogReplicaController() { + if binCat, ok := b.cat.(binlogreplication.BinlogReplicaProvider); ok && binCat.HasBinlogReplicaController() { changeFilter.ReplicaController = binCat.GetBinlogReplicaController() } outScope.node = changeFilter diff --git a/sql/planbuilder/builder.go b/sql/planbuilder/builder.go index de69783d25..f3b744a3ea 100644 --- a/sql/planbuilder/builder.go +++ b/sql/planbuilder/builder.go @@ -307,7 +307,7 @@ func (b *Builder) buildSubquery(inScope *scope, stmt ast.Statement, subQuery str } outScope = inScope.push() startRep := plan.NewStartReplica() - if binCat, ok := b.cat.(binlogreplication.BinlogReplicaCatalog); ok && binCat.HasBinlogReplicaController() { + if binCat, ok := b.cat.(binlogreplication.BinlogReplicaProvider); ok && binCat.HasBinlogReplicaController() { startRep.ReplicaController = binCat.GetBinlogReplicaController() } outScope.node = startRep @@ -317,7 +317,7 @@ func (b *Builder) buildSubquery(inScope *scope, stmt ast.Statement, subQuery str } outScope = inScope.push() stopRep := plan.NewStopReplica() - if binCat, ok := b.cat.(binlogreplication.BinlogReplicaCatalog); ok && binCat.HasBinlogReplicaController() { + if binCat, ok := b.cat.(binlogreplication.BinlogReplicaProvider); ok && binCat.HasBinlogReplicaController() { stopRep.ReplicaController = binCat.GetBinlogReplicaController() } outScope.node = stopRep @@ -327,7 +327,7 @@ func (b *Builder) buildSubquery(inScope *scope, stmt ast.Statement, subQuery str } outScope = inScope.push() resetRep := plan.NewResetReplica(n.All) - if binCat, ok := b.cat.(binlogreplication.BinlogReplicaCatalog); ok && binCat.HasBinlogReplicaController() { + if binCat, ok := b.cat.(binlogreplication.BinlogReplicaProvider); ok && binCat.HasBinlogReplicaController() { resetRep.ReplicaController = binCat.GetBinlogReplicaController() } outScope.node = resetRep @@ -401,6 +401,16 @@ func (b *Builder) buildSubquery(inScope *scope, stmt ast.Statement, subQuery str return b.buildDeallocate(inScope, n) case ast.InjectedStatement: return b.buildInjectedStatement(inScope, n) + case *ast.Binlog: + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } + outScope = inScope.push() + binlogNode := plan.NewBinlog(n.Base64Str) + if binCat, ok := b.cat.(binlogreplication.BinlogConsumerProvider); ok && binCat.HasBinlogConsumer() { + binlogNode = binlogNode.WithBinlogConsumer(binCat.GetBinlogConsumer()).(*plan.Binlog) + } + outScope.node = binlogNode } return } diff --git a/sql/planbuilder/set.go b/sql/planbuilder/set.go index 9ed6457373..41d5f583d5 100644 --- a/sql/planbuilder/set.go +++ b/sql/planbuilder/set.go @@ -151,6 +151,27 @@ func (b *Builder) setExprsToExpressions(inScope *scope, e ast.SetVarExprs) []sql } } + if sysVar, ok := setVar.(*expression.SystemVar); ok { + if sqlVal, ok := setExpr.Expr.(*ast.SQLVal); ok && sqlVal.Type == ast.IntVal { + switch strings.ToLower(sysVar.Name) { + case "sql_mode": + converted, err := sql.ConvertSqlModeBitmask(sqlVal.Val) + if err != nil { + b.handleErr(err) + } + setExpr.Expr = ast.NewStrVal([]byte(converted)) + case "collation_database", "collation_connection", "collation_server": + converted, err := sql.ConvertCollationID(sqlVal.Val) + if err != nil { + b.handleErr(err) + } + setExpr.Expr = ast.NewStrVal([]byte(converted)) + case "lc_time_names": + setExpr.Expr = ast.NewStrVal(sqlVal.Val) + } + } + } + sysVarType, _ := setVar.Type().(sql.SystemVariableType) innerExpr, ok := b.simplifySetExpr(setExpr.Name, setScope, setExpr.Expr, sysVarType) if !ok { diff --git a/sql/planbuilder/show.go b/sql/planbuilder/show.go index 3669658d00..480fae635e 100644 --- a/sql/planbuilder/show.go +++ b/sql/planbuilder/show.go @@ -84,21 +84,21 @@ func (b *Builder) buildShow(inScope *scope, s *ast.Show) (outScope *scope) { case "binary log status": outScope = inScope.push() showRep := plan.NewShowBinlogStatus() - if binCat, ok := b.cat.(binlogreplication.BinlogPrimaryCatalog); ok && binCat.HasBinlogPrimaryController() { + if binCat, ok := b.cat.(binlogreplication.BinlogPrimaryProvider); ok && binCat.HasBinlogPrimaryController() { showRep.PrimaryController = binCat.GetBinlogPrimaryController() } outScope.node = showRep case "binary logs": outScope = inScope.push() showRep := plan.NewShowBinlogs() - if binCat, ok := b.cat.(binlogreplication.BinlogPrimaryCatalog); ok && binCat.HasBinlogPrimaryController() { + if binCat, ok := b.cat.(binlogreplication.BinlogPrimaryProvider); ok && binCat.HasBinlogPrimaryController() { showRep.PrimaryController = binCat.GetBinlogPrimaryController() } outScope.node = showRep case "replica status": outScope = inScope.push() showRep := plan.NewShowReplicaStatus() - if binCat, ok := b.cat.(binlogreplication.BinlogReplicaCatalog); ok && binCat.HasBinlogReplicaController() { + if binCat, ok := b.cat.(binlogreplication.BinlogReplicaProvider); ok && binCat.HasBinlogReplicaController() { showRep.ReplicaController = binCat.GetBinlogReplicaController() } outScope.node = showRep @@ -107,7 +107,7 @@ func (b *Builder) buildShow(inScope *scope, s *ast.Show) (outScope *scope) { // but uses a schema with different column names so we create the node differently here. outScope = inScope.push() showRep := plan.NewShowSlaveStatus() - if binCat, ok := b.cat.(binlogreplication.BinlogReplicaCatalog); ok && binCat.HasBinlogReplicaController() { + if binCat, ok := b.cat.(binlogreplication.BinlogReplicaProvider); ok && binCat.HasBinlogReplicaController() { showRep.ReplicaController = binCat.GetBinlogReplicaController() } outScope.node = showRep diff --git a/sql/rowexec/binlog.go b/sql/rowexec/binlog.go new file mode 100644 index 0000000000..d0a6479e36 --- /dev/null +++ b/sql/rowexec/binlog.go @@ -0,0 +1,142 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rowexec + +import ( + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "strings" + + "github.com/dolthub/vitess/go/mysql" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/binlogreplication" + "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/types" +) + +// buildBinlog decodes base64 binlog events, parses them into individual events, and delegates processing to the +// BinlogReplicaController. This allows integrators like Dolt to handle BINLOG statement execution using their +// existing binlog replication infrastructure. +// +// The BINLOG statement is used by tools like mysqldump and mysqlbinlog to replay binary log events. The base64-encoded +// event data is decoded, parsed into individual BinlogEvents, and each event is passed to the BinlogReplicaController's +// ConsumeBinlogEvent method for processing. +// +// See https://dev.mysql.com/doc/refman/8.4/en/binlog.html for the BINLOG statement specification. +func (b *BaseBuilder) buildBinlog(ctx *sql.Context, n *plan.Binlog, row sql.Row) (sql.RowIter, error) { + if n.Consumer == nil { + return nil, fmt.Errorf("BINLOG statement requires BinlogConsumer") + } + + var decoded []byte + lines := strings.Split(n.Base64Str, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + block, err := base64.StdEncoding.DecodeString(line) + if err != nil { + return nil, sql.ErrBase64DecodeError.New() + } + + decoded = append(decoded, block...) + } + + return &binlogIter{ + consumer: n.Consumer, + decoded: decoded, + }, nil +} + +// binlogIter processes decoded binlog events one at a time, returning an OkResult when all events are processed. +type binlogIter struct { + consumer binlogreplication.BinlogConsumer + decoded []byte + offset int +} + +var _ sql.RowIter = (*binlogIter)(nil) + +const ( + eventHeaderSize = 19 + eventLengthOffset = 9 +) + +// Next processes one binlog event per call and recursively processes remaining events. +// Only the final call returns OkResult, which bubbles up through the recursive calls. +func (bi *binlogIter) Next(ctx *sql.Context) (sql.Row, error) { + // Check if offset is negative (already returned OkResult) + if bi.offset < 0 { + return nil, io.EOF + } + + // If all events processed, mark as done and return OkResult once + if bi.offset >= len(bi.decoded) { + bi.offset = -1 // Mark as completed + return sql.Row{types.OkResult{}}, nil + } + + // Validate we have enough bytes for the event header + if bi.offset+eventHeaderSize > len(bi.decoded) { + return nil, fmt.Errorf("incomplete event header at offset %d", bi.offset) + } + + // Read the event length from the header + eventLength := binary.LittleEndian.Uint32(bi.decoded[bi.offset+eventLengthOffset : bi.offset+eventLengthOffset+4]) + + // Validate we have the complete event + if bi.offset+int(eventLength) > len(bi.decoded) { + return nil, fmt.Errorf("incomplete event at offset %d: event length %d exceeds buffer", bi.offset, eventLength) + } + + eventBytes := bi.decoded[bi.offset : bi.offset+int(eventLength)] + + // Parse the event using Vitess's binlog event parser + // MariaDB format is backward compatible with MySQL events + event := mysql.NewMariadbBinlogEvent(eventBytes) + + if !event.IsFormatDescription() && !event.IsQuery() && !event.IsTableMap() && + !event.IsWriteRows() && !event.IsUpdateRows() && !event.IsDeleteRows() { + return nil, sql.ErrOnlyFDAndRBREventsAllowedInBinlogStatement.New(event.TypeName()) + } + + // Check that TABLE_MAP and row events have a FORMAT_DESCRIPTION first + if event.IsTableMap() || event.IsWriteRows() || event.IsUpdateRows() || event.IsDeleteRows() { + if !bi.consumer.HasFormatDescription() { + return nil, sql.ErrNoFormatDescriptionEventBeforeBinlogStatement.New(event.TypeName()) + } + } + + // Process this event using the consumer + err := bi.consumer.ProcessEvent(ctx, event) + if err != nil { + return nil, err + } + + bi.offset += int(eventLength) + + // Recursively process next event - final OkResult bubbles up + return bi.Next(ctx) +} + +// Close implements sql.RowIter. +func (bi *binlogIter) Close(ctx *sql.Context) error { + return nil +} diff --git a/sql/rowexec/binlog_test.go b/sql/rowexec/binlog_test.go new file mode 100644 index 0000000000..3240f59462 --- /dev/null +++ b/sql/rowexec/binlog_test.go @@ -0,0 +1,168 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rowexec + +import ( + "encoding/base64" + "encoding/binary" + "io" + "testing" + + "github.com/dolthub/vitess/go/mysql" + "github.com/stretchr/testify/require" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/types" +) + +func TestBuildBinlog_InvalidBase64(t *testing.T) { + builder := &BaseBuilder{} + ctx := sql.NewEmptyContext() + + binlogNode := plan.NewBinlog("invalid!@#$base64") + + _, err := builder.buildBinlog(ctx, binlogNode, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "BinlogConsumer") +} + +func TestBuildBinlog_NoBinlogReplicaController(t *testing.T) { + builder := &BaseBuilder{} + ctx := sql.NewEmptyContext() + + // Create some valid base64 data + eventData := make([]byte, 10) + encoded := base64.StdEncoding.EncodeToString(eventData) + + binlogNode := plan.NewBinlog(encoded) + // Don't set controller - should get error + + _, err := builder.buildBinlog(ctx, binlogNode, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "BinlogConsumer") +} + +// mockBinlogConsumer is a test implementation of BinlogConsumer +type mockBinlogConsumer struct { + consumedEvents []mysql.BinlogEvent + returnError error + hasFormatDesc bool +} + +func (m *mockBinlogConsumer) ProcessEvent(ctx *sql.Context, event mysql.BinlogEvent) error { + m.consumedEvents = append(m.consumedEvents, event) + if event.IsFormatDescription() { + m.hasFormatDesc = true + } + return m.returnError +} + +func (m *mockBinlogConsumer) HasFormatDescription() bool { + return m.hasFormatDesc +} + +func TestBuildBinlog_WithBinlogReplicaController(t *testing.T) { + builder := &BaseBuilder{} + ctx := sql.NewEmptyContext() + + mockConsumer := &mockBinlogConsumer{} + + // Create a minimal valid binlog event (FORMAT_DESCRIPTION_EVENT) + // Event header: timestamp(4) + type(1) + server_id(4) + event_length(4) + next_position(4) + flags(2) + eventData := make([]byte, 19) + eventData[4] = 0x0f + binary.LittleEndian.PutUint32(eventData[9:13], 19) // event length + + encoded := base64.StdEncoding.EncodeToString(eventData) + + binlogNode := plan.NewBinlog(encoded).WithBinlogConsumer(mockConsumer).(*plan.Binlog) + + iter, err := builder.buildBinlog(ctx, binlogNode, nil) + require.NoError(t, err) + require.NotNil(t, iter) + + row, err := iter.Next(ctx) + require.NoError(t, err) + require.NotNil(t, row) + require.Equal(t, types.OkResult{}, row[0]) + + // Verify controller received one event + require.Len(t, mockConsumer.consumedEvents, 1) + + // Next call should return EOF + _, err = iter.Next(ctx) + require.Equal(t, io.EOF, err) +} + +func TestBuildBinlog_MultilineBase64WithController(t *testing.T) { + builder := &BaseBuilder{} + ctx := sql.NewEmptyContext() + + mockConsumer := &mockBinlogConsumer{} + + // Create two minimal events + event1 := make([]byte, 19) + event1[4] = 0x0f // FORMAT_DESCRIPTION_EVENT + binary.LittleEndian.PutUint32(event1[9:13], 19) + + event2 := make([]byte, 19) + event2[4] = 0x02 // QUERY_EVENT + binary.LittleEndian.PutUint32(event2[9:13], 19) + + combined := append(event1, event2...) + part1 := base64.StdEncoding.EncodeToString(combined[:10]) + part2 := base64.StdEncoding.EncodeToString(combined[10:]) + multiline := part1 + "\n" + part2 + + binlogNode := plan.NewBinlog(multiline).WithBinlogConsumer(mockConsumer).(*plan.Binlog) + + iter, err := builder.buildBinlog(ctx, binlogNode, nil) + require.NoError(t, err) + + // Next() processes all events and returns single OkResult + row, err := iter.Next(ctx) + require.NoError(t, err) + require.NotNil(t, row) + require.Equal(t, types.OkResult{}, row[0]) + + require.Len(t, mockConsumer.consumedEvents, 2) + + _, err = iter.Next(ctx) + require.Equal(t, io.EOF, err) +} + +func TestBuildBinlog_ControllerError(t *testing.T) { + builder := &BaseBuilder{} + ctx := sql.NewEmptyContext() + + mockConsumer := &mockBinlogConsumer{ + returnError: sql.ErrUnsupportedFeature.New("test error"), + } + + eventData := make([]byte, 19) + eventData[4] = 0x0f // FORMAT_DESCRIPTION_EVENT + binary.LittleEndian.PutUint32(eventData[9:13], 19) + encoded := base64.StdEncoding.EncodeToString(eventData) + + binlogNode := plan.NewBinlog(encoded).WithBinlogConsumer(mockConsumer).(*plan.Binlog) + + iter, err := builder.buildBinlog(ctx, binlogNode, nil) + require.NoError(t, err) + + _, err = iter.Next(ctx) + require.Error(t, err) + require.Contains(t, err.Error(), "test error") +} diff --git a/sql/rowexec/node_builder.gen.go b/sql/rowexec/node_builder.gen.go index a72dbdb0a0..10b0e81b99 100644 --- a/sql/rowexec/node_builder.gen.go +++ b/sql/rowexec/node_builder.gen.go @@ -64,6 +64,8 @@ func (b *BaseBuilder) buildNodeExecNoAnalyze(ctx *sql.Context, n sql.Node, row s return b.buildUpdateHistogram(ctx, n, row) case *plan.DropHistogram: return b.buildDropHistogram(ctx, n, row) + case *plan.Binlog: + return b.buildBinlog(ctx, n, row) case *plan.ShowBinlogs: return b.buildShowBinlogs(ctx, n, row) case *plan.ShowBinlogStatus: diff --git a/sql/sql_mode.go b/sql/sql_mode.go index 50721111bb..a95673a09c 100644 --- a/sql/sql_mode.go +++ b/sql/sql_mode.go @@ -15,7 +15,9 @@ package sql import ( + "fmt" "sort" + "strconv" "strings" "github.com/dolthub/vitess/go/vt/sqlparser" @@ -24,33 +26,103 @@ import ( const ( SqlModeSessionVar = "SQL_MODE" + RealAsFloat = "REAL_AS_FLOAT" + PipesAsConcat = "PIPES_AS_CONCAT" + ANSIQuotes = "ANSI_QUOTES" + IgnoreSpace = "IGNORE_SPACE" + OnlyFullGroupBy = "ONLY_FULL_GROUP_BY" + NoUnsignedSubtraction = "NO_UNSIGNED_SUBTRACTION" + NoDirInCreate = "NO_DIR_IN_CREATE" + // ANSI mode includes REAL_AS_FLOAT, PIPES_AS_CONCAT, ANSI_QUOTES, IGNORE_SPACE, and ONLY_FULL_GROUP_BY + ANSI = "ANSI" + NoAutoValueOnZero = "NO_AUTO_VALUE_ON_ZERO" + NoBackslashEscapes = "NO_BACKSLASH_ESCAPES" + StrictTransTables = "STRICT_TRANS_TABLES" + StrictAllTables = "STRICT_ALL_TABLES" + NoZeroInDate = "NO_ZERO_IN_DATE" AllowInvalidDates = "ALLOW_INVALID_DATES" - ANSIQuotes = "ANSI_QUOTES" ErrorForDivisionByZero = "ERROR_FOR_DIVISION_BY_ZERO" + // Traditional mode includes STRICT_TRANS_TABLES, STRICT_ALL_TABLES, NO_ZERO_IN_DATE, ERROR_FOR_DIVISION_BY_ZERO, + // and NO_ENGINE_SUBSTITUTION + Traditional = "TRADITIONAL" HighNotPrecedence = "HIGH_NOT_PRECEDENCE" - IgnoreSpaces = "IGNORE_SPACE" - NoAutoValueOnZero = "NO_AUTO_VALUE_ON_ZERO" - NoBackslashEscapes = "NO_BACKSLASH_ESCAPES" - NoDirInCreate = "NO_DIR_IN_CREATE" NoEngineSubstitution = "NO_ENGINE_SUBSTITUTION" - NoUnsignedSubtraction = "NO_UNSIGNED_SUBTRACTION" - NoZeroInDate = "NO_ZERO_IN_DATE" - OnlyFullGroupBy = "ONLY_FULL_GROUP_BY" PadCharToFullLength = "PAD_CHAR_TO_FULL_LENGTH" - PipesAsConcat = "PIPES_AS_CONCAT" - RealAsFloat = "REAL_AS_FLOAT" - StrictTransTables = "STRICT_TRANS_TABLES" - StrictAllTables = "STRICT_ALL_TABLES" TimeTruncateFractional = "TIME_TRUNCATE_FRACTIONAL" +) - // ANSI mode includes REAL_AS_FLOAT, PIPES_AS_CONCAT, ANSI_QUOTES, IGNORE_SPACE, and ONLY_FULL_GROUP_BY - ANSI = "ANSI" - // Traditional mode includes STRICT_TRANS_TABLES, STRICT_ALL_TABLES, NO_ZERO_IN_DATE, ERROR_FOR_DIVISION_BY_ZERO, - // and NO_ENGINE_SUBSTITUTION - Traditional = "TRADITIONAL" - DefaultSqlMode = "NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES" +// Bits for different SQL modes +// https://github.com/mysql/mysql-server/blob/f362272c18e856930c867952a7dd1d840cbdb3b1/sql/system_variables.h#L123-L173 +const ( + modeRealAsFloat = 1 + modePipesAsConcat = 2 + modeAnsiQuotes = 4 + modeIgnoreSpace = 8 + modeOnlyFullGroupBy = 32 + modeNoUnsignedSubtraction = 64 + modeNoDirInCreate = 128 + modeAnsi = 0x40000 + modeNoAutoValueOnZero = modeAnsi * 2 + modeNoBackslashEscapes = modeNoAutoValueOnZero * 2 + modeStrictTransTables = modeNoBackslashEscapes * 2 + modeStrictAllTables = modeStrictTransTables * 2 + modeNoZeroInDate = modeStrictAllTables * 2 + modeNoZeroDate = modeNoZeroInDate * 2 + modeAllowInvalidDates = modeNoZeroDate * 2 + modeErrorForDivisionByZero = modeAllowInvalidDates * 2 + modeTraditional = modeErrorForDivisionByZero * 2 + modeHighNotPrecedence = 1 << 29 + modeNoEngineSubstitution = modeHighNotPrecedence * 2 + modePadCharToFullLength = 1 << 31 + modeTimeTruncateFractional = 1 << 32 + + // modeIgnoredMask contains deprecated/obsolete SQL mode bits that can be safely ignored + // during binlog replication. These modes existed in older MySQL versions but are no longer used. + // See: https://github.com/mysql/mysql-server/blob/trunk/sql/system_variables.h MODE_IGNORED_MASK + modeIgnoredMask = 0x00100 | // was: MODE_POSTGRESQL + 0x00200 | // was: MODE_ORACLE + 0x00400 | // was: MODE_MSSQL + 0x00800 | // was: MODE_DB2 + 0x01000 | // was: MODE_MAXDB + 0x02000 | // was: MODE_NO_KEY_OPTIONS + 0x04000 | // was: MODE_NO_TABLE_OPTIONS + 0x08000 | // was: MODE_NO_FIELD_OPTIONS + 0x10000 | // was: MODE_MYSQL323 + 0x20000 | // was: MODE_MYSQL40 + 0x10000000 // was: MODE_NO_AUTO_CREATE_USER ) +// sqlModeBitMap maps SQL mode bit flags to their string names. +var sqlModeBitMap = map[uint64]string{ + modeRealAsFloat: RealAsFloat, + modePipesAsConcat: PipesAsConcat, + modeAnsiQuotes: ANSIQuotes, + modeIgnoreSpace: IgnoreSpace, + modeOnlyFullGroupBy: OnlyFullGroupBy, + modeNoUnsignedSubtraction: NoUnsignedSubtraction, + modeNoDirInCreate: NoDirInCreate, + modeAnsi: ANSI, + modeNoAutoValueOnZero: NoAutoValueOnZero, + modeNoBackslashEscapes: NoBackslashEscapes, + modeStrictTransTables: StrictTransTables, + modeStrictAllTables: StrictAllTables, + modeNoZeroInDate: NoZeroInDate, + modeAllowInvalidDates: AllowInvalidDates, + modeErrorForDivisionByZero: ErrorForDivisionByZero, + modeTraditional: Traditional, + // Note: modeNoAutoCreateUser is NOT in this map - it's in modeIgnoredMask and filtered out + modeHighNotPrecedence: HighNotPrecedence, + modeNoEngineSubstitution: NoEngineSubstitution, + modePadCharToFullLength: PadCharToFullLength, + modeTimeTruncateFractional: TimeTruncateFractional, +} + +var DefaultSqlMode = strings.Join([]string{ + NoEngineSubstitution, + OnlyFullGroupBy, + StrictTransTables, +}, ",") + var defaultMode *SqlMode func init() { @@ -170,3 +242,57 @@ func (s *SqlMode) ParserOptions() sqlparser.ParserOptions { func (s *SqlMode) String() string { return s.modeString } + +// ConvertSqlModeBitmask converts sql_mode values to their string representation. +func ConvertSqlModeBitmask(val any) (string, error) { + var bitmask uint64 + switch v := val.(type) { + case []byte: + if n, err := strconv.ParseUint(string(v), 10, 64); err == nil { + bitmask = n + } + case int8: + bitmask = uint64(v) + case int16: + bitmask = uint64(v) + case int: + bitmask = uint64(v) + case int32: + bitmask = uint64(v) + case int64: + bitmask = uint64(v) + case uint8: + bitmask = uint64(v) + case uint16: + bitmask = uint64(v) + case uint: + bitmask = uint64(v) + case uint32: + bitmask = uint64(v) + case uint64: + bitmask = v + default: + return fmt.Sprintf("%v", val), nil + } + + bitmask = bitmask &^ modeIgnoredMask + + var modes []string + var matchedBits uint64 + for bit, modeName := range sqlModeBitMap { + if bitmask&bit != 0 { + modes = append(modes, modeName) + matchedBits |= bit + } + } + + if bitmask != 0 && matchedBits != bitmask { + return fmt.Sprintf("%v", val), nil + } + + if len(modes) == 0 { + return "", nil + } + + return strings.Join(modes, ","), nil +} diff --git a/sql/sql_mode_test.go b/sql/sql_mode_test.go index 9011e8c389..3851a850ae 100644 --- a/sql/sql_mode_test.go +++ b/sql/sql_mode_test.go @@ -15,6 +15,7 @@ package sql import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -55,3 +56,85 @@ func TestSqlMode(t *testing.T) { assert.True(t, sqlMode.Strict()) assert.Equal(t, "ONLY_FULL_GROUP_BY,PIPES_AS_CONCAT,STRICT_TRANS_TABLES", sqlMode.String()) } + +func TestConvertSqlModeBitmask(t *testing.T) { + tests := []struct { + input any + expected []string + }{ + {uint64(1411383296), []string{ErrorForDivisionByZero, NoEngineSubstitution, StrictTransTables}}, + {int64(1411383296), []string{ErrorForDivisionByZero, NoEngineSubstitution, StrictTransTables}}, + {modeStrictTransTables | modeErrorForDivisionByZero | modeNoEngineSubstitution | 0x1, []string{StrictTransTables, ErrorForDivisionByZero, NoEngineSubstitution}}, + + {modeRealAsFloat, []string{RealAsFloat}}, + {modePipesAsConcat, []string{PipesAsConcat}}, + {modeAnsiQuotes, []string{ANSIQuotes}}, + {modeIgnoreSpace, []string{IgnoreSpace}}, + {modeOnlyFullGroupBy, []string{OnlyFullGroupBy}}, + {modeNoEngineSubstitution, []string{NoEngineSubstitution}}, + {uint64(modeNoEngineSubstitution), []string{NoEngineSubstitution}}, + + {modeAnsiQuotes | modePipesAsConcat, []string{ANSIQuotes, PipesAsConcat}}, + {modeAnsiQuotes | modeIgnoreSpace, []string{ANSIQuotes, IgnoreSpace}}, + + {modeRealAsFloat | modePipesAsConcat, []string{RealAsFloat, PipesAsConcat}}, + {modeRealAsFloat | modePipesAsConcat | modeAnsiQuotes, []string{RealAsFloat, PipesAsConcat, ANSIQuotes}}, + {modeRealAsFloat | modePipesAsConcat | modeAnsiQuotes | modeIgnoreSpace, []string{RealAsFloat, PipesAsConcat, ANSIQuotes, IgnoreSpace}}, + {modeAnsiQuotes | modeOnlyFullGroupBy, []string{ANSIQuotes, OnlyFullGroupBy}}, + {modeIgnoreSpace | modeOnlyFullGroupBy, []string{IgnoreSpace, OnlyFullGroupBy}}, + + {modeStrictTransTables, []string{StrictTransTables}}, + {modeStrictTransTables | modeAnsiQuotes, []string{StrictTransTables, ANSIQuotes}}, + {modeStrictAllTables, []string{StrictAllTables}}, + {modeNoZeroInDate, []string{NoZeroInDate}}, + {modeAllowInvalidDates, []string{AllowInvalidDates}}, + {modeErrorForDivisionByZero, []string{ErrorForDivisionByZero}}, + {modeNoBackslashEscapes, []string{NoBackslashEscapes}}, + {modeNoAutoValueOnZero, []string{NoAutoValueOnZero}}, + {modeNoUnsignedSubtraction, []string{NoUnsignedSubtraction}}, + {modeNoDirInCreate, []string{NoDirInCreate}}, + {modeHighNotPrecedence, []string{HighNotPrecedence}}, + {modePadCharToFullLength, []string{PadCharToFullLength}}, + + {0x10000000, []string{}}, + {modeStrictTransTables | 0x10000000, []string{StrictTransTables}}, + + {modeNoEngineSubstitution | modeAnsiQuotes, []string{NoEngineSubstitution, ANSIQuotes}}, + {modeNoEngineSubstitution | modeOnlyFullGroupBy, []string{NoEngineSubstitution, OnlyFullGroupBy}}, + {modeStrictTransTables | modeErrorForDivisionByZero | modeNoEngineSubstitution, []string{StrictTransTables, ErrorForDivisionByZero, NoEngineSubstitution}}, + {modeStrictTransTables | modeNoZeroInDate | modeErrorForDivisionByZero, []string{StrictTransTables, NoZeroInDate, ErrorForDivisionByZero}}, + + {uint64(0), []string{}}, + {int(0), []string{}}, + + {int8(4), []string{ANSIQuotes}}, + {int16(4), []string{ANSIQuotes}}, + {int32(4), []string{ANSIQuotes}}, + {uint8(4), []string{ANSIQuotes}}, + {uint16(4), []string{ANSIQuotes}}, + {uint32(4), []string{ANSIQuotes}}, + + {"TRADITIONAL", []string{"TRADITIONAL"}}, + {"ANSI", []string{"ANSI"}}, + {"STRICT_TRANS_TABLES,NO_ZERO_DATE", []string{"STRICT_TRANS_TABLES,NO_ZERO_DATE"}}, + {"", []string{}}, + + {uint64(9999999999), []string{"9999999999"}}, + {"not_a_number", []string{"not_a_number"}}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%T(%v)", tt.input, tt.input), func(t *testing.T) { + result, err := ConvertSqlModeBitmask(tt.input) + assert.NoError(t, err) + + if len(tt.expected) == 0 { + assert.Equal(t, "", result) + } else { + for _, exp := range tt.expected { + assert.Contains(t, result, exp) + } + } + }) + } +} diff --git a/sql/variables/system_variables.go b/sql/variables/system_variables.go index 13f73aa74d..24ec81e7a6 100644 --- a/sql/variables/system_variables.go +++ b/sql/variables/system_variables.go @@ -74,7 +74,7 @@ func (sv *globalSystemVariables) AssignValues(vals map[string]interface{}) error defer sv.mutex.Unlock() for varName, val := range vals { varName = strings.ToLower(varName) - sysVar, ok := systemVars[varName] + sysVar, ok := getSystemVar(varName) if !ok { return sql.ErrUnknownSystemVariable.New(varName) } @@ -104,7 +104,7 @@ func (sv *globalSystemVariables) GetGlobal(name string) (sql.SystemVariable, int sv.mutex.RLock() defer sv.mutex.RUnlock() name = strings.ToLower(name) - v, ok := systemVars[name] + v, ok := getSystemVar(name) if !ok { return nil, nil, false } @@ -141,7 +141,7 @@ func (sv *globalSystemVariables) SetGlobal(ctx *sql.Context, name string, val in sv.mutex.Lock() defer sv.mutex.Unlock() name = strings.ToLower(name) - sysVar, ok := systemVars[name] + sysVar, ok := getSystemVar(name) if !ok { return sql.ErrUnknownSystemVariable.New(name) } @@ -166,19 +166,26 @@ func (sv *globalSystemVariables) GetAllGlobalVariables() map[string]interface{} return m } -// InitSystemVariables resets the systemVars singleton in the sql package +// InitSystemVariables resets the global systemVars singleton in the sql package func InitSystemVariables() { - vars := &globalSystemVariables{ - mutex: &sync.RWMutex{}, - sysVarVals: make(map[string]sql.SystemVarValue, len(systemVars)), + out := &globalSystemVariables{ + mutex: &sync.RWMutex{}, + sysVarVals: make(map[string]sql.SystemVarValue, + len(systemVars)+len(mariadbSystemVars)), } - for _, sysVar := range systemVars { - vars.sysVarVals[sysVar.GetName()] = sql.SystemVarValue{ - Var: sysVar, - Val: sysVar.GetDefault(), + + for _, vars := range []map[string]sql.SystemVariable{ + systemVars, + mariadbSystemVars, + } { + for _, sysVar := range vars { + out.sysVarVals[sysVar.GetName()] = sql.SystemVarValue{ + Var: sysVar, + Val: sysVar.GetDefault(), + } } } - sql.SystemVariables = vars + sql.SystemVariables = out } // init initializes SystemVariables as it functions as a global variable. @@ -192,6 +199,19 @@ func getHostname() string { return hostname } +// getSystemVar looks up a system variable by name in both systemVars and mariadbSystemVars. +// Returns the variable and true if found, or nil and false if not found. +func getSystemVar(name string) (sql.SystemVariable, bool) { + name = strings.ToLower(name) + if v, ok := systemVars[name]; ok { + return v, true + } + if v, ok := mariadbSystemVars[name]; ok { + return v, true + } + return nil, false +} + // systemVars is the internal collection of all MySQL system variables according to the following pages: // https://dev.mysql.com/doc/refman/8.0/en/server-system-variables.html // https://dev.mysql.com/doc/refman/8.0/en/replication-options-gtids.html @@ -3041,6 +3061,67 @@ var systemVars = map[string]sql.SystemVariable{ Type: types.NewSystemBoolType("windowing_use_high_precision"), Default: int8(1), }, + "insert_id": &sql.MysqlSystemVariable{ + Name: "insert_id", + Scope: sql.GetMysqlScope(sql.SystemVariableScope_Session), + Dynamic: true, + SetVarHintApplies: false, + Type: types.NewSystemIntType("insert_id", 0, 9223372036854775807, false), + Default: int64(0), + }, +} + +// mariadbSystemVars contains MariaDB-specific system variables that are not part of MySQL. +// These variables are merged into systemVars during initialization. +var mariadbSystemVars = map[string]sql.SystemVariable{ + "skip_parallel_replication": &sql.MysqlSystemVariable{ + Name: "skip_parallel_replication", + Scope: sql.GetMysqlScope(sql.SystemVariableScope_Session), + Dynamic: true, + SetVarHintApplies: false, + Type: types.NewSystemBoolType("skip_parallel_replication"), + Default: int8(0), + }, + "gtid_domain_id": &sql.MysqlSystemVariable{ + Name: "gtid_domain_id", + Scope: sql.GetMysqlScope(sql.SystemVariableScope_Both), + Dynamic: true, + SetVarHintApplies: false, + Type: types.NewSystemIntType("gtid_domain_id", 0, 4294967295, false), + Default: int64(0), + }, + "gtid_seq_no": &sql.MysqlSystemVariable{ + Name: "gtid_seq_no", + Scope: sql.GetMysqlScope(sql.SystemVariableScope_Session), + Dynamic: true, + SetVarHintApplies: false, + Type: types.NewSystemIntType("gtid_seq_no", 0, 9223372036854775807, false), + Default: int64(0), + }, + "check_constraint_checks": &sql.MysqlSystemVariable{ + Name: "check_constraint_checks", + Scope: sql.GetMysqlScope(sql.SystemVariableScope_Both), + Dynamic: true, + SetVarHintApplies: false, + Type: types.NewSystemBoolType("check_constraint_checks"), + Default: int8(1), + }, + "sql_if_exists": &sql.MysqlSystemVariable{ + Name: "sql_if_exists", + Scope: sql.GetMysqlScope(sql.SystemVariableScope_Both), + Dynamic: true, + SetVarHintApplies: false, + Type: types.NewSystemBoolType("sql_if_exists"), + Default: int8(0), + }, + "system_versioning_insert_history": &sql.MysqlSystemVariable{ + Name: "system_versioning_insert_history", + Scope: sql.GetMysqlScope(sql.SystemVariableScope_Both), + Dynamic: true, + SetVarHintApplies: false, + Type: types.NewSystemBoolType("system_versioning_insert_history"), + Default: int8(0), + }, } // TODO: need to implement SystemDateTime type