Skip to content

Commit f670322

Browse files
committed
concurrent Serve calls
1 parent 202d132 commit f670322

File tree

4 files changed

+91
-17
lines changed

4 files changed

+91
-17
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ supports multiple concurrent `Accept()` calls, allowing you to reverse-proxy a s
1919

2020
## Server
2121

22+
The server can listen on multiple listeners concurrently.
23+
2224
The server provides two abstractions to customize it's behavior.
2325

2426
The `Authenticator` interface allows custom authentication methods, and comes with implementations for

server/listener.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ type listener struct {
1616

1717
func (l *listener) Close() (err error) {
1818
if refs := l.refs.Add(-1); refs < 1 {
19-
died := int64(time.Since(l.srv.Started))
19+
died := int64(time.Since(l.srv.started))
2020
l.died.Store(died)
2121
_ = l.srv.Debug && l.srv.LogDebug("listener deref", "key", l.key, "refs", refs, "died", died)
2222
}

server/server.go

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,13 @@ import (
55
"io"
66
"net"
77
"sync"
8-
"sync/atomic"
98
"time"
109

1110
"github.com/linkdata/socks5"
1211
)
1312

1413
// Server is a SOCKS5 proxy server.
1514
type Server struct {
16-
Started time.Time // time when Server.Serve() was called
17-
1815
// List of authentication providers. If nil, uses NoAuthAuthenticator.
1916
// Order matters; they are tried in the given order.
2017
Authenticators []Authenticator
@@ -26,9 +23,10 @@ type Server struct {
2623
Logger socks5.Logger // If not nil, use this Logger (compatible with log/slog)
2724
Debug bool // If true, output debug logging using Logger.Info
2825

29-
closed atomic.Bool
3026
mu sync.Mutex // protects following
27+
serving int
3128
listeners map[string]*listener
29+
started time.Time // time when Server.Serve() was called
3230
}
3331

3432
func listenKey(client net.Conn, address string) (key string) {
@@ -47,7 +45,7 @@ func listenKey(client net.Conn, address string) (key string) {
4745

4846
func (s *Server) getListener(ctx context.Context, client net.Conn, bindaddress string) (nl net.Listener, err error) {
4947
err = net.ErrClosed
50-
if !s.closed.Load() {
48+
if s.Serving() > 0 {
5149
err = nil
5250
key := listenKey(client, bindaddress)
5351
var lc net.ListenConfig
@@ -113,9 +111,10 @@ func (s *Server) maybeLogError(err error, msg string, keyvaluepairs ...any) {
113111
}
114112

115113
func (s *Server) close() {
116-
if !s.closed.Swap(true) {
117-
s.mu.Lock()
118-
defer s.mu.Unlock()
114+
s.mu.Lock()
115+
defer s.mu.Unlock()
116+
s.serving--
117+
if s.serving < 1 {
119118
for _, l := range s.listeners {
120119
_ = s.Debug && s.LogDebug("Server.close(): listener stop", "address", l.key)
121120
l.refs.Store(0)
@@ -125,14 +124,26 @@ func (s *Server) close() {
125124
}
126125
}
127126

127+
// Serving returns the number of active calls to Serve()
128+
func (s *Server) Serving() (n int) {
129+
s.mu.Lock()
130+
n = s.serving
131+
s.mu.Unlock()
132+
return
133+
}
134+
128135
// Serve accepts and handles incoming connections on the given listener.
129136
func (s *Server) Serve(ctx context.Context, l net.Listener) (err error) {
130-
defer l.Close()
131137
defer s.close()
132-
s.Started = time.Now()
138+
s.mu.Lock()
139+
s.serving++
140+
if s.listeners == nil {
141+
s.listeners = make(map[string]*listener)
142+
s.started = time.Now()
143+
}
144+
s.mu.Unlock()
133145
errchan := make(chan error, 1)
134146
s.LogInfo("listening", "addr", l.Addr())
135-
s.listeners = make(map[string]*listener)
136147
go s.listenerMaintenance(ctx)
137148
go s.listen(ctx, errchan, l)
138149
select {
@@ -145,7 +156,7 @@ func (s *Server) Serve(ctx context.Context, l net.Listener) (err error) {
145156
func (s *Server) listenerCleanup() {
146157
s.mu.Lock()
147158
defer s.mu.Unlock()
148-
deadline := int64(time.Since(s.Started) - socks5.ListenerTimeout)
159+
deadline := int64(time.Since(s.started) - socks5.ListenerTimeout)
149160
for k, l := range s.listeners {
150161
if refs := l.refs.Load(); refs < 1 {
151162
if died := l.died.Load(); died < deadline {

server/server_test.go

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"log/slog"
99
"net"
1010
"testing"
11+
"time"
1112

1213
"github.com/linkdata/socks5"
1314
"github.com/linkdata/socks5/client"
@@ -54,12 +55,72 @@ func TestServer_Logging(t *testing.T) {
5455
t.Fatal(err)
5556
}
5657
defer listen.Close()
57-
proxy := server.Server{Logger: slog.Default(), Debug: true}
58-
proxy.LogDebug("debug")
59-
proxy.LogInfo("info")
60-
proxy.LogError("error")
58+
srv := server.Server{Logger: slog.Default(), Debug: true}
59+
srv.LogDebug("debug")
60+
srv.LogInfo("info")
61+
srv.LogError("error")
6162
}
6263

6364
func TestServer_DialerSelector(t *testing.T) {
6465
socks5test.InvalidCommand(t, srvfn, clifn)
6566
}
67+
68+
func TestServer_Serve_CancelContext(t *testing.T) {
69+
ctx, cancel := context.WithCancel(context.Background())
70+
defer cancel()
71+
72+
listen, err := net.Listen("tcp", ":0")
73+
if err != nil {
74+
t.Fatal(err)
75+
}
76+
defer listen.Close()
77+
srv := server.Server{Logger: slog.Default(), Debug: true}
78+
doneCh := make(chan struct{})
79+
go func() {
80+
defer close(doneCh)
81+
srv.Serve(ctx, listen)
82+
}()
83+
cancel()
84+
select {
85+
case <-time.NewTimer(time.Second).C:
86+
t.Error("timeout")
87+
case <-doneCh:
88+
}
89+
}
90+
91+
func TestServer_Serve_TwoListeners(t *testing.T) {
92+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
93+
defer cancel()
94+
95+
listen1, err := net.Listen("tcp", ":0")
96+
if err != nil {
97+
t.Fatal(err)
98+
}
99+
defer listen1.Close()
100+
101+
listen2, err := net.Listen("tcp", ":0")
102+
if err != nil {
103+
t.Fatal(err)
104+
}
105+
defer listen2.Close()
106+
107+
srv := server.Server{Logger: slog.Default(), Debug: true}
108+
go srv.Serve(ctx, listen1)
109+
go srv.Serve(ctx, listen2)
110+
111+
for ctx.Err() == nil && srv.Serving() != 2 {
112+
time.Sleep(time.Millisecond)
113+
}
114+
115+
listen1.Close()
116+
117+
for ctx.Err() == nil && srv.Serving() != 1 {
118+
time.Sleep(time.Millisecond)
119+
}
120+
121+
listen2.Close()
122+
123+
for ctx.Err() == nil && srv.Serving() != 0 {
124+
time.Sleep(time.Millisecond)
125+
}
126+
}

0 commit comments

Comments
 (0)