Skip to content

Commit ceaa26f

Browse files
author
James Cor
committed
merge
2 parents cd9ddef + e8ce0df commit ceaa26f

18 files changed

+205
-129
lines changed

enginetest/queries/information_schema_queries.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,18 @@ var InfoSchemaScripts = []ScriptTest{
874874
},
875875
},
876876
},
877+
{
878+
Name: "issue 8930: connect to info schema",
879+
SetUpScript: []string{
880+
"use information_schema",
881+
},
882+
Assertions: []ScriptTestAssertion{
883+
{
884+
Query: "SELECT C.COLUMN_NAME AS label, 'connection.column' as \"type\", C.TABLE_NAME AS \"table\", C.DATA_TYPE AS \"dataType\", CAST(C.CHARACTER_MAXIMUM_LENGTH AS UNSIGNED) AS size, CAST(UPPER( CONCAT( C.DATA_TYPE, CASE WHEN C.DATA_TYPE = 'text' THEN '' ELSE ( CASE WHEN C.CHARACTER_MAXIMUM_LENGTH > 0 THEN ( CONCAT('(', C.CHARACTER_MAXIMUM_LENGTH, ')') ) ELSE '' END ) END ) ) AS CHAR CHARACTER SET utf8) AS \"detail\", C.TABLE_CATALOG AS \"catalog\", C.TABLE_SCHEMA AS \"database\", C.TABLE_SCHEMA AS \"schema\", C.COLUMN_DEFAULT AS \"defaultValue\", C.IS_NULLABLE AS \"isNullable\", (CASE WHEN C.COLUMN_KEY = 'PRI' THEN 1 ELSE 0 END) AS \"isPk\", (CASE WHEN KCU.REFERENCED_COLUMN_NAME IS NULL THEN 0 ELSE 1 END) AS \"isFk\" FROM INFORMATION_SCHEMA.COLUMNS AS C LEFT JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS KCU ON ( C.TABLE_NAME = KCU.TABLE_NAME AND C.TABLE_SCHEMA = KCU.TABLE_SCHEMA AND C.TABLE_CATALOG = KCU.TABLE_CATALOG AND C.COLUMN_NAME = KCU.COLUMN_NAME ) JOIN INFORMATION_SCHEMA.TABLES AS T ON C.TABLE_NAME = T.TABLE_NAME AND C.TABLE_SCHEMA = T.TABLE_SCHEMA AND C.TABLE_CATALOG = T.TABLE_CATALOG WHERE C.TABLE_SCHEMA = 'dev' AND C.TABLE_NAME = 'countries' AND C.TABLE_CATALOG = 'def' ORDER BY C.TABLE_NAME, C.ORDINAL_POSITION",
885+
Expected: []sql.Row{},
886+
},
887+
},
888+
},
877889
{
878890
Name: "query does not use optimization rule on LIKE clause because info_schema db charset is utf8mb3",
879891
SetUpScript: []string{

enginetest/server_engine_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ type serverScriptTestAssertion struct {
8383
expectedRows []any
8484

8585
// can't avoid writing custom comparator because of how gosql.Rows.Scan() works
86-
checkRows func(rows *gosql.Rows, expectedRows []any) (bool, error)
86+
checkRows func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error)
8787
}
8888

8989
type serverScriptTest struct {
@@ -108,7 +108,7 @@ func TestServerPreparedStatements(t *testing.T) {
108108
expectedRows: []any{
109109
[]float64{321.4},
110110
},
111-
checkRows: func(rows *gosql.Rows, expectedRows []any) (bool, error) {
111+
checkRows: func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error) {
112112
var i float64
113113
var rowNum int
114114
for rows.Next() {
@@ -133,7 +133,7 @@ func TestServerPreparedStatements(t *testing.T) {
133133
[]float64{213.4},
134134
[]float64{213.4},
135135
},
136-
checkRows: func(rows *gosql.Rows, expectedRows []any) (bool, error) {
136+
checkRows: func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error) {
137137
var i float64
138138
var rowNum int
139139
for rows.Next() {
@@ -197,7 +197,7 @@ func TestServerPreparedStatements(t *testing.T) {
197197
[]uint64{uint64(math.MaxInt64 + 1)},
198198
[]uint64{uint64(math.MaxUint64)},
199199
},
200-
checkRows: func(rows *gosql.Rows, expectedRows []any) (bool, error) {
200+
checkRows: func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error) {
201201
var i uint64
202202
var rowNum int
203203
for rows.Next() {
@@ -247,7 +247,7 @@ func TestServerPreparedStatements(t *testing.T) {
247247
[]int64{int64(-1)},
248248
[]int64{int64(math.MaxInt64)},
249249
},
250-
checkRows: func(rows *gosql.Rows, expectedRows []any) (bool, error) {
250+
checkRows: func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error) {
251251
var i int64
252252
var rowNum int
253253
for rows.Next() {
@@ -275,12 +275,12 @@ func TestServerPreparedStatements(t *testing.T) {
275275
},
276276
assertions: []serverScriptTestAssertion{
277277
{
278-
query: "select * from test where c0 = 2 and c1 = 3;",
278+
query: "select * from test where c0 = 2 and c1 = 3 order by pk;",
279279
expectedRows: []any{
280280
[]uint64{uint64(2), uint64(3), uint64(1)},
281281
[]uint64{uint64(2), uint64(3), uint64(7)},
282282
},
283-
checkRows: func(rows *gosql.Rows, expectedRows []any) (bool, error) {
283+
checkRows: func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error) {
284284
var c0, c1, pk uint64
285285
var rowNum int
286286
for rows.Next() {
@@ -363,7 +363,7 @@ func TestServerPreparedStatements(t *testing.T) {
363363
} else {
364364
require.NoError(t, err)
365365
}
366-
ok, err := assertion.checkRows(rows, assertion.expectedRows)
366+
ok, err := assertion.checkRows(t, rows, assertion.expectedRows)
367367
require.NoError(t, err)
368368
require.True(t, ok)
369369
})

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ require (
66
github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00
77
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71
88
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
9-
github.com/dolthub/vitess v0.0.0-20250303224041-5cc89c183bc4
9+
github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a
1010
github.com/go-kit/kit v0.10.0
1111
github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d
1212
github.com/gocraft/dbr/v2 v2.7.2

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ github.com/dolthub/vitess v0.0.0-20250228011932-c4f6bba87730 h1:GtlMVB7+Z7fZZj7B
6262
github.com/dolthub/vitess v0.0.0-20250228011932-c4f6bba87730/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70=
6363
github.com/dolthub/vitess v0.0.0-20250303224041-5cc89c183bc4 h1:wtS9ZWEyEeYzLCcqdGUo+7i3hAV5MWuY9Z7tYbQa65A=
6464
github.com/dolthub/vitess v0.0.0-20250303224041-5cc89c183bc4/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70=
65+
github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a h1:HIH9g4z+yXr4DIFyT6L5qOIEGJ1zVtlj6baPyHAG4Yw=
66+
github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70=
6567
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
6668
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
6769
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=

server/context.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ type SessionManager struct {
4949
connections map[uint32]*mysql.Conn
5050
lastPid uint64
5151
ctxFactory sql.ContextFactory
52+
// Implements WaitForClosedConnections(), which is only used
53+
// at server shutdown to allow the integrator to ensure that
54+
// no connections are being handled by handlers.
55+
wg sync.WaitGroup
5256
}
5357

5458
// NewSessionManager creates a SessionManager with the given SessionBuilder.
@@ -82,6 +86,13 @@ func (s *SessionManager) nextPid() uint64 {
8286
return s.lastPid
8387
}
8488

89+
// Block the calling thread until all known connections are closed. It
90+
// is an error to call this concurrently while the server might still
91+
// be accepting new connections.
92+
func (s *SessionManager) WaitForClosedConnections() {
93+
s.wg.Wait()
94+
}
95+
8596
// AddConn adds a connection to be tracked by the SessionManager. Should be called as
8697
// soon as possible after the server has accepted the connection. Results in
8798
// the connection being tracked by ProcessList and being available through
@@ -93,6 +104,7 @@ func (s *SessionManager) AddConn(conn *mysql.Conn) {
93104
defer s.mu.Unlock()
94105
s.connections[conn.ConnectionID] = conn
95106
s.processlist.AddConnection(conn.ConnectionID, conn.RemoteAddr().String())
107+
s.wg.Add(1)
96108
}
97109

98110
// NewSession creates a Session for the given connection and saves it to the session pool.
@@ -270,6 +282,7 @@ func (s *SessionManager) KillConnection(connID uint32) error {
270282
func (s *SessionManager) RemoveConn(conn *mysql.Conn) {
271283
s.mu.Lock()
272284
defer s.mu.Unlock()
285+
s.wg.Done()
273286
if cur, ok := s.sessions[conn.ConnectionID]; ok {
274287
sql.SessionEnd(cur)
275288
}

server/extension.go

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,36 @@ import (
2727
sqle "github.com/dolthub/go-mysql-server"
2828
)
2929

30-
func Intercept(h Interceptor) {
31-
inters = append(inters, h)
32-
sort.Slice(inters, func(i, j int) bool { return inters[i].Priority() < inters[j].Priority() })
33-
}
34-
35-
func WithChain() Option {
36-
return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) {
37-
f := DefaultProtocolListenerFunc
38-
DefaultProtocolListenerFunc = func(cfg mysql.ListenerConfig, sel ServerEventListener) (ProtocolListener, error) {
39-
cfg.Handler = buildChain(cfg.Handler)
40-
return f(cfg, sel)
41-
}
42-
}
30+
// InterceptorChain allows an integrator to build a chain of
31+
// |Interceptor| instances which will wrap and intercept the server's
32+
// mysql.Handler.
33+
//
34+
// Example usage:
35+
//
36+
// var ic InterceptorChain
37+
// ic.WithInterceptor(metricsInterceptor)
38+
// ic.WithInterceptor(authInterceptor)
39+
// server, err := NewServer(Config{ ..., Options: []Option{ic.Option()}, ...}, ...)
40+
type InterceptorChain struct {
41+
inters []Interceptor
4342
}
4443

45-
var inters []Interceptor
44+
func (ic *InterceptorChain) WithInterceptor(h Interceptor) {
45+
ic.inters = append(ic.inters, h)
46+
}
4647

47-
func buildChain(h mysql.Handler) mysql.Handler {
48+
func (ic *InterceptorChain) Option() Option {
49+
return func(e *sqle.Engine, sm *SessionManager, handler mysql.Handler) (*sqle.Engine, *SessionManager, mysql.Handler) {
50+
chainHandler := buildChain(handler, ic.inters)
51+
return e, sm, chainHandler
52+
}
53+
}
54+
55+
func buildChain(h mysql.Handler, inters []Interceptor) mysql.Handler {
56+
// XXX: Mutates |inters|
57+
sort.Slice(inters, func(i, j int) bool {
58+
return inters[i].Priority() < inters[j].Priority()
59+
})
4860
var last Chain = h
4961
for i := len(inters) - 1; i >= 0; i-- {
5062
filter := inters[i]
@@ -55,7 +67,6 @@ func buildChain(h mysql.Handler) mysql.Handler {
5567
}
5668

5769
type Interceptor interface {
58-
5970
// Priority returns the priority of the interceptor.
6071
Priority() int
6172

@@ -88,7 +99,6 @@ type Interceptor interface {
8899
}
89100

90101
type Chain interface {
91-
92102
// ComQuery is called when a connection receives a query.
93103
// Note the contents of the query slice may change after
94104
// the first call to callback. So the Handler should not

server/handler.go

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package server
1717
import (
1818
"context"
1919
"encoding/base64"
20+
goerrors "errors"
2021
"fmt"
2122
"io"
2223
"net"
@@ -609,31 +610,32 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,
609610

610611
// resultForDefaultIter reads batches of rows from the iterator
611612
// and writes results into the callback function.
612-
func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter, callback func(*sqltypes.Result, bool) error, resultFields []*querypb.Field, more bool, buf *sql.ByteBuffer) (r *sqltypes.Result, processedAtLeastOneBatch bool, returnErr error) {
613+
func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter, callback func(*sqltypes.Result, bool) error, resultFields []*querypb.Field, more bool, buf *sql.ByteBuffer) (*sqltypes.Result, bool, error) {
613614
defer trace.StartRegion(ctx, "Handler.resultForDefaultIter").End()
614615

615616
eg, ctx := ctx.NewErrgroup()
616-
617-
pan2err := func() {
617+
pan2err := func(err *error) {
618618
if recoveredPanic := recover(); recoveredPanic != nil {
619-
returnErr = fmt.Errorf("handler caught panic: %v", recoveredPanic)
619+
*err = goerrors.Join(*err, fmt.Errorf("handler caught panic: %v", recoveredPanic))
620620
}
621621
}
622-
623622
wg := sync.WaitGroup{}
624623
wg.Add(2)
625624

625+
var r *sqltypes.Result
626+
var processedAtLeastOneBatch bool
627+
626628
// Read rows off the row iterator and send them to the row channel.
627629
iter, projs := GetDeferredProjections(iter)
628630
var rowChan = make(chan sql.Row, 512)
629-
eg.Go(func() error {
630-
defer pan2err()
631+
eg.Go(func() (err error) {
632+
defer pan2err(&err)
631633
defer wg.Done()
632634
defer close(rowChan)
633635
for {
634636
select {
635637
case <-ctx.Done():
636-
return nil
638+
return context.Cause(ctx)
637639
default:
638640
row, err := iter.Next(ctx)
639641
if err == io.EOF {
@@ -651,9 +653,12 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
651653
}
652654
})
653655

656+
// TODO: poll for closed connections should obviously also run even if
657+
// we're doing something with an OK result or a single row result, etc.
658+
// This should be in the caller.
654659
pollCtx, cancelF := ctx.NewSubContext()
655-
eg.Go(func() error {
656-
defer pan2err()
660+
eg.Go(func() (err error) {
661+
defer pan2err(&err)
657662
return h.pollForClosedConnection(pollCtx, c)
658663
})
659664

@@ -676,8 +681,8 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
676681

677682
// Reads rows from the channel, converts them to wire format,
678683
// and calls |callback| to give them to vitess.
679-
eg.Go(func() error {
680-
defer pan2err()
684+
eg.Go(func() (err error) {
685+
defer pan2err(&err)
681686
defer cancelF()
682687
defer wg.Done()
683688
for {
@@ -695,7 +700,7 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
695700

696701
select {
697702
case <-ctx.Done():
698-
return nil
703+
return context.Cause(ctx)
699704
case row, ok := <-rowChan:
700705
if !ok {
701706
return nil
@@ -716,6 +721,9 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
716721
ctx.GetLogger().Tracef("spooling result row %s", outputRow)
717722
r.Rows = append(r.Rows, outputRow)
718723
r.RowsAffected++
724+
if !timer.Stop() {
725+
<-timer.C
726+
}
719727
case <-timer.C:
720728
// TODO: timer should probably go in its own thread, as rowChan is blocking
721729
if h.readTimeout != 0 {
@@ -724,17 +732,14 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
724732
return ErrRowTimeout.New()
725733
}
726734
}
727-
if !timer.Stop() {
728-
<-timer.C
729-
}
730735
timer.Reset(waitTime)
731736
}
732737
})
733738

734739
// Close() kills this PID in the process list,
735740
// wait until all rows have be sent over the wire
736-
eg.Go(func() error {
737-
defer pan2err()
741+
eg.Go(func() (err error) {
742+
defer pan2err(&err)
738743
wg.Wait()
739744
return iter.Close(ctx)
740745
})
@@ -745,9 +750,9 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
745750
if verboseErrorLogging {
746751
fmt.Printf("Err: %+v", err)
747752
}
748-
returnErr = err
753+
return nil, false, err
749754
}
750-
return
755+
return r, processedAtLeastOneBatch, nil
751756
}
752757

753758
// See https://dev.mysql.com/doc/internals/en/status-flags.html

0 commit comments

Comments
 (0)