Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions server/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,15 @@ func (s *SessionManager) getOrCreateSession(ctx context.Context, conn *mysql.Con
return sess, nil
}

// InitSessionDefaultVariable sets a default value to a parameter of a session at start.
func (s *SessionManager) InitSessionDefaultVariable(ctx context.Context, conn *mysql.Conn, name, value string) error {
sess, err := s.getOrCreateSession(ctx, conn)
if err != nil {
return err
}
return sess.InitSessionVariableDefault(s.ctxFactory(ctx, sql.WithSession(sess)), name, value)
}

// NewContextWithQuery creates a new context for the session at the given conn.
func (s *SessionManager) NewContextWithQuery(ctx context.Context, conn *mysql.Conn, query string) (*sql.Context, error) {
sess, err := s.getOrCreateSession(ctx, conn)
Expand Down
49 changes: 45 additions & 4 deletions sql/base_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,21 +159,46 @@ func (s *BaseSession) InitSessionVariable(ctx *Context, sysVarName string, value
return s.setSessVar(ctx, sysVar, value, true)
}

func (s *BaseSession) setSessVar(ctx *Context, sysVar SystemVariable, value interface{}, init bool) error {
// InitSessionVariableDefault implements the Session interface and is used to initialize variables (Including read-only variables)
func (s *BaseSession) InitSessionVariableDefault(ctx *Context, sysVarName string, value interface{}) error {
sysVar, _, ok := SystemVariables.GetGlobal(sysVarName)
if !ok {
return ErrUnknownSystemVariable.New(sysVarName)
}

sysVar.SetDefault(value)
svv, err := sysVar.InitValue(ctx, sysVar.GetDefault(), value, false)
if err != nil {
return err
}

sysVarName = strings.ToLower(sysVarName)
s.systemVars[sysVarName] = svv
if sysVarName == characterSetResultsSysVarName {
s.charset = CharacterSet_Unspecified
}
return nil
}

func (s *BaseSession) setSessVar(ctx *Context, sysVar SystemVariable, newVal interface{}, init bool) error {
var svv SystemVarValue
var err error
sysVarName := strings.ToLower(sysVar.GetName())
var currVal = sysVar.GetDefault()
if ov, ok := s.systemVars[sysVarName]; ok {
currVal = ov.Val
}
if init {
svv, err = sysVar.InitValue(ctx, value, false)
svv, err = sysVar.InitValue(ctx, currVal, newVal, false)
if err != nil {
return err
}
} else {
svv, err = sysVar.SetValue(ctx, value, false)
svv, err = sysVar.SetValue(ctx, currVal, newVal, false)
if err != nil {
return err
}
}
sysVarName := strings.ToLower(sysVar.GetName())
s.systemVars[sysVarName] = svv
if sysVarName == characterSetResultsSysVarName {
s.charset = CharacterSet_Unspecified
Expand Down Expand Up @@ -202,6 +227,22 @@ func (s *BaseSession) GetSessionVariable(ctx *Context, sysVarName string) (inter
return sysVar.Val, nil
}

// GetSessionVariableDefault implements the Session interface.
func (s *BaseSession) GetSessionVariableDefault(ctx *Context, sysVarName string) (interface{}, error) {
sysVarName = strings.ToLower(sysVarName)
sysVar, ok := s.systemVars[sysVarName]
if !ok {
return nil, ErrUnknownSystemVariable.New(sysVarName)
}
// TODO: this is duplicated from within variables.globalSystemVariables, suggesting the need for an interface
if sysType, ok := sysVar.Var.GetType().(SetType); ok {
if sv, ok := sysVar.Var.GetDefault().(uint64); ok {
return sysType.BitsToString(sv)
}
}
return sysVar.Var.GetDefault(), nil
}

// GetUserVariable implements the Session interface.
func (s *BaseSession) GetUserVariable(ctx *Context, varName string) (Type, interface{}, error) {
return s.userVars.GetUserVariable(ctx, varName)
Expand Down
12 changes: 6 additions & 6 deletions sql/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,12 +496,12 @@ type SystemVariable interface {
// InitValue sets value without validation.
// This is used for setting the initial values internally
// using pre-defined variables or for test-purposes.
InitValue(ctx *Context, val any, global bool) (SystemVarValue, error)
InitValue(ctx *Context, currVal, newVal any, global bool) (SystemVarValue, error)
// SetValue sets the value of the sv of given scope, global or session
// It validates setting value of correct scope,
// converts the given value to appropriate value depending on the sv
// and it returns the SystemVarValue with the updated value.
SetValue(ctx *Context, val any, global bool) (SystemVarValue, error)
SetValue(ctx *Context, currVal, newVal any, global bool) (SystemVarValue, error)
// IsReadOnly checks whether the variable is read only.
// It returns false if variable can be set to a value.
IsReadOnly() bool
Expand Down Expand Up @@ -575,8 +575,8 @@ func (m *MysqlSystemVariable) GetDefault() any {
}

// InitValue implements SystemVariable.
func (m *MysqlSystemVariable) InitValue(ctx *Context, val any, global bool) (SystemVarValue, error) {
convertedVal, _, err := m.Type.Convert(ctx, val)
func (m *MysqlSystemVariable) InitValue(ctx *Context, currVal, newVal any, global bool) (SystemVarValue, error) {
convertedVal, _, err := m.Type.Convert(ctx, newVal)
if err != nil {
return SystemVarValue{}, err
}
Expand All @@ -598,7 +598,7 @@ func (m *MysqlSystemVariable) InitValue(ctx *Context, val any, global bool) (Sys
}

// SetValue implements SystemVariable.
func (m *MysqlSystemVariable) SetValue(ctx *Context, val any, global bool) (SystemVarValue, error) {
func (m *MysqlSystemVariable) SetValue(ctx *Context, currVal, newVal any, global bool) (SystemVarValue, error) {
if global && m.Scope.Type == SystemVariableScope_Session {
return SystemVarValue{}, ErrSystemVariableSessionOnly.New(m.Name)
}
Expand All @@ -608,7 +608,7 @@ func (m *MysqlSystemVariable) SetValue(ctx *Context, val any, global bool) (Syst
if !m.Dynamic || m.ValueFunction != nil {
return SystemVarValue{}, ErrSystemVariableReadOnly.New(m.Name)
}
return m.InitValue(ctx, val, global)
return m.InitValue(ctx, currVal, newVal, global)
}

// IsReadOnly implements SystemVariable.
Expand Down
5 changes: 3 additions & 2 deletions sql/planbuilder/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,9 @@ func (b *Builder) simplifySetExpr(name *ast.ColName, varScope ast.SetScope, val

switch varScope {
case ast.SetScope_None, ast.SetScope_Session, ast.SetScope_Global:
_, value, ok := sql.SystemVariables.GetGlobal(varName)
if ok {
// cannot use sql.SystemVariables.GetGlobal as the default value can be defined at session start runtime.
value, err := b.ctx.GetSessionVariableDefault(b.ctx, varName)
if err == nil {
return expression.NewLiteral(value, types.ApproximateTypeFromValue(value)), true
}
err = sql.ErrUnknownSystemVariable.New(varName)
Expand Down
5 changes: 5 additions & 0 deletions sql/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ type Session interface {
Client() Client
// SetClient returns a new session with the given client.
SetClient(Client)
// InitSessionVariableDefault sets this session's default value of the system variable with the given name.
InitSessionVariableDefault(ctx *Context, sysVarName string, value interface{}) error
// SetSessionVariable sets the given system variable to the value given for this session.
SetSessionVariable(ctx *Context, sysVarName string, value interface{}) error
// InitSessionVariable sets the given system variable to the value given for this session and will allow for
Expand All @@ -76,6 +78,9 @@ type Session interface {
// GetSessionVariable returns this session's value of the system variable with the given name.
// To access global scope, use sql.SystemVariables.GetGlobal instead.
GetSessionVariable(ctx *Context, sysVarName string) (interface{}, error)
// GetSessionVariableDefault returns this session's default value of the system variable with the given name.
// To access global scope, use sql.SystemVariables.GetGlobal instead.
GetSessionVariableDefault(ctx *Context, sysVarName string) (interface{}, error)
// GetUserVariable returns this session's value of the user variable with the given name, along with its most
// appropriate type.
GetUserVariable(ctx *Context, varName string) (Type, interface{}, error)
Expand Down
10 changes: 7 additions & 3 deletions sql/variables/system_variables.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (sv *globalSystemVariables) AssignValues(vals map[string]interface{}) error
if !ok {
return sql.ErrUnknownSystemVariable.New(varName)
}
svv, err := sysVar.InitValue(ctx, val, true)
svv, err := sysVar.InitValue(ctx, sysVar.GetDefault(), val, true)
if err != nil {
return err
}
Expand Down Expand Up @@ -137,15 +137,19 @@ func (sv *globalSystemVariables) GetGlobal(name string) (sql.SystemVariable, int
// Only global dynamic variables may be set through this function, as it is intended for use through the SET GLOBAL
// statement. To set session system variables, use the appropriate function on the session context. To set values
// directly (such as when loading persisted values), use AssignValues. Case-insensitive.
func (sv *globalSystemVariables) SetGlobal(ctx *sql.Context, name string, val interface{}) error {
func (sv *globalSystemVariables) SetGlobal(ctx *sql.Context, name string, newVal interface{}) error {
sv.mutex.Lock()
defer sv.mutex.Unlock()
name = strings.ToLower(name)
sysVar, ok := systemVars[name]
if !ok {
return sql.ErrUnknownSystemVariable.New(name)
}
svv, err := sysVar.SetValue(ctx, val, true)
var currVal = sysVar.GetDefault()
if sysVarVal, exists := sv.sysVarVals[name]; exists {
currVal = sysVarVal.Val
}
svv, err := sysVar.SetValue(ctx, currVal, newVal, true)
if err != nil {
return err
}
Expand Down
Loading