Skip to content

Commit b5b430b

Browse files
author
James Cor
committed
merge with main
2 parents 5f1affe + f73a318 commit b5b430b

File tree

16 files changed

+158
-99
lines changed

16 files changed

+158
-99
lines changed

enginetest/queries/information_schema_queries.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,18 @@ var InfoSchemaScripts = []ScriptTest{
874874
},
875875
},
876876
},
877+
{
878+
Name: "issue 8930: connect to info schema",
879+
SetUpScript: []string{
880+
"use information_schema",
881+
},
882+
Assertions: []ScriptTestAssertion{
883+
{
884+
Query: "SELECT C.COLUMN_NAME AS label, 'connection.column' as \"type\", C.TABLE_NAME AS \"table\", C.DATA_TYPE AS \"dataType\", CAST(C.CHARACTER_MAXIMUM_LENGTH AS UNSIGNED) AS size, CAST(UPPER( CONCAT( C.DATA_TYPE, CASE WHEN C.DATA_TYPE = 'text' THEN '' ELSE ( CASE WHEN C.CHARACTER_MAXIMUM_LENGTH > 0 THEN ( CONCAT('(', C.CHARACTER_MAXIMUM_LENGTH, ')') ) ELSE '' END ) END ) ) AS CHAR CHARACTER SET utf8) AS \"detail\", C.TABLE_CATALOG AS \"catalog\", C.TABLE_SCHEMA AS \"database\", C.TABLE_SCHEMA AS \"schema\", C.COLUMN_DEFAULT AS \"defaultValue\", C.IS_NULLABLE AS \"isNullable\", (CASE WHEN C.COLUMN_KEY = 'PRI' THEN 1 ELSE 0 END) AS \"isPk\", (CASE WHEN KCU.REFERENCED_COLUMN_NAME IS NULL THEN 0 ELSE 1 END) AS \"isFk\" FROM INFORMATION_SCHEMA.COLUMNS AS C LEFT JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS KCU ON ( C.TABLE_NAME = KCU.TABLE_NAME AND C.TABLE_SCHEMA = KCU.TABLE_SCHEMA AND C.TABLE_CATALOG = KCU.TABLE_CATALOG AND C.COLUMN_NAME = KCU.COLUMN_NAME ) JOIN INFORMATION_SCHEMA.TABLES AS T ON C.TABLE_NAME = T.TABLE_NAME AND C.TABLE_SCHEMA = T.TABLE_SCHEMA AND C.TABLE_CATALOG = T.TABLE_CATALOG WHERE C.TABLE_SCHEMA = 'dev' AND C.TABLE_NAME = 'countries' AND C.TABLE_CATALOG = 'def' ORDER BY C.TABLE_NAME, C.ORDINAL_POSITION",
885+
Expected: []sql.Row{},
886+
},
887+
},
888+
},
877889
{
878890
Name: "query does not use optimization rule on LIKE clause because info_schema db charset is utf8mb3",
879891
SetUpScript: []string{

enginetest/server_engine_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ type serverScriptTestAssertion struct {
8383
expectedRows []any
8484

8585
// can't avoid writing custom comparator because of how gosql.Rows.Scan() works
86-
checkRows func(rows *gosql.Rows, expectedRows []any) (bool, error)
86+
checkRows func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error)
8787
}
8888

8989
type serverScriptTest struct {
@@ -108,7 +108,7 @@ func TestServerPreparedStatements(t *testing.T) {
108108
expectedRows: []any{
109109
[]float64{321.4},
110110
},
111-
checkRows: func(rows *gosql.Rows, expectedRows []any) (bool, error) {
111+
checkRows: func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error) {
112112
var i float64
113113
var rowNum int
114114
for rows.Next() {
@@ -133,7 +133,7 @@ func TestServerPreparedStatements(t *testing.T) {
133133
[]float64{213.4},
134134
[]float64{213.4},
135135
},
136-
checkRows: func(rows *gosql.Rows, expectedRows []any) (bool, error) {
136+
checkRows: func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error) {
137137
var i float64
138138
var rowNum int
139139
for rows.Next() {
@@ -197,7 +197,7 @@ func TestServerPreparedStatements(t *testing.T) {
197197
[]uint64{uint64(math.MaxInt64 + 1)},
198198
[]uint64{uint64(math.MaxUint64)},
199199
},
200-
checkRows: func(rows *gosql.Rows, expectedRows []any) (bool, error) {
200+
checkRows: func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error) {
201201
var i uint64
202202
var rowNum int
203203
for rows.Next() {
@@ -247,7 +247,7 @@ func TestServerPreparedStatements(t *testing.T) {
247247
[]int64{int64(-1)},
248248
[]int64{int64(math.MaxInt64)},
249249
},
250-
checkRows: func(rows *gosql.Rows, expectedRows []any) (bool, error) {
250+
checkRows: func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error) {
251251
var i int64
252252
var rowNum int
253253
for rows.Next() {
@@ -275,12 +275,12 @@ func TestServerPreparedStatements(t *testing.T) {
275275
},
276276
assertions: []serverScriptTestAssertion{
277277
{
278-
query: "select * from test where c0 = 2 and c1 = 3;",
278+
query: "select * from test where c0 = 2 and c1 = 3 order by pk;",
279279
expectedRows: []any{
280280
[]uint64{uint64(2), uint64(3), uint64(1)},
281281
[]uint64{uint64(2), uint64(3), uint64(7)},
282282
},
283-
checkRows: func(rows *gosql.Rows, expectedRows []any) (bool, error) {
283+
checkRows: func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error) {
284284
var c0, c1, pk uint64
285285
var rowNum int
286286
for rows.Next() {
@@ -363,7 +363,7 @@ func TestServerPreparedStatements(t *testing.T) {
363363
} else {
364364
require.NoError(t, err)
365365
}
366-
ok, err := assertion.checkRows(rows, assertion.expectedRows)
366+
ok, err := assertion.checkRows(t, rows, assertion.expectedRows)
367367
require.NoError(t, err)
368368
require.True(t, ok)
369369
})

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ require (
66
github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00
77
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71
88
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
9-
github.com/dolthub/vitess v0.0.0-20250303224041-5cc89c183bc4
9+
github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a
1010
github.com/go-kit/kit v0.10.0
1111
github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d
1212
github.com/gocraft/dbr/v2 v2.7.2

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ github.com/dolthub/vitess v0.0.0-20250228011932-c4f6bba87730 h1:GtlMVB7+Z7fZZj7B
6262
github.com/dolthub/vitess v0.0.0-20250228011932-c4f6bba87730/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70=
6363
github.com/dolthub/vitess v0.0.0-20250303224041-5cc89c183bc4 h1:wtS9ZWEyEeYzLCcqdGUo+7i3hAV5MWuY9Z7tYbQa65A=
6464
github.com/dolthub/vitess v0.0.0-20250303224041-5cc89c183bc4/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70=
65+
github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a h1:HIH9g4z+yXr4DIFyT6L5qOIEGJ1zVtlj6baPyHAG4Yw=
66+
github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70=
6567
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
6668
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
6769
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=

server/context.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ type SessionManager struct {
4949
connections map[uint32]*mysql.Conn
5050
lastPid uint64
5151
ctxFactory sql.ContextFactory
52+
// Implements WaitForClosedConnections(), which is only used
53+
// at server shutdown to allow the integrator to ensure that
54+
// no connections are being handled by handlers.
55+
wg sync.WaitGroup
5256
}
5357

5458
// NewSessionManager creates a SessionManager with the given SessionBuilder.
@@ -82,6 +86,13 @@ func (s *SessionManager) nextPid() uint64 {
8286
return s.lastPid
8387
}
8488

89+
// Block the calling thread until all known connections are closed. It
90+
// is an error to call this concurrently while the server might still
91+
// be accepting new connections.
92+
func (s *SessionManager) WaitForClosedConnections() {
93+
s.wg.Wait()
94+
}
95+
8596
// AddConn adds a connection to be tracked by the SessionManager. Should be called as
8697
// soon as possible after the server has accepted the connection. Results in
8798
// the connection being tracked by ProcessList and being available through
@@ -93,6 +104,7 @@ func (s *SessionManager) AddConn(conn *mysql.Conn) {
93104
defer s.mu.Unlock()
94105
s.connections[conn.ConnectionID] = conn
95106
s.processlist.AddConnection(conn.ConnectionID, conn.RemoteAddr().String())
107+
s.wg.Add(1)
96108
}
97109

98110
// NewSession creates a Session for the given connection and saves it to the session pool.
@@ -270,6 +282,7 @@ func (s *SessionManager) KillConnection(connID uint32) error {
270282
func (s *SessionManager) RemoveConn(conn *mysql.Conn) {
271283
s.mu.Lock()
272284
defer s.mu.Unlock()
285+
s.wg.Done()
273286
if cur, ok := s.sessions[conn.ConnectionID]; ok {
274287
sql.SessionEnd(cur)
275288
}

server/extension.go

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,36 @@ import (
2727
sqle "github.com/dolthub/go-mysql-server"
2828
)
2929

30-
func Intercept(h Interceptor) {
31-
inters = append(inters, h)
32-
sort.Slice(inters, func(i, j int) bool { return inters[i].Priority() < inters[j].Priority() })
33-
}
34-
35-
func WithChain() Option {
36-
return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) {
37-
f := DefaultProtocolListenerFunc
38-
DefaultProtocolListenerFunc = func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) {
39-
cfg.Handler = buildChain(cfg.Handler)
40-
return f(cfg, sel)
41-
}
42-
}
30+
// InterceptorChain allows an integrator to build a chain of
31+
// |Interceptor| instances which will wrap and intercept the server's
32+
// mysql.Handler.
33+
//
34+
// Example usage:
35+
//
36+
// var ic InterceptorChain
37+
// ic.WithInterceptor(metricsInterceptor)
38+
// ic.WithInterceptor(authInterceptor)
39+
// server, err := NewServer(Config{ ..., Options: []Option{ic.Option()}, ...}, ...)
40+
type InterceptorChain struct {
41+
inters []Interceptor
4342
}
4443

45-
var inters []Interceptor
44+
func (ic *InterceptorChain) WithInterceptor(h Interceptor) {
45+
ic.inters = append(ic.inters, h)
46+
}
4647

47-
func buildChain(h mysql.Handler) mysql.Handler {
48+
func (ic *InterceptorChain) Option() Option {
49+
return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) (*sqle.Engine, *SessionManager, mysql.Handler) {
50+
chainHandler := buildChain(handler, ic.inters)
51+
return e, sm, chainHandler
52+
}
53+
}
54+
55+
func buildChain(h mysql.Handler, inters []Interceptor) mysql.Handler {
56+
// XXX: Mutates |inters|
57+
sort.Slice(inters, func(i, j int) bool {
58+
return inters[i].Priority() < inters[j].Priority()
59+
})
4860
var last Chain = h
4961
for i := len(inters) - 1; i >= 0; i-- {
5062
filter := inters[i]
@@ -55,7 +67,6 @@ func buildChain(h mysql.Handler) mysql.Handler {
5567
}
5668

5769
type Interceptor interface {
58-
5970
// Priority returns the priority of the interceptor.
6071
Priority() int
6172

@@ -88,7 +99,6 @@ type Interceptor interface {
8899
}
89100

90101
type Chain interface {
91-
92102
// ComQuery is called when a connection receives a query.
93103
// Note the contents of the query slice may change after
94104
// the first call to callback. So the Handler should not

server/server.go

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,19 @@ type ProtocolListener interface {
3737
}
3838

3939
// ProtocolListenerFunc returns a ProtocolListener based on the configuration it was given.
40-
type ProtocolListenerFunc func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error)
41-
42-
// DefaultProtocolListenerFunc is the protocol listener, which defaults to Vitess' protocol listener. Changing
43-
// this function will change the protocol listener used when creating all servers. If multiple servers are needed
44-
// with different protocols, then create each server after changing this function. Servers retain the protocol that
45-
// they were created with.
46-
var DefaultProtocolListenerFunc ProtocolListenerFunc = func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) {
47-
return mysql.NewListenerWithConfig(cfg)
40+
type ProtocolListenerFunc func(cfg Config, listenerCfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error)
41+
42+
func MySQLProtocolListenerFactory(cfg Config, listenerCfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) {
43+
vtListener, err := mysql.NewListenerWithConfig(listenerCfg)
44+
if err != nil {
45+
return nil, err
46+
}
47+
if cfg.Version != "" {
48+
vtListener.ServerVersion = cfg.Version
49+
}
50+
vtListener.TLSConfig = cfg.TLSConfig
51+
vtListener.RequireSecureTransport = cfg.RequireSecureTransport
52+
return vtListener, nil
4853
}
4954

5055
type ServerEventListener interface {
@@ -114,10 +119,6 @@ func portInUse(hostPort string) bool {
114119
}
115120

116121
func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handler mysql.Handler, sel ServerEventListener) (*Server, error) {
117-
for _, option := range cfg.Options {
118-
option(e, sm, handler)
119-
}
120-
121122
if cfg.ConnReadTimeout < 0 {
122123
cfg.ConnReadTimeout = 0
123124
}
@@ -128,6 +129,10 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle
128129
cfg.MaxConnections = 0
129130
}
130131

132+
for _, opt := range cfg.Options {
133+
e, sm, handler = opt(e, sm, handler)
134+
}
135+
131136
l := cfg.Listener
132137
var unixSocketInUse error
133138
if l == nil {
@@ -156,19 +161,15 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle
156161
ConnReadBufferSize: mysql.DefaultConnBufferSize,
157162
AllowClearTextWithoutTLS: cfg.AllowClearTextWithoutTLS,
158163
}
159-
protocolListener, err := DefaultProtocolListenerFunc(listenerCfg, sel)
164+
plf := cfg.ProtocolListenerFactory
165+
if plf == nil {
166+
plf = MySQLProtocolListenerFactory
167+
}
168+
protocolListener, err := plf(cfg, listenerCfg, sel)
160169
if err != nil {
161170
return nil, err
162171
}
163172

164-
if vtListener, ok := protocolListener.(*mysql.Listener); ok {
165-
if cfg.Version != "" {
166-
vtListener.ServerVersion = cfg.Version
167-
}
168-
vtListener.TLSConfig = cfg.TLSConfig
169-
vtListener.RequireSecureTransport = cfg.RequireSecureTransport
170-
}
171-
172173
return &Server{
173174
Listener: protocolListener,
174175
handler: handler,

server/server_config.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,9 @@ import (
2323
"go.opentelemetry.io/otel/trace"
2424

2525
gms "github.com/dolthub/go-mysql-server"
26-
sqle "github.com/dolthub/go-mysql-server"
2726
"github.com/dolthub/go-mysql-server/sql"
2827
)
2928

30-
// Option is an option to customize server.
31-
type Option func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler)
32-
3329
// Server is a MySQL server for SQLe engines.
3430
type Server struct {
3531
Listener ProtocolListener
@@ -38,6 +34,9 @@ type Server struct {
3834
Engine *gms.Engine
3935
}
4036

37+
// An option to customize the server.
38+
type Option func(e *gms.Engine, sm *SessionManager, handler mysql.Handler) (*gms.Engine, *SessionManager, mysql.Handler)
39+
4140
// Config for the mysql server.
4241
type Config struct {
4342
// Protocol for the connection.
@@ -82,8 +81,14 @@ type Config struct {
8281
// If true, queries will be logged as base64 encoded strings.
8382
// If false (default behavior), queries will be logged as strings, but newlines and tabs will be replaced with spaces.
8483
EncodeLoggedQuery bool
85-
// Options add additional options to customize the server.
84+
// Options gets a chance to visit and mutate the GMS *Engine,
85+
// *server.SessionManager and the mysql.Handler as the server
86+
// is being initialized, before the ProtocolListener is
87+
// constructed.
8688
Options []Option
89+
// Used to get the ProtocolListener on server start.
90+
// If unset, defaults to MySQLProtocolListenerFactory.
91+
ProtocolListenerFactory ProtocolListenerFunc
8792
}
8893

8994
func (c Config) NewConfig() (Config, error) {

sql/analyzer/analyzer.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package analyzer
1616

1717
import (
1818
"fmt"
19+
"io"
1920
"os"
2021
"reflect"
2122
"runtime/trace"
@@ -176,6 +177,10 @@ func (ab *Builder) RemoveAfterAllRule(id RuleId) *Builder {
176177

177178
var log = logrus.New()
178179

180+
func SetOutput(w io.Writer) {
181+
log.SetOutput(w)
182+
}
183+
179184
func init() {
180185
// TODO: give the option for debug analyzer logging format to match the global one
181186
log.SetFormatter(simpleLogFormatter{})

sql/analyzer/costed_index_scan.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func costedIndexScans(ctx *sql.Context, a *Analyzer, n sql.Node, qFlags *sql.Que
8282
}
8383
}
8484
if iat, ok := rt.UnderlyingTable().(sql.IndexAddressableTable); ok {
85-
return costedIndexLookup(ctx, n, a.Catalog, iat, rt, aliasName, filter.Expression, qFlags)
85+
return costedIndexLookup(ctx, n, a, iat, rt, aliasName, filter.Expression, qFlags)
8686
}
8787
return n, transform.SameTree, nil
8888
})
@@ -123,19 +123,24 @@ func indexSearchableLookup(n sql.Node, rt sql.TableNode, lookup sql.IndexLookup,
123123
return ret, transform.NewTree, nil
124124
}
125125

126-
func costedIndexLookup(ctx *sql.Context, n sql.Node, cat sql.Catalog, iat sql.IndexAddressableTable, rt sql.TableNode, aliasName string, oldFilter sql.Expression, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
126+
func costedIndexLookup(ctx *sql.Context, n sql.Node, a *Analyzer, iat sql.IndexAddressableTable, rt sql.TableNode, aliasName string, oldFilter sql.Expression, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
127127
indexes, err := iat.GetIndexes(ctx)
128128
if err != nil {
129129
return n, transform.SameTree, err
130130
}
131-
ita, _, filters, err := getCostedIndexScan(ctx, cat, rt, indexes, expression.SplitConjunction(oldFilter), qFlags)
131+
ita, stats, filters, err := getCostedIndexScan(ctx, a.Catalog, rt, indexes, expression.SplitConjunction(oldFilter), qFlags)
132132
if err != nil || ita == nil {
133133
return n, transform.SameTree, err
134134
}
135135
var ret sql.Node = ita
136136
if aliasName != "" {
137137
ret = plan.NewTableAlias(aliasName, ret)
138138
}
139+
140+
a.Log("new indexed table: %s/%s/%s", ita.Index().Database(), ita.Index().Table(), ita.Index().ID())
141+
a.Log("index stats cnt: %d", stats.RowCount())
142+
a.Log("index stats histogram: %s", stats.Histogram().DebugString())
143+
139144
// excluded from tree + not included in index scan => filter above scan
140145
if len(filters) > 0 {
141146
ret = plan.NewFilter(expression.JoinAnd(filters...), ret)

0 commit comments

Comments
 (0)