@@ -37,14 +37,19 @@ type ProtocolListener interface {
37
37
}
38
38
39
39
// ProtocolListenerFunc returns a ProtocolListener based on the configuration it was given.
40
- type ProtocolListenerFunc func (cfg mysql.ListenerConfig , sel ServerEventListener ) (ProtocolListener , error )
41
-
42
- // DefaultProtocolListenerFunc is the protocol listener, which defaults to Vitess' protocol listener. Changing
43
- // this function will change the protocol listener used when creating all servers. If multiple servers are needed
44
- // with different protocols, then create each server after changing this function. Servers retain the protocol that
45
- // they were created with.
46
- var DefaultProtocolListenerFunc ProtocolListenerFunc = func (cfg mysql.ListenerConfig , sel ServerEventListener ) (ProtocolListener , error ) {
47
- return mysql .NewListenerWithConfig (cfg )
40
+ type ProtocolListenerFunc func (cfg Config , listenerCfg mysql.ListenerConfig , sel ServerEventListener ) (ProtocolListener , error )
41
+
42
+ func MySQLProtocolListenerFactory (cfg Config , listenerCfg mysql.ListenerConfig , sel ServerEventListener ) (ProtocolListener , error ) {
43
+ vtListener , err := mysql .NewListenerWithConfig (listenerCfg )
44
+ if err != nil {
45
+ return nil , err
46
+ }
47
+ if cfg .Version != "" {
48
+ vtListener .ServerVersion = cfg .Version
49
+ }
50
+ vtListener .TLSConfig = cfg .TLSConfig
51
+ vtListener .RequireSecureTransport = cfg .RequireSecureTransport
52
+ return vtListener , nil
48
53
}
49
54
50
55
type ServerEventListener interface {
@@ -114,10 +119,6 @@ func portInUse(hostPort string) bool {
114
119
}
115
120
116
121
func newServerFromHandler (cfg Config , e * sqle.Engine , sm * SessionManager , handler mysql.Handler , sel ServerEventListener ) (* Server , error ) {
117
- for _ , option := range cfg .Options {
118
- option (e , sm , handler )
119
- }
120
-
121
122
if cfg .ConnReadTimeout < 0 {
122
123
cfg .ConnReadTimeout = 0
123
124
}
@@ -128,6 +129,10 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle
128
129
cfg .MaxConnections = 0
129
130
}
130
131
132
+ for _ , opt := range cfg .Options {
133
+ e , sm , handler = opt (e , sm , handler )
134
+ }
135
+
131
136
l := cfg .Listener
132
137
var unixSocketInUse error
133
138
if l == nil {
@@ -156,19 +161,15 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle
156
161
ConnReadBufferSize : mysql .DefaultConnBufferSize ,
157
162
AllowClearTextWithoutTLS : cfg .AllowClearTextWithoutTLS ,
158
163
}
159
- protocolListener , err := DefaultProtocolListenerFunc (listenerCfg , sel )
164
+ plf := cfg .ProtocolListenerFactory
165
+ if plf == nil {
166
+ plf = MySQLProtocolListenerFactory
167
+ }
168
+ protocolListener , err := plf (cfg , listenerCfg , sel )
160
169
if err != nil {
161
170
return nil , err
162
171
}
163
172
164
- if vtListener , ok := protocolListener .(* mysql.Listener ); ok {
165
- if cfg .Version != "" {
166
- vtListener .ServerVersion = cfg .Version
167
- }
168
- vtListener .TLSConfig = cfg .TLSConfig
169
- vtListener .RequireSecureTransport = cfg .RequireSecureTransport
170
- }
171
-
172
173
return & Server {
173
174
Listener : protocolListener ,
174
175
handler : handler ,
0 commit comments