Skip to content

Commit 1a65d1c

Browse files
committed
server: Revive Option and Interceptor, but without the globals.
1 parent 2d4ebe5 commit 1a65d1c

File tree

3 files changed

+205
-0
lines changed

3 files changed

+205
-0
lines changed

server/extension.go

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
// Copyright 2020-2021 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package server
16+
17+
import (
18+
"context"
19+
"sort"
20+
21+
"github.com/dolthub/vitess/go/mysql"
22+
"github.com/dolthub/vitess/go/sqltypes"
23+
querypb "github.com/dolthub/vitess/go/vt/proto/query"
24+
"github.com/dolthub/vitess/go/vt/sqlparser"
25+
ast "github.com/dolthub/vitess/go/vt/sqlparser"
26+
27+
sqle "github.com/dolthub/go-mysql-server"
28+
)
29+
30+
// InterceptorChain allows an integrator to build a chain of
31+
// |Interceptor| instances which will wrap and intercept the server's
32+
// mysql.Handler.
33+
//
34+
// Example usage:
35+
//
36+
// var ic InterceptorChain
37+
// ic.WithInterceptor(metricsInterceptor)
38+
// ic.WithInterceptor(authInterceptor)
39+
// server, err := NewServer(Config{ ..., Options: []Option{ic.Option()}, ...}, ...)
40+
type InterceptorChain struct {
41+
inters []Interceptor
42+
}
43+
44+
func (ic *InterceptorChain) WithInterceptor(h Interceptor) {
45+
ic.inters = append(ic.inters, h)
46+
}
47+
48+
func (ic *InterceptorChain) Option() Option {
49+
return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) (*sqle.Engine, *SessionManager, mysql.Handler) {
50+
chainHandler := buildChain(handler, ic.inters)
51+
return e, sm, chainHandler
52+
}
53+
}
54+
55+
func buildChain(h mysql.Handler, inters []Interceptor) mysql.Handler {
56+
// XXX: Mutates |inters|
57+
sort.Slice(inters, func(i, j int) bool {
58+
return inters[i].Priority() < inters[j].Priority()
59+
})
60+
var last Chain = h
61+
for i := len(inters) - 1; i >= 0; i-- {
62+
filter := inters[i]
63+
next := last
64+
last = &chainInterceptor{i: filter, c: next}
65+
}
66+
return &interceptorHandler{h: h, c: last}
67+
}
68+
69+
type Interceptor interface {
70+
// Priority returns the priority of the interceptor.
71+
Priority() int
72+
73+
// Query is called when a connection receives a query.
74+
// Note the contents of the query slice may change after
75+
// the first call to callback. So the Handler should not
76+
// hang on to the byte slice.
77+
Query(ctx context.Context, chain Chain, c *mysql.Conn, query string, callback func(res *sqltypes.Result, more bool) error) error
78+
79+
// ParsedQuery is called when a connection receives a
80+
// query that has already been parsed. Note the contents
81+
// of the query slice may change after the first call to
82+
// callback. So the Handler should not hang on to the byte
83+
// slice.
84+
ParsedQuery(chain Chain, c *mysql.Conn, query string, parsed sqlparser.Statement, callback func(res *sqltypes.Result, more bool) error) error
85+
86+
// MultiQuery is called when a connection receives a query and the
87+
// client supports MULTI_STATEMENT. It should process the first
88+
// statement in |query| and return the remainder. It will be called
89+
// multiple times until the remainder is |""|.
90+
MultiQuery(ctx context.Context, chain Chain, c *mysql.Conn, query string, callback func(res *sqltypes.Result, more bool) error) (string, error)
91+
92+
// Prepare is called when a connection receives a prepared
93+
// statement query.
94+
Prepare(ctx context.Context, chain Chain, c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error)
95+
96+
// StmtExecute is called when a connection receives a statement
97+
// execute query.
98+
StmtExecute(ctx context.Context, chain Chain, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error
99+
}
100+
101+
type Chain interface {
102+
// ComQuery is called when a connection receives a query.
103+
// Note the contents of the query slice may change after
104+
// the first call to callback. So the Handler should not
105+
// hang on to the byte slice.
106+
ComQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) error
107+
108+
// ComMultiQuery is called when a connection receives a query and the
109+
// client supports MULTI_STATEMENT. It should process the first
110+
// statement in |query| and return the remainder. It will be called
111+
// multiple times until the remainder is |""|.
112+
ComMultiQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) (string, error)
113+
114+
// ComPrepare is called when a connection receives a prepared
115+
// statement query.
116+
ComPrepare(ctx context.Context, c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error)
117+
118+
// ComStmtExecute is called when a connection receives a statement
119+
// execute query.
120+
ComStmtExecute(ctx context.Context, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error
121+
}
122+
123+
type chainInterceptor struct {
124+
i Interceptor
125+
c Chain
126+
}
127+
128+
func (ci *chainInterceptor) ComQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) error {
129+
return ci.i.Query(ctx, ci.c, c, query, callback)
130+
}
131+
132+
func (ci *chainInterceptor) ComMultiQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) (string, error) {
133+
return ci.i.MultiQuery(ctx, ci.c, c, query, callback)
134+
}
135+
136+
func (ci *chainInterceptor) ComPrepare(ctx context.Context, c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) {
137+
return ci.i.Prepare(ctx, ci.c, c, query, prepare)
138+
}
139+
140+
func (ci *chainInterceptor) ComStmtExecute(ctx context.Context, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
141+
return ci.i.StmtExecute(ctx, ci.c, c, prepare, callback)
142+
}
143+
144+
type interceptorHandler struct {
145+
c Chain
146+
h mysql.Handler
147+
}
148+
149+
var _ mysql.Handler = (*interceptorHandler)(nil)
150+
151+
func (ih *interceptorHandler) NewConnection(c *mysql.Conn) {
152+
ih.h.NewConnection(c)
153+
}
154+
155+
func (ih *interceptorHandler) ConnectionClosed(c *mysql.Conn) {
156+
ih.h.ConnectionClosed(c)
157+
}
158+
159+
func (ih *interceptorHandler) ConnectionAborted(c *mysql.Conn, reason string) error {
160+
return ih.h.ConnectionAborted(c, reason)
161+
}
162+
163+
func (ih *interceptorHandler) ComInitDB(c *mysql.Conn, schemaName string) error {
164+
return ih.h.ComInitDB(c, schemaName)
165+
}
166+
167+
func (ih *interceptorHandler) ComQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) error {
168+
return ih.c.ComQuery(ctx, c, query, callback)
169+
}
170+
171+
func (ih *interceptorHandler) ComMultiQuery(ctx context.Context, c *mysql.Conn, query string, callback mysql.ResultSpoolFn) (string, error) {
172+
return ih.c.ComMultiQuery(ctx, c, query, callback)
173+
}
174+
175+
func (ih *interceptorHandler) ComPrepare(ctx context.Context, c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) {
176+
return ih.c.ComPrepare(ctx, c, query, prepare)
177+
}
178+
179+
func (ih *interceptorHandler) ComStmtExecute(ctx context.Context, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
180+
return ih.c.ComStmtExecute(ctx, c, prepare, callback)
181+
}
182+
183+
func (ih *interceptorHandler) WarningCount(c *mysql.Conn) uint16 {
184+
return ih.h.WarningCount(c)
185+
}
186+
187+
func (ih *interceptorHandler) ComResetConnection(c *mysql.Conn) error {
188+
return ih.h.ComResetConnection(c)
189+
}
190+
191+
func (ih *interceptorHandler) ParserOptionsForConnection(c *mysql.Conn) (ast.ParserOptions, error) {
192+
return ih.h.ParserOptionsForConnection(c)
193+
}

server/server.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle
129129
cfg.MaxConnections = 0
130130
}
131131

132+
for _, opt := range cfg.Options {
133+
e, sm, handler = opt(e, sm, handler)
134+
}
135+
132136
l := cfg.Listener
133137
var unixSocketInUse error
134138
if l == nil {

server/server_config.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ type Server struct {
3434
Engine *gms.Engine
3535
}
3636

37+
// An option to customize the server.
38+
type Option func(e *gms.Engine, sm *SessionManager, handler mysql.Handler) (*gms.Engine, *SessionManager, mysql.Handler)
39+
3740
// Config for the mysql server.
3841
type Config struct {
3942
// Protocol for the connection.
@@ -78,6 +81,11 @@ type Config struct {
7881
// If true, queries will be logged as base64 encoded strings.
7982
// If false (default behavior), queries will be logged as strings, but newlines and tabs will be replaced with spaces.
8083
EncodeLoggedQuery bool
84+
// Options gets a chance to visit and mutate the GMS *Engine,
85+
// *server.SessionManager and the mysql.Handler as the server
86+
// is being initialized, before the ProtocolListener is
87+
// constructed.
88+
Options []Option
8189
// Used to get the ProtocolListener on server start.
8290
// If unset, defaults to MySQLProtocolListenerFactory.
8391
ProtocolListenerFactory ProtocolListenerFunc

0 commit comments

Comments
 (0)