diff --git a/sql/base_session.go b/sql/base_session.go index ecb1172119..9e8e4a1193 100644 --- a/sql/base_session.go +++ b/sql/base_session.go @@ -117,6 +117,22 @@ func (s *BaseSession) GetAllSessionVariables() map[string]interface{} { return m } +// GetBaseSession implements the Session interface. +func (s *BaseSession) GetBaseSession() *BaseSession { + return s +} + +// SetBaseSession implements the Session interface. +func (s *BaseSession) SetBaseSession(base *BaseSession) { + *s = *base +} + +// Copy implements the Session interface. +func (s *BaseSession) Copy() Session { + cpy := *s + return &cpy +} + // SetSessionVariable implements the Session interface. func (s *BaseSession) SetSessionVariable(ctx *Context, sysVarName string, value interface{}) error { sysVarName = strings.ToLower(sysVarName) diff --git a/sql/session.go b/sql/session.go index ad655f2cc3..924366a373 100644 --- a/sql/session.go +++ b/sql/session.go @@ -86,6 +86,11 @@ type Session interface { GetUserVariable(ctx *Context, varName string) (Type, interface{}, error) // GetAllSessionVariables returns a copy of all session variable values. GetAllSessionVariables() map[string]interface{} + // GetBaseSession returns the BaseSession embedded in this session. For BaseSession itself, it returns itself. + // For sessions that embed BaseSession (like DoltSession), it returns the embedded BaseSession. + GetBaseSession() *BaseSession + // GetSessionVariablesMap returns the system variables map with SystemVarValue structs. + SetBaseSession(*BaseSession) // GetStatusVariable returns the value of the status variable with session scope with the given name. // To access global scope, use sql.StatusVariables instead. GetStatusVariable(ctx *Context, statVarName string) (interface{}, error) @@ -182,6 +187,8 @@ type Session interface { // ValidateSession provides integrators a chance to do any custom validation of this session before any query is // executed in it. For example, Dolt uses this hook to validate that the session's working set is valid. ValidateSession(ctx *Context) error + // Copy creates a shallow copy of the session. + Copy() Session } // PersistableSession supports serializing/deserializing global system variables/ @@ -337,6 +344,49 @@ func WithServices(services Services) ContextOption { } } +// WithSessionVariables creates a shallow copy of the session's BaseSession with a new systemVars map, then sets the +// specified variables. The original session's systemVars remain unchanged because the context uses a copied BaseSession +// with an independent systemVars map. If setting a variable fails (e.g., read-only), the error is ignored and processing +// continues with the remaining variables. +func WithSessionVariables(sessVars map[string]interface{}) ContextOption { + return func(ctx *Context) { + if ctx == nil || ctx.Session == nil { + return + } + + if len(sessVars) == 0 { + return + } + + // Get the BaseSession and create a copy with new systemVars map + baseSess := ctx.Session.GetBaseSession() + if baseSess == nil { + return + } + + // Create a shallow copy of the BaseSession with a new systemVars map + newBaseSess := *baseSess + newBaseSess.systemVars = make(map[string]SystemVarValue, len(baseSess.systemVars)) + for k, v := range baseSess.systemVars { + newBaseSess.systemVars[k] = v + } + + // Copy the session and set the new BaseSession (preserves outer session type) + newSess := ctx.Session.Copy() + newSess.SetBaseSession(&newBaseSess) + ctx.Session = newSess + + // Set the modified variables + for k, v := range sessVars { + if err := newSess.SetSessionVariable(ctx, k, v); err != nil { + // If setting a variable fails (e.g., read-only), continue with other variables + // Errors are not returned because ContextOption doesn't support error returns + continue + } + } + } +} + var ctxNowFunc = time.Now var ctxNowFuncMutex = &sync.Mutex{}