Skip to content

Commit 14a57e0

Browse files
authored
Merge pull request #2879 from dolthub/aaron/server-protocol-listener-fixup
server: Get rid of globals for setting a protocol listener factory. Get rid of unused, global-ridden and complicated Interceptor and Option functionality.
2 parents d06023d + 1a65d1c commit 14a57e0

File tree

3 files changed

+59
-43
lines changed

3 files changed

+59
-43
lines changed

server/extension.go

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,36 @@ import (
2727
sqle "github.com/dolthub/go-mysql-server"
2828
)
2929

30-
func Intercept(h Interceptor) {
31-
inters = append(inters, h)
32-
sort.Slice(inters, func(i, j int) bool { return inters[i].Priority() < inters[j].Priority() })
33-
}
34-
35-
func WithChain() Option {
36-
return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) {
37-
f := DefaultProtocolListenerFunc
38-
DefaultProtocolListenerFunc = func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) {
39-
cfg.Handler = buildChain(cfg.Handler)
40-
return f(cfg, sel)
41-
}
42-
}
30+
// InterceptorChain allows an integrator to build a chain of
31+
// |Interceptor| instances which will wrap and intercept the server's
32+
// mysql.Handler.
33+
//
34+
// Example usage:
35+
//
36+
// var ic InterceptorChain
37+
// ic.WithInterceptor(metricsInterceptor)
38+
// ic.WithInterceptor(authInterceptor)
39+
// server, err := NewServer(Config{ ..., Options: []Option{ic.Option()}, ...}, ...)
40+
type InterceptorChain struct {
41+
inters []Interceptor
4342
}
4443

45-
var inters []Interceptor
44+
func (ic *InterceptorChain) WithInterceptor(h Interceptor) {
45+
ic.inters = append(ic.inters, h)
46+
}
4647

47-
func buildChain(h mysql.Handler) mysql.Handler {
48+
func (ic *InterceptorChain) Option() Option {
49+
return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) (*sqle.Engine, *SessionManager, mysql.Handler) {
50+
chainHandler := buildChain(handler, ic.inters)
51+
return e, sm, chainHandler
52+
}
53+
}
54+
55+
func buildChain(h mysql.Handler, inters []Interceptor) mysql.Handler {
56+
// XXX: Mutates |inters|
57+
sort.Slice(inters, func(i, j int) bool {
58+
return inters[i].Priority() < inters[j].Priority()
59+
})
4860
var last Chain = h
4961
for i := len(inters) - 1; i >= 0; i-- {
5062
filter := inters[i]
@@ -55,7 +67,6 @@ func buildChain(h mysql.Handler) mysql.Handler {
5567
}
5668

5769
type Interceptor interface {
58-
5970
// Priority returns the priority of the interceptor.
6071
Priority() int
6172

@@ -88,7 +99,6 @@ type Interceptor interface {
8899
}
89100

90101
type Chain interface {
91-
92102
// ComQuery is called when a connection receives a query.
93103
// Note the contents of the query slice may change after
94104
// the first call to callback. So the Handler should not

server/server.go

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,19 @@ type ProtocolListener interface {
3737
}
3838

3939
// ProtocolListenerFunc returns a ProtocolListener based on the configuration it was given.
40-
type ProtocolListenerFunc func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error)
41-
42-
// DefaultProtocolListenerFunc is the protocol listener, which defaults to Vitess' protocol listener. Changing
43-
// this function will change the protocol listener used when creating all servers. If multiple servers are needed
44-
// with different protocols, then create each server after changing this function. Servers retain the protocol that
45-
// they were created with.
46-
var DefaultProtocolListenerFunc ProtocolListenerFunc = func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) {
47-
return mysql.NewListenerWithConfig(cfg)
40+
type ProtocolListenerFunc func(cfg Config, listenerCfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error)
41+
42+
func MySQLProtocolListenerFactory(cfg Config, listenerCfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) {
43+
vtListener, err := mysql.NewListenerWithConfig(listenerCfg)
44+
if err != nil {
45+
return nil, err
46+
}
47+
if cfg.Version != "" {
48+
vtListener.ServerVersion = cfg.Version
49+
}
50+
vtListener.TLSConfig = cfg.TLSConfig
51+
vtListener.RequireSecureTransport = cfg.RequireSecureTransport
52+
return vtListener, nil
4853
}
4954

5055
type ServerEventListener interface {
@@ -114,10 +119,6 @@ func portInUse(hostPort string) bool {
114119
}
115120

116121
func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handler mysql.Handler, sel ServerEventListener) (*Server, error) {
117-
for _, option := range cfg.Options {
118-
option(e, sm, handler)
119-
}
120-
121122
if cfg.ConnReadTimeout < 0 {
122123
cfg.ConnReadTimeout = 0
123124
}
@@ -128,6 +129,10 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle
128129
cfg.MaxConnections = 0
129130
}
130131

132+
for _, opt := range cfg.Options {
133+
e, sm, handler = opt(e, sm, handler)
134+
}
135+
131136
l := cfg.Listener
132137
var unixSocketInUse error
133138
if l == nil {
@@ -156,19 +161,15 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle
156161
ConnReadBufferSize: mysql.DefaultConnBufferSize,
157162
AllowClearTextWithoutTLS: cfg.AllowClearTextWithoutTLS,
158163
}
159-
protocolListener, err := DefaultProtocolListenerFunc(listenerCfg, sel)
164+
plf := cfg.ProtocolListenerFactory
165+
if plf == nil {
166+
plf = MySQLProtocolListenerFactory
167+
}
168+
protocolListener, err := plf(cfg, listenerCfg, sel)
160169
if err != nil {
161170
return nil, err
162171
}
163172

164-
if vtListener, ok := protocolListener.(*mysql.Listener); ok {
165-
if cfg.Version != "" {
166-
vtListener.ServerVersion = cfg.Version
167-
}
168-
vtListener.TLSConfig = cfg.TLSConfig
169-
vtListener.RequireSecureTransport = cfg.RequireSecureTransport
170-
}
171-
172173
return &Server{
173174
Listener: protocolListener,
174175
handler: handler,

server/server_config.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,9 @@ import (
2323
"go.opentelemetry.io/otel/trace"
2424

2525
gms "github.com/dolthub/go-mysql-server"
26-
sqle "github.com/dolthub/go-mysql-server"
2726
"github.com/dolthub/go-mysql-server/sql"
2827
)
2928

30-
// Option is an option to customize server.
31-
type Option func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler)
32-
3329
// Server is a MySQL server for SQLe engines.
3430
type Server struct {
3531
Listener ProtocolListener
@@ -38,6 +34,9 @@ type Server struct {
3834
Engine *gms.Engine
3935
}
4036

37+
// An option to customize the server.
38+
type Option func(e *gms.Engine, sm *SessionManager, handler mysql.Handler) (*gms.Engine, *SessionManager, mysql.Handler)
39+
4140
// Config for the mysql server.
4241
type Config struct {
4342
// Protocol for the connection.
@@ -82,8 +81,14 @@ type Config struct {
8281
// If true, queries will be logged as base64 encoded strings.
8382
// If false (default behavior), queries will be logged as strings, but newlines and tabs will be replaced with spaces.
8483
EncodeLoggedQuery bool
85-
// Options add additional options to customize the server.
84+
// Options gets a chance to visit and mutate the GMS *Engine,
85+
// *server.SessionManager and the mysql.Handler as the server
86+
// is being initialized, before the ProtocolListener is
87+
// constructed.
8688
Options []Option
89+
// Used to get the ProtocolListener on server start.
90+
// If unset, defaults to MySQLProtocolListenerFactory.
91+
ProtocolListenerFactory ProtocolListenerFunc
8792
}
8893

8994
func (c Config) NewConfig() (Config, error) {

0 commit comments

Comments
 (0)