Skip to content

Commit 5b478d4

Browse files
authored
Merge pull request #2859 from dolthub/aaron/handler-ctxfactory
server: Add a ContextFactory parameter to the handler, giving integrators control over the *sql.Context creation.
2 parents d3abea2 + 9269751 commit 5b478d4

File tree

11 files changed

+66
-67
lines changed

11 files changed

+66
-67
lines changed

enginetest/engine_only_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ func newDatabase() (*sql2.DB, func()) {
10371037
Protocol: "tcp",
10381038
Address: fmt.Sprintf("localhost:%d", port),
10391039
}
1040-
srv, err := server.NewServer(cfg, engine, harness.SessionBuilder(), nil)
1040+
srv, err := server.NewServer(cfg, engine, sql.NewContext, harness.SessionBuilder(), nil)
10411041
if err != nil {
10421042
panic(err)
10431043
}

enginetest/enginetests.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2142,7 +2142,7 @@ func TestUserAuthentication(t *testing.T, h Harness) {
21422142
require.FailNow(t, "harness must implement ServerHarness")
21432143
}
21442144

2145-
s, err := server.NewServer(serverConfig, engine, serverHarness.SessionBuilder(), nil)
2145+
s, err := server.NewServer(serverConfig, engine, sql.NewContext, serverHarness.SessionBuilder(), nil)
21462146
require.NoError(t, err)
21472147
go func() {
21482148
err := s.Start()
@@ -5695,7 +5695,7 @@ func testCharsetCollationWire(t *testing.T, h Harness, sessionBuilder server.Ses
56955695
defer engine.Close()
56965696
engine.EngineAnalyzer().Catalog.MySQLDb.AddRootAccount()
56975697

5698-
s, err := server.NewServer(serverConfig, engine, sessionBuilder, nil)
5698+
s, err := server.NewServer(serverConfig, engine, sql.NewContext, sessionBuilder, nil)
56995699
require.NoError(t, err)
57005700
go func() {
57015701
err := s.Start()
@@ -5811,7 +5811,7 @@ func TestTypesOverWire(t *testing.T, harness ClientHarness, sessionBuilder serve
58115811
Address: fmt.Sprintf("localhost:%d", port),
58125812
MaxConnections: 1000,
58135813
}
5814-
s, err := server.NewServer(serverConfig, engine, sessionBuilder, nil)
5814+
s, err := server.NewServer(serverConfig, engine, sql.NewContext, sessionBuilder, nil)
58155815
require.NoError(t, err)
58165816
go func() {
58175817
err := s.Start()

enginetest/server_engine.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func NewServerQueryEngine(t *testing.T, engine *sqle.Engine, builder server.Sess
6969
Protocol: "tcp",
7070
Address: fmt.Sprintf("%s:%d", address, p),
7171
}
72-
s, err := server.NewServer(config, engine, builder, nil)
72+
s, err := server.NewServer(config, engine, sql.NewContext, builder, nil)
7373
if err != nil {
7474
return nil, err
7575
}

enginetest/server_engine_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func initTestServer(port int) (*server.Server, error) {
4848
sessBuilder := func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
4949
return memory.NewSession(sql.NewBaseSession(), pro), nil
5050
}
51-
s, err := server.NewServer(config, engine, sessBuilder, nil)
51+
s, err := server.NewServer(config, engine, sql.NewContext, sessBuilder, nil)
5252
if err != nil {
5353
return nil, err
5454
}

server/context.go

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@ type SessionManager struct {
4848
sessions map[uint32]sql.Session
4949
connections map[uint32]*mysql.Conn
5050
lastPid uint64
51+
ctxFactory sql.ContextFactory
5152
}
5253

5354
// NewSessionManager creates a SessionManager with the given SessionBuilder.
5455
func NewSessionManager(
56+
ctxFactory sql.ContextFactory,
5557
builder SessionBuilder,
5658
tracer trace.Tracer,
5759
getDbFunc func(ctx *sql.Context, db string) (sql.Database, error),
@@ -69,6 +71,7 @@ func NewSessionManager(
6971
builder: builder,
7072
sessions: make(map[uint32]sql.Session),
7173
connections: make(map[uint32]*mysql.Conn),
74+
ctxFactory: ctxFactory,
7275
}
7376
}
7477

@@ -125,28 +128,27 @@ func (s *SessionManager) NewSession(ctx context.Context, conn *mysql.Conn) error
125128

126129
// SetDB sets the current database of the given connection session.
127130
// If the session does not exist, it creates a new session with given connection.
128-
func (s *SessionManager) SetDB(conn *mysql.Conn, dbName string) error {
129-
sess, err := s.getOrCreateSession(context.Background(), conn)
131+
func (s *SessionManager) SetDB(ctx context.Context, conn *mysql.Conn, dbName string) error {
132+
sess, err := s.getOrCreateSession(ctx, conn)
130133
if err != nil {
131134
return err
132135
}
133136

134137
err = sql.SessionCommandBegin(sess)
135138
if err != nil {
136-
sql.SessionEnd(sess)
137139
return err
138140
}
139141
defer sql.SessionCommandEnd(sess)
140142

141-
ctx := sql.NewContext(context.Background(), sql.WithSession(sess))
142-
ctx, err = s.processlist.BeginOperation(ctx)
143+
sqlCtx := s.ctxFactory(ctx, sql.WithSession(sess))
144+
sqlCtx, err = s.processlist.BeginOperation(sqlCtx)
143145
if err != nil {
144146
return err
145147
}
146-
defer s.processlist.EndOperation(ctx)
148+
defer s.processlist.EndOperation(sqlCtx)
147149
var db sql.Database
148150
if dbName != "" {
149-
db, err = s.getDbFunc(ctx, dbName)
151+
db, err = s.getDbFunc(sqlCtx, dbName)
150152
if err != nil {
151153
return err
152154
}
@@ -157,7 +159,7 @@ func (s *SessionManager) SetDB(conn *mysql.Conn, dbName string) error {
157159
if pdb, ok := db.(mysql_db.PrivilegedDatabase); ok {
158160
db = pdb.Unwrap()
159161
}
160-
err = sess.UseDatabase(ctx, db)
162+
err = sess.UseDatabase(sqlCtx, db)
161163
if err != nil {
162164
return err
163165
}
@@ -200,11 +202,6 @@ func (s *SessionManager) session(conn *mysql.Conn) sql.Session {
200202
return s.sessions[conn.ConnectionID]
201203
}
202204

203-
// NewContext creates a new context for the session at the given conn.
204-
func (s *SessionManager) NewContext(ctx context.Context, conn *mysql.Conn, query string) (*sql.Context, error) {
205-
return s.NewContextWithQuery(ctx, conn, query)
206-
}
207-
208205
func (s *SessionManager) getOrCreateSession(ctx context.Context, conn *mysql.Conn) (sql.Session, error) {
209206
s.mu.Lock()
210207
sess, ok := s.sessions[conn.ConnectionID]
@@ -236,7 +233,7 @@ func (s *SessionManager) NewContextWithQuery(ctx context.Context, conn *mysql.Co
236233

237234
ctx, span := s.tracer.Start(ctx, "query")
238235

239-
context := sql.NewContext(
236+
context := s.ctxFactory(
240237
ctx,
241238
sql.WithSession(sess),
242239
sql.WithTracer(s.tracer),

server/handler.go

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ type Handler struct {
8080
var _ mysql.Handler = (*Handler)(nil)
8181
var _ mysql.ExtendedHandler = (*Handler)(nil)
8282
var _ mysql.BinlogReplicaHandler = (*Handler)(nil)
83-
var _ sql.ContextProvider = (*Handler)(nil)
8483

8584
// NewConnection reports that a new connection has been established.
8685
func (h *Handler) NewConnection(c *mysql.Conn) {
@@ -103,7 +102,7 @@ func (h *Handler) ConnectionAborted(_ *mysql.Conn, _ string) error {
103102

104103
func (h *Handler) ComInitDB(c *mysql.Conn, schemaName string) error {
105104
// SetDB itself handles session and processlist operation lifecycle callbacks.
106-
err := h.sm.SetDB(c, schemaName)
105+
err := h.sm.SetDB(context.Background(), c, schemaName)
107106
if err != nil {
108107
logrus.WithField("database", schemaName).Errorf("unable to process ComInitDB: %s", err.Error())
109108
err = sql.CastSQLError(err)
@@ -202,10 +201,6 @@ func (h *Handler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query str
202201
return analyzed, fields, nil
203202
}
204203

205-
func (h *Handler) NewContext(ctx context.Context, c *mysql.Conn, query string) (*sql.Context, error) {
206-
return h.sm.NewContext(ctx, c, query)
207-
}
208-
209204
func (h *Handler) ComBind(ctx context.Context, c *mysql.Conn, query string, parsedQuery mysql.ParsedQuery, prepare *mysql.PrepareData) (mysql.BoundQuery, []*querypb.Field, error) {
210205
sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query)
211206
if err != nil {
@@ -273,17 +268,19 @@ func (h *Handler) ComResetConnection(c *mysql.Conn) error {
273268
h.maybeReleaseAllLocks(c)
274269
h.e.CloseSession(c.ConnectionID)
275270

271+
ctx := context.Background()
272+
276273
// Create a new session and set the current database
277-
err := h.sm.NewSession(context.Background(), c)
274+
err := h.sm.NewSession(ctx, c)
278275
if err != nil {
279276
return err
280277
}
281278

282-
return h.sm.SetDB(c, db)
279+
return h.sm.SetDB(ctx, c, db)
283280
}
284281

285282
func (h *Handler) ParserOptionsForConnection(c *mysql.Conn) (sqlparser.ParserOptions, error) {
286-
ctx, err := h.sm.NewContext(context.Background(), c, "")
283+
ctx, err := h.sm.NewContextWithQuery(context.Background(), c, "")
287284
if err != nil {
288285
return sqlparser.ParserOptions{}, err
289286
}
@@ -406,7 +403,7 @@ func (h *Handler) doQuery(
406403
qFlags *sql.QueryFlags,
407404
) (remainder string, err error) {
408405
var sqlCtx *sql.Context
409-
sqlCtx, err = h.sm.NewContext(ctx, c, query)
406+
sqlCtx, err = h.sm.NewContextWithQuery(ctx, c, query)
410407
if err != nil {
411408
return "", err
412409
}

0 commit comments

Comments
 (0)