Skip to content

Commit a028ca9

Browse files
authored
Merge pull request #2319 from seanlaff/customNetListener
Add net.Listener as server config option
2 parents a9f896c + 7e31fec commit a028ca9

File tree

5 files changed

+129
-15
lines changed

5 files changed

+129
-15
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ require (
2525
golang.org/x/sys v0.12.0
2626
golang.org/x/text v0.6.0
2727
golang.org/x/tools v0.13.0
28+
google.golang.org/grpc v1.53.0
2829
gopkg.in/src-d/go-errors.v1 v1.0.0
2930
gopkg.in/yaml.v3 v3.0.1
3031
)
@@ -38,7 +39,6 @@ require (
3839
github.com/tetratelabs/wazero v1.1.0 // indirect
3940
golang.org/x/mod v0.12.0 // indirect
4041
google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect
41-
google.golang.org/grpc v1.53.0 // indirect
4242
google.golang.org/protobuf v1.28.1 // indirect
4343
gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b // indirect
4444
)

server/server.go

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -143,18 +143,21 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle
143143
cfg.MaxConnections = 0
144144
}
145145

146+
l := cfg.Listener
146147
var unixSocketInUse error
148+
if l == nil {
149+
if portInUse(cfg.Address) {
150+
unixSocketInUse = fmt.Errorf("Port %s already in use.", cfg.Address)
151+
}
147152

148-
if portInUse(cfg.Address) {
149-
unixSocketInUse = fmt.Errorf("Port %s already in use.", cfg.Address)
150-
}
151-
152-
l, err := NewListener(cfg.Protocol, cfg.Address, cfg.Socket)
153-
if err != nil {
154-
if errors.Is(err, UnixSocketInUseError) {
155-
unixSocketInUse = err
156-
} else {
157-
return nil, err
153+
var err error
154+
l, err = NewListener(cfg.Protocol, cfg.Address, cfg.Socket)
155+
if err != nil {
156+
if errors.Is(err, UnixSocketInUseError) {
157+
unixSocketInUse = err
158+
} else {
159+
return nil, err
160+
}
158161
}
159162
}
160163

server/server_config.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package server
1616

1717
import (
1818
"crypto/tls"
19+
"net"
1920
"time"
2021

2122
"github.com/dolthub/vitess/go/mysql"
@@ -43,6 +44,9 @@ type Config struct {
4344
Protocol string
4445
// Address of the server.
4546
Address string
47+
// Custom listener for the mysql server. Use this if you don't want ports or unix sockets to be opened automatically.
48+
// This can be useful in testing by using a pure go net.Conn implementation.
49+
Listener net.Listener
4650
// Tracer to use in the server. By default, a noop tracer will be used if
4751
// no tracer is provided.
4852
Tracer trace.Tracer

server/server_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package server_test
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"net"
7+
"testing"
8+
"time"
9+
10+
vsql "github.com/dolthub/vitess/go/mysql"
11+
"github.com/go-sql-driver/mysql"
12+
"github.com/stretchr/testify/require"
13+
14+
sqle "github.com/dolthub/go-mysql-server"
15+
"github.com/dolthub/go-mysql-server/memory"
16+
"github.com/dolthub/go-mysql-server/server"
17+
gsql "github.com/dolthub/go-mysql-server/sql"
18+
"github.com/dolthub/go-mysql-server/sql/mysql_db"
19+
"google.golang.org/grpc/test/bufconn"
20+
)
21+
22+
// TestSeverCustomListener verifies a caller can provide their own net.Conn implementation for the server to use
23+
func TestSeverCustomListener(t *testing.T) {
24+
dbName := "mydb"
25+
// create a net.Conn thats based on a golang buffer
26+
buffer := 1024
27+
listener := bufconn.Listen(buffer)
28+
29+
// create the memory database
30+
memdb := memory.NewDatabase(dbName)
31+
pro := memory.NewDBProvider(memdb)
32+
engine := sqle.NewDefault(pro)
33+
34+
// server config with custom listener
35+
cfg := server.Config{Listener: listener}
36+
// since we're using a memory db, we can't rely on server.DefaultSessionBuilder as it causes panics, so explicitly build a memorySessionBuilder
37+
sessionBuilder := func(ctx context.Context, c *vsql.Conn, addr string) (gsql.Session, error) {
38+
host := ""
39+
user := ""
40+
mysqlConnectionUser, ok := c.UserData.(mysql_db.MysqlConnectionUser)
41+
if ok {
42+
host = mysqlConnectionUser.Host
43+
user = mysqlConnectionUser.User
44+
}
45+
client := gsql.Client{Address: host, User: user, Capabilities: c.Capabilities}
46+
return memory.NewSession(gsql.NewBaseSessionWithClientServer(addr, client, c.ConnectionID), pro), nil
47+
}
48+
s, err := server.NewServer(cfg, engine, sessionBuilder, nil)
49+
require.NoError(t, err)
50+
51+
networkName := "testNetwork"
52+
// wire up go-mysql-driver to the listener
53+
mysql.RegisterDialContext(networkName, func(ctx context.Context, addr string) (net.Conn, error) {
54+
return listener.DialContext(ctx)
55+
})
56+
driver, err := mysql.NewConnector(&mysql.Config{
57+
DBName: dbName,
58+
Addr: "bufconn",
59+
Net: networkName,
60+
Passwd: "",
61+
User: "root",
62+
AllowNativePasswords: true,
63+
})
64+
require.NoError(t, err)
65+
66+
// start go-mysql-server
67+
go func() {
68+
err := s.Start()
69+
require.NoError(t, err)
70+
}()
71+
72+
// open the db, ping it, and run some execs/queries
73+
db := sql.OpenDB(driver)
74+
75+
var pingErr error
76+
for i := 0; i < 3; i++ {
77+
if pingErr = db.Ping(); pingErr == nil {
78+
break
79+
}
80+
time.Sleep(time.Second)
81+
}
82+
require.NoError(t, pingErr)
83+
84+
_, err = db.Exec("CREATE TABLE table1 (id int)")
85+
require.NoError(t, err)
86+
87+
row := db.QueryRow("SHOW TABLES")
88+
var tableName string
89+
err = row.Scan(&tableName)
90+
require.NoError(t, err)
91+
if tableName != "table1" {
92+
t.Fatalf("expected to find table1, but got %s", tableName)
93+
}
94+
}

sql/mysql_db/mysql_db.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -769,9 +769,14 @@ func (db *MySQLDb) AuthMethod(user, addr string) (string, error) {
769769
} else {
770770
splitHost, _, err := net.SplitHostPort(addr)
771771
if err != nil {
772-
return "", err
772+
if err.(*net.AddrError).Err == "missing port in address" {
773+
host = addr
774+
} else {
775+
return "", err
776+
}
777+
} else {
778+
host = splitHost
773779
}
774-
host = splitHost
775780
}
776781

777782
rd := db.Reader()
@@ -801,7 +806,11 @@ func (db *MySQLDb) ValidateHash(salt []byte, user string, authResponse []byte, a
801806
} else {
802807
host, _, err = net.SplitHostPort(addr.String())
803808
if err != nil {
804-
return nil, err
809+
if err.(*net.AddrError).Err == "missing port in address" {
810+
host = addr.String()
811+
} else {
812+
return nil, err
813+
}
805814
}
806815
}
807816

@@ -837,7 +846,11 @@ func (db *MySQLDb) Negotiate(c *mysql.Conn, user string, addr net.Addr) (mysql.G
837846
} else {
838847
host, _, err = net.SplitHostPort(addr.String())
839848
if err != nil {
840-
return nil, err
849+
if err.(*net.AddrError).Err == "missing port in address" {
850+
host = addr.String()
851+
} else {
852+
return nil, err
853+
}
841854
}
842855
}
843856

0 commit comments

Comments
 (0)