Skip to content

Commit 9dac468

Browse files
authored
proxy: listen on additional addrs (#393)
Signed-off-by: xhe <[email protected]>
1 parent 8d0bf38 commit 9dac468

File tree

8 files changed

+77
-42
lines changed

8 files changed

+77
-42
lines changed

lib/config/proxy.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,6 @@ func (cfg *Config) Clone() *Config {
168168
}
169169

170170
func (cfg *Config) Check() error {
171-
172171
if cfg.Workdir == "" {
173172
d, err := os.Getwd()
174173
if err != nil {

pkg/manager/config/manager.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,16 +136,17 @@ func (e *ConfigManager) Close() error {
136136
e.cancel()
137137
e.cancel = nil
138138
}
139-
if e.wch != nil {
140-
wcherr = e.wch.Close()
141-
e.wch = nil
142-
}
143139
e.sts.Lock()
144140
for _, ch := range e.sts.listeners {
145141
close(ch)
146142
}
147143
e.sts.listeners = nil
148144
e.sts.Unlock()
149145
e.wg.Wait()
146+
// close after all goroutines are done
147+
if e.wch != nil {
148+
wcherr = e.wch.Close()
149+
e.wch = nil
150+
}
150151
return wcherr
151152
}

pkg/manager/infosync/info.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ func (is *InfoSyncer) getTopologyInfo(cfg *config.Config) (*TopologyInfo, error)
160160
s = ""
161161
}
162162
dir := path.Dir(s)
163-
ip, port, err := net.SplitHostPort(cfg.Proxy.Addr)
163+
addrs := strings.Split(cfg.Proxy.Addr, ",")
164+
ip, port, err := net.SplitHostPort(addrs[0])
164165
if err != nil {
165166
return nil, errors.WithStack(err)
166167
}

pkg/metrics/metrics_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import (
2020
// Test that the metrics are pushed or not pushed with different configurations.
2121
func TestPushMetrics(t *testing.T) {
2222
proxyAddr := "0.0.0.0:6000"
23-
labelName := fmt.Sprintf("%s_%s_connections", ModuleProxy, LabelServer)
23+
labelName := fmt.Sprintf("%s_%s_maxprocs", ModuleProxy, LabelServer)
2424
hostname, err := os.Hostname()
2525
require.NoError(t, err)
2626
expectedPath := fmt.Sprintf("/metrics/job/tiproxy/instance/%s_6000", hostname)

pkg/proxy/backend/handshake_handler.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ type ConnContextKey string
1818
const (
1919
ConnContextKeyTLSState ConnContextKey = "tls-state"
2020
ConnContextKeyConnID ConnContextKey = "conn-id"
21+
ConnContextKeyConnAddr ConnContextKey = "conn-addr"
2122
)
2223

2324
type ErrorSource int

pkg/proxy/client/client_conn.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ type ClientConnection struct {
2323
}
2424

2525
func NewClientConnection(logger *zap.Logger, conn net.Conn, frontendTLSConfig *tls.Config, backendTLSConfig *tls.Config,
26-
hsHandler backend.HandshakeHandler, connID uint64, bcConfig *backend.BCConfig) *ClientConnection {
26+
hsHandler backend.HandshakeHandler, connID uint64, addr string, bcConfig *backend.BCConfig) *ClientConnection {
2727
bemgr := backend.NewBackendConnManager(logger.Named("be"), hsHandler, connID, bcConfig)
28+
bemgr.SetValue(backend.ConnContextKeyConnAddr, addr)
2829
opts := make([]pnet.PacketIOption, 0, 2)
2930
opts = append(opts, pnet.WithWrapError(backend.ErrClientConn))
3031
if bcConfig.ProxyProtocol {

pkg/proxy/proxy.go

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package proxy
66
import (
77
"context"
88
"net"
9+
"strings"
910
"sync"
1011
"time"
1112

@@ -37,7 +38,8 @@ type serverState struct {
3738
}
3839

3940
type SQLServer struct {
40-
listener net.Listener
41+
listeners []net.Listener
42+
addrs []string
4143
logger *zap.Logger
4244
certMgr *cert.CertManager
4345
hsHandler backend.HandshakeHandler
@@ -65,9 +67,13 @@ func NewSQLServer(logger *zap.Logger, cfg config.ProxyServer, certMgr *cert.Cert
6567

6668
s.reset(&cfg.ProxyServerOnline)
6769

68-
s.listener, err = net.Listen("tcp", cfg.Addr)
69-
if err != nil {
70-
return nil, err
70+
s.addrs = strings.Split(cfg.Addr, ",")
71+
s.listeners = make([]net.Listener, len(s.addrs))
72+
for i, addr := range s.addrs {
73+
s.listeners[i], err = net.Listen("tcp", addr)
74+
if err != nil {
75+
return nil, err
76+
}
7177
}
7278

7379
return s, nil
@@ -104,31 +110,34 @@ func (s *SQLServer) Run(ctx context.Context, cfgch <-chan *config.Config) {
104110
}
105111
})
106112

107-
s.wg.Run(func() {
108-
for {
109-
select {
110-
case <-ctx.Done():
111-
return
112-
default:
113-
conn, err := s.listener.Accept()
114-
if err != nil {
115-
if errors.Is(err, net.ErrClosed) {
116-
return
113+
for i := range s.listeners {
114+
j := i
115+
s.wg.Run(func() {
116+
for {
117+
select {
118+
case <-ctx.Done():
119+
return
120+
default:
121+
conn, err := s.listeners[j].Accept()
122+
if err != nil {
123+
if errors.Is(err, net.ErrClosed) {
124+
return
125+
}
126+
127+
s.logger.Error("accept failed", zap.Error(err))
128+
continue
117129
}
118130

119-
s.logger.Error("accept failed", zap.Error(err))
120-
continue
131+
s.wg.Run(func() {
132+
util.WithRecovery(func() { s.onConn(ctx, conn, s.addrs[j]) }, nil, s.logger)
133+
})
121134
}
122-
123-
s.wg.Run(func() {
124-
util.WithRecovery(func() { s.onConn(ctx, conn) }, nil, s.logger)
125-
})
126135
}
127-
}
128-
})
136+
})
137+
}
129138
}
130139

131-
func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) {
140+
func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) {
132141
s.mu.Lock()
133142
conns := uint64(len(s.mu.clients))
134143
maxConns := s.mu.maxConnections
@@ -149,9 +158,9 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) {
149158
connID := s.mu.connID
150159
s.mu.connID++
151160
logger := s.logger.With(zap.Uint64("connID", connID), zap.String("client_addr", conn.RemoteAddr().String()),
152-
zap.Bool("proxy-protocol", s.mu.proxyProtocol))
161+
zap.Bool("proxy-protocol", s.mu.proxyProtocol), zap.String("addr", addr))
153162
clientConn := client.NewClientConnection(logger.Named("conn"), conn, s.certMgr.ServerTLS(), s.certMgr.SQLTLS(),
154-
s.hsHandler, connID, &backend.BCConfig{
163+
s.hsHandler, connID, addr, &backend.BCConfig{
155164
ProxyProtocol: s.mu.proxyProtocol,
156165
RequireBackendTLS: s.requireBackendTLS,
157166
HealthyKeepAlive: s.mu.healthyKeepAlive,
@@ -232,8 +241,8 @@ func (s *SQLServer) Close() error {
232241
s.cancelFunc = nil
233242
}
234243
errs := make([]error, 0, 4)
235-
if s.listener != nil {
236-
errs = append(errs, s.listener.Close())
244+
for i := range s.listeners {
245+
errs = append(errs, s.listeners[i].Close())
237246
}
238247

239248
s.mu.RLock()

pkg/proxy/proxy_test.go

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package proxy
66
import (
77
"context"
88
"database/sql"
9+
"fmt"
910
"net"
1011
"strings"
1112
"testing"
@@ -48,13 +49,13 @@ func TestGracefulShutdown(t *testing.T) {
4849
createClientConn := func() *client.ClientConnection {
4950
server.mu.Lock()
5051
go func() {
51-
conn, err := net.Dial("tcp", server.listener.Addr().String())
52+
conn, err := net.Dial("tcp", server.listeners[0].Addr().String())
5253
require.NoError(t, err)
5354
require.NoError(t, conn.Close())
5455
}()
55-
conn, err := server.listener.Accept()
56+
conn, err := server.listeners[0].Accept()
5657
require.NoError(t, err)
57-
clientConn := client.NewClientConnection(lg, conn, nil, nil, hsHandler, 0, &backend.BCConfig{})
58+
clientConn := client.NewClientConnection(lg, conn, nil, nil, hsHandler, 0, "", &backend.BCConfig{})
5859
server.mu.clients[1] = clientConn
5960
server.mu.Unlock()
6061
return clientConn
@@ -107,18 +108,40 @@ func TestGracefulShutdown(t *testing.T) {
107108
}
108109
}
109110

110-
func TestRecoverPanic(t *testing.T) {
111-
lg, text := logger.CreateLoggerForTest(t)
111+
func TestMultiAddr(t *testing.T) {
112+
lg, _ := logger.CreateLoggerForTest(t)
112113
certManager := cert.NewCertManager()
113114
err := certManager.Init(&config.Config{}, lg, nil)
114115
require.NoError(t, err)
115116
server, err := NewSQLServer(lg, config.ProxyServer{
116-
Addr: "0.0.0.0:6000",
117+
Addr: "0.0.0.0:0,0.0.0.0:0",
117118
}, certManager, &panicHsHandler{})
118119
require.NoError(t, err)
119120
server.Run(context.Background(), nil)
120121

121-
mdb, err := sql.Open("mysql", "root@tcp(localhost:6000)/test")
122+
require.Len(t, server.listeners, 2)
123+
for _, listener := range server.listeners {
124+
conn, err := net.Dial("tcp", listener.Addr().String())
125+
require.NoError(t, err)
126+
require.NoError(t, conn.Close())
127+
}
128+
129+
require.NoError(t, server.Close())
130+
certManager.Close()
131+
}
132+
133+
func TestRecoverPanic(t *testing.T) {
134+
lg, text := logger.CreateLoggerForTest(t)
135+
certManager := cert.NewCertManager()
136+
err := certManager.Init(&config.Config{}, lg, nil)
137+
require.NoError(t, err)
138+
server, err := NewSQLServer(lg, config.ProxyServer{}, certManager, &panicHsHandler{})
139+
require.NoError(t, err)
140+
server.Run(context.Background(), nil)
141+
142+
_, port, err := net.SplitHostPort(server.listeners[0].Addr().String())
143+
require.NoError(t, err)
144+
mdb, err := sql.Open("mysql", fmt.Sprintf("root@tcp(localhost:%s)/test", port))
122145
require.NoError(t, err)
123146
// The first connection encounters panic.
124147
require.ErrorContains(t, mdb.Ping(), "invalid connection")

0 commit comments

Comments
 (0)