Skip to content

Commit 4957f80

Browse files
committed
add WithSessionVariables to ctx
1 parent 16344ed commit 4957f80

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

sql/base_session.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,22 @@ func (s *BaseSession) GetAllSessionVariables() map[string]interface{} {
117117
return m
118118
}
119119

120+
// GetBaseSession implements the Session interface.
121+
func (s *BaseSession) GetBaseSession() *BaseSession {
122+
return s
123+
}
124+
125+
// SetBaseSession implements the Session interface.
126+
func (s *BaseSession) SetBaseSession(base *BaseSession) {
127+
*s = *base
128+
}
129+
130+
// Copy implements the Session interface.
131+
func (s *BaseSession) Copy() Session {
132+
cpy := *s
133+
return &cpy
134+
}
135+
120136
// SetSessionVariable implements the Session interface.
121137
func (s *BaseSession) SetSessionVariable(ctx *Context, sysVarName string, value interface{}) error {
122138
sysVarName = strings.ToLower(sysVarName)

sql/session.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ type Session interface {
8686
GetUserVariable(ctx *Context, varName string) (Type, interface{}, error)
8787
// GetAllSessionVariables returns a copy of all session variable values.
8888
GetAllSessionVariables() map[string]interface{}
89+
// GetBaseSession returns the BaseSession embedded in this session. For BaseSession itself, it returns itself.
90+
// For sessions that embed BaseSession (like DoltSession), it returns the embedded BaseSession.
91+
GetBaseSession() *BaseSession
92+
// GetSessionVariablesMap returns the system variables map with SystemVarValue structs.
93+
SetBaseSession(*BaseSession)
8994
// GetStatusVariable returns the value of the status variable with session scope with the given name.
9095
// To access global scope, use sql.StatusVariables instead.
9196
GetStatusVariable(ctx *Context, statVarName string) (interface{}, error)
@@ -182,6 +187,8 @@ type Session interface {
182187
// ValidateSession provides integrators a chance to do any custom validation of this session before any query is
183188
// executed in it. For example, Dolt uses this hook to validate that the session's working set is valid.
184189
ValidateSession(ctx *Context) error
190+
// Copy creates a shallow copy of the session.
191+
Copy() Session
185192
}
186193

187194
// PersistableSession supports serializing/deserializing global system variables/
@@ -337,6 +344,49 @@ func WithServices(services Services) ContextOption {
337344
}
338345
}
339346

347+
// WithSessionVariables creates a shallow copy of the session's BaseSession with a new systemVars map, then sets the
348+
// specified variables. The original session's systemVars remain unchanged because the context uses a copied BaseSession
349+
// with an independent systemVars map. If setting a variable fails (e.g., read-only), the error is ignored and processing
350+
// continues with the remaining variables.
351+
func WithSessionVariables(sessVars map[string]interface{}) ContextOption {
352+
return func(ctx *Context) {
353+
if ctx == nil || ctx.Session == nil {
354+
return
355+
}
356+
357+
if len(sessVars) == 0 {
358+
return
359+
}
360+
361+
// Get the BaseSession and create a copy with new systemVars map
362+
baseSess := ctx.Session.GetBaseSession()
363+
if baseSess == nil {
364+
return
365+
}
366+
367+
// Create a shallow copy of the BaseSession with a new systemVars map
368+
newBaseSess := *baseSess
369+
newBaseSess.systemVars = make(map[string]SystemVarValue, len(baseSess.systemVars))
370+
for k, v := range baseSess.systemVars {
371+
newBaseSess.systemVars[k] = v
372+
}
373+
374+
// Copy the session and set the new BaseSession (preserves outer session type)
375+
newSess := ctx.Session.Copy()
376+
newSess.SetBaseSession(&newBaseSess)
377+
ctx.Session = newSess
378+
379+
// Set the modified variables
380+
for k, v := range sessVars {
381+
if err := newSess.SetSessionVariable(ctx, k, v); err != nil {
382+
// If setting a variable fails (e.g., read-only), continue with other variables
383+
// Errors are not returned because ContextOption doesn't support error returns
384+
continue
385+
}
386+
}
387+
}
388+
}
389+
340390
var ctxNowFunc = time.Now
341391
var ctxNowFuncMutex = &sync.Mutex{}
342392

0 commit comments

Comments
 (0)