Skip to content

Commit 4a9a677

Browse files
committed
all: close active listeners on error
1 parent f00be4d commit 4a9a677

File tree

9 files changed

+94
-75
lines changed

9 files changed

+94
-75
lines changed

proxy/config.go

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -266,35 +266,42 @@ type PendingRequestsConfig struct {
266266
func (p *Proxy) validateConfig() (err error) {
267267
err = p.UpstreamConfig.validate()
268268
if err != nil {
269-
return fmt.Errorf("validating general upstreams: %w", err)
269+
return fmt.Errorf("general upstreams: %w", err)
270270
}
271271

272272
err = ValidatePrivateConfig(p.PrivateRDNSUpstreamConfig, p.privateNets)
273273
if err != nil {
274274
if p.UsePrivateRDNS || errors.Is(err, upstream.ErrNoUpstreams) {
275-
return fmt.Errorf("validating private RDNS upstreams: %w", err)
275+
return fmt.Errorf("private RDNS upstreams: %w", err)
276276
}
277277
}
278278

279+
err = p.Fallbacks.validate()
279280
// Allow [Proxy.Fallbacks] to be nil, but not empty. nil means not to use
280281
// fallbacks at all.
281-
err = p.Fallbacks.validate()
282282
if errors.Is(err, upstream.ErrNoUpstreams) {
283-
return fmt.Errorf("validating fallbacks: %w", err)
283+
return fmt.Errorf("fallbacks: %w", err)
284284
}
285285

286286
err = p.validateRatelimit()
287287
if err != nil {
288-
return fmt.Errorf("validating ratelimit: %w", err)
288+
return fmt.Errorf("ratelimit: %w", err)
289289
}
290290

291291
switch p.UpstreamMode {
292-
case "":
293-
// Go on.
294-
case UpstreamModeFastestAddr, UpstreamModeLoadBalance, UpstreamModeParallel:
292+
case
293+
"",
294+
UpstreamModeFastestAddr,
295+
UpstreamModeLoadBalance,
296+
UpstreamModeParallel:
295297
// Go on.
296298
default:
297-
return fmt.Errorf("bad upstream mode: %q", p.UpstreamMode)
299+
return fmt.Errorf("upstream mode: %w: %q", errors.ErrBadEnumValue, p.UpstreamMode)
300+
}
301+
302+
err = p.validateBasicAuth()
303+
if err != nil {
304+
return fmt.Errorf("basic auth: %w", err)
298305
}
299306

300307
p.logConfigInfo()
@@ -368,9 +375,11 @@ func (p *Proxy) logConfigInfo() {
368375

369376
// validateListenAddrs returns an error if the addresses are not configured
370377
// properly.
378+
//
379+
// TODO(e.burkov): Move to configuration validation.
371380
func (p *Proxy) validateListenAddrs() (err error) {
372381
if !p.hasListenAddrs() {
373-
return errors.Error("no listen address specified")
382+
return fmt.Errorf("listen addresses: %w", errors.ErrNoValue)
374383
}
375384

376385
err = p.validateTLSConfig()

proxy/dns64.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ const (
3737
// is specified explicitly. Each prefix also validated to be a valid IPv6 CIDR
3838
// with a maximum length of 96 bits. The first specified prefix is then used to
3939
// synthesize AAAA records.
40+
//
41+
// TODO(e.burkov): Split validation and initialization.
4042
func (p *Proxy) setupDNS64() (err error) {
4143
if !p.Config.UseDNS64 {
4244
return nil

proxy/proxy.go

Lines changed: 56 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"github.com/AdguardTeam/golibs/service"
2828
"github.com/AdguardTeam/golibs/syncutil"
2929
"github.com/AdguardTeam/golibs/timeutil"
30+
"github.com/AdguardTeam/golibs/validate"
3031
"github.com/ameshkov/dnscrypt/v2"
3132
"github.com/miekg/dns"
3233
gocache "github.com/patrickmn/go-cache"
@@ -183,9 +184,9 @@ type Proxy struct {
183184
// udpOOBSize is the size of the out-of-band data for UDP connections.
184185
udpOOBSize int
185186

186-
// bindRetryNum is the number of retries for binding to an address for
187+
// bindRetryCount is the number of retries for binding to an address for
187188
// listening. Zero means one attempt and no retries.
188-
bindRetryNum uint
189+
bindRetryCount uint
189190

190191
// bindRetryIvl is the interval between attempts to bind to an address for
191192
// listening.
@@ -215,6 +216,8 @@ type Proxy struct {
215216
// New creates a new Proxy with the specified configuration. c must not be nil.
216217
//
217218
// TODO(e.burkov): Cover with tests.
219+
//
220+
// TODO(e.burkov): Add context.
218221
func New(c *Config) (p *Proxy, err error) {
219222
p = &Proxy{
220223
Config: *c,
@@ -256,12 +259,6 @@ func New(c *Config) (p *Proxy, err error) {
256259
return nil, err
257260
}
258261

259-
// TODO(s.chzhen): Consider moving to [Proxy.validateConfig].
260-
err = p.validateBasicAuth()
261-
if err != nil {
262-
return nil, fmt.Errorf("basic auth: %w", err)
263-
}
264-
265262
p.initCache()
266263

267264
if p.MaxGoroutines > 0 {
@@ -281,7 +278,7 @@ func New(c *Config) (p *Proxy, err error) {
281278
}
282279

283280
if bindRetries := c.BindRetryConfig; bindRetries != nil && bindRetries.Enabled {
284-
p.bindRetryNum = bindRetries.Count
281+
p.bindRetryCount = bindRetries.Count
285282
p.bindRetryIvl = bindRetries.Interval
286283
}
287284

@@ -290,6 +287,7 @@ func New(c *Config) (p *Proxy, err error) {
290287
return nil, fmt.Errorf("setting up DNS64: %w", err)
291288
}
292289

290+
// TODO(e.burkov): Clone all mutable fields of Config.
293291
p.RatelimitWhitelist = slices.Clone(p.RatelimitWhitelist)
294292
slices.SortFunc(p.RatelimitWhitelist, netip.Addr.Compare)
295293

@@ -324,11 +322,7 @@ func (p *Proxy) validateBasicAuth() (err error) {
324322
return nil
325323
}
326324

327-
if len(conf.HTTPSListenAddr) == 0 {
328-
return errors.Error("no https addrs")
329-
}
330-
331-
return nil
325+
return validate.NotEmptySlice("https listen addrs", conf.HTTPSListenAddr)
332326
}
333327

334328
// Returns true if proxy is started. It is safe for concurrent use.
@@ -359,12 +353,15 @@ func (p *Proxy) Start(ctx context.Context) (err error) {
359353
return err
360354
}
361355

362-
err = p.configureListeners(ctx)
356+
err = p.startListeners(ctx)
363357
if err != nil {
364-
return fmt.Errorf("configuring listeners: %w", err)
358+
closeErr := errors.Join(p.closeListeners(nil)...)
359+
360+
return fmt.Errorf("configuring listeners: %w", errors.WithDeferred(err, closeErr))
365361
}
366362

367-
p.startListeners()
363+
p.serveListeners()
364+
368365
p.started = true
369366

370367
return nil
@@ -390,7 +387,8 @@ func closeAll[C io.Closer](errs []error, closers ...C) (appended []error) {
390387
return errs
391388
}
392389

393-
// Shutdown implements the [service.Interface] for *Proxy.
390+
// Shutdown implements the [service.Interface] for *Proxy. It also closes the
391+
// configured upstream configurations.
394392
func (p *Proxy) Shutdown(ctx context.Context) (err error) {
395393
p.logger.InfoContext(ctx, "stopping server")
396394

@@ -404,65 +402,75 @@ func (p *Proxy) Shutdown(ctx context.Context) (err error) {
404402
return nil
405403
}
406404

407-
errs := closeAll(nil, p.tcpListen...)
405+
errs := p.closeListeners(nil)
406+
407+
for _, u := range []*UpstreamConfig{
408+
p.UpstreamConfig,
409+
p.PrivateRDNSUpstreamConfig,
410+
p.Fallbacks,
411+
} {
412+
if u != nil {
413+
errs = closeAll(errs, u)
414+
}
415+
}
416+
417+
p.started = false
418+
419+
p.logger.InfoContext(ctx, "stopped dns proxy server")
420+
421+
err = errors.Join(errs...)
422+
if err != nil {
423+
return fmt.Errorf("stopping dns proxy server: %w", err)
424+
}
425+
426+
return nil
427+
}
428+
429+
// closeListeners closes all acrive listeners and returns the occurred errors.
430+
func (p *Proxy) closeListeners(errs []error) (res []error) {
431+
res = errs
432+
433+
res = closeAll(res, p.tcpListen...)
408434
p.tcpListen = nil
409435

410-
errs = closeAll(errs, p.udpListen...)
436+
res = closeAll(res, p.udpListen...)
411437
p.udpListen = nil
412438

413-
errs = closeAll(errs, p.tlsListen...)
439+
res = closeAll(res, p.tlsListen...)
414440
p.tlsListen = nil
415441

416442
if p.httpsServer != nil {
417-
errs = closeAll(errs, p.httpsServer)
443+
res = closeAll(res, p.httpsServer)
418444
p.httpsServer = nil
419445

420446
// No need to close these since they're closed by httpsServer.Close().
421447
p.httpsListen = nil
422448
}
423449

424450
if p.h3Server != nil {
425-
errs = closeAll(errs, p.h3Server)
451+
res = closeAll(res, p.h3Server)
426452
p.h3Server = nil
427453
}
428454

429-
errs = closeAll(errs, p.h3Listen...)
455+
res = closeAll(res, p.h3Listen...)
430456
p.h3Listen = nil
431457

432-
errs = closeAll(errs, p.quicListen...)
458+
res = closeAll(res, p.quicListen...)
433459
p.quicListen = nil
434460

435-
errs = closeAll(errs, p.quicTransports...)
461+
res = closeAll(res, p.quicTransports...)
436462
p.quicTransports = nil
437463

438-
errs = closeAll(errs, p.quicConns...)
464+
res = closeAll(res, p.quicConns...)
439465
p.quicConns = nil
440466

441-
errs = closeAll(errs, p.dnsCryptUDPListen...)
467+
res = closeAll(res, p.dnsCryptUDPListen...)
442468
p.dnsCryptUDPListen = nil
443469

444-
errs = closeAll(errs, p.dnsCryptTCPListen...)
470+
res = closeAll(res, p.dnsCryptTCPListen...)
445471
p.dnsCryptTCPListen = nil
446472

447-
for _, u := range []*UpstreamConfig{
448-
p.UpstreamConfig,
449-
p.PrivateRDNSUpstreamConfig,
450-
p.Fallbacks,
451-
} {
452-
if u != nil {
453-
errs = closeAll(errs, u)
454-
}
455-
}
456-
457-
p.started = false
458-
459-
p.logger.InfoContext(ctx, "stopped dns proxy server")
460-
461-
if len(errs) > 0 {
462-
return fmt.Errorf("stopping dns proxy server: %w", errors.Join(errs...))
463-
}
464-
465-
return nil
473+
return res
466474
}
467475

468476
// addrFunc provides the address from the given A.

proxy/retry.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func (p *Proxy) bindWithRetry(ctx context.Context, bindFunc func() (err error))
3232

3333
p.logger.WarnContext(ctx, "binding", "attempt", 1, slogutil.KeyError, err)
3434

35-
for attempt := uint(1); attempt <= p.bindRetryNum; attempt++ {
35+
for attempt := uint(1); attempt <= p.bindRetryCount; attempt++ {
3636
time.Sleep(p.bindRetryIvl)
3737

3838
retryErr := bindFunc()

proxy/retry_internal_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ func TestWithRetry(t *testing.T) {
7676
}}
7777

7878
p := &Proxy{
79-
logger: slogutil.NewDiscardLogger(),
80-
bindRetryNum: 1,
81-
bindRetryIvl: 0,
79+
logger: slogutil.NewDiscardLogger(),
80+
bindRetryCount: 1,
81+
bindRetryIvl: 0,
8282
}
8383

8484
for _, tc := range testCases {

proxy/server.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ import (
1515
"github.com/quic-go/quic-go"
1616
)
1717

18-
// configureListeners configures listeners.
19-
func (p *Proxy) configureListeners(ctx context.Context) (err error) {
18+
// startListeners configures listeners and starts listening each configured
19+
// address. If it returns an error, all listeners should be closed manually.
20+
func (p *Proxy) startListeners(ctx context.Context) (err error) {
2021
err = p.initUDPListeners(ctx)
2122
if err != nil {
2223
return err
@@ -50,8 +51,8 @@ func (p *Proxy) configureListeners(ctx context.Context) (err error) {
5051
return nil
5152
}
5253

53-
// startListeners starts listener loops.
54-
func (p *Proxy) startListeners() {
54+
// serveListeners starts serving the configured listeners.
55+
func (p *Proxy) serveListeners() {
5556
for _, l := range p.udpListen {
5657
go p.udpPacketLoop(l, p.requestsSema)
5758
}

proxy/upstreams.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -327,20 +327,15 @@ func (p *configParser) includeToReserved(dnsUpstream upstream.Upstream, domains
327327
// considered valid if it contains at least a single default upstream. Empty c
328328
// causes [upstream.ErrNoUpstreams].
329329
func (uc *UpstreamConfig) validate() (err error) {
330-
const (
331-
errNilConf errors.Error = "upstream config is nil"
332-
errNoDefault errors.Error = "no default upstreams specified"
333-
)
334-
335330
switch {
336331
case uc == nil:
337-
return errNilConf
332+
return errors.ErrNoValue
338333
case len(uc.Upstreams) > 0:
339334
return nil
340335
case len(uc.DomainReservedUpstreams) == 0 && len(uc.SpecifiedDomainUpstreams) == 0:
341336
return upstream.ErrNoUpstreams
342337
default:
343-
return errNoDefault
338+
return fmt.Errorf("default upstreams: %w", errors.ErrNoValue)
344339
}
345340
}
346341

proxy/upstreams_internal_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ func TestUpstreamConfig_Validate(t *testing.T) {
188188
},
189189
}, {
190190
name: "no_default",
191-
wantErr: errors.Error("no default upstreams specified"),
191+
wantErr: errors.ErrNoValue,
192192
in: []string{
193193
"[/domain.example/]udp://upstream.example:53",
194194
"[/another.domain.example/]#",
@@ -205,7 +205,9 @@ func TestUpstreamConfig_Validate(t *testing.T) {
205205
}
206206

207207
t.Run("actual_nil", func(t *testing.T) {
208-
assert.ErrorIs(t, (*UpstreamConfig)(nil).validate(), errors.Error("upstream config is nil"))
208+
var c *UpstreamConfig
209+
210+
assert.Equal(t, c.validate(), errors.ErrNoValue)
209211
})
210212
}
211213

upstream/parallel.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"github.com/miekg/dns"
99
)
1010

11+
// TODO(e.burkov): Consider using wrapped [errors.ErrNoValue] and
12+
// [errors.ErrEmptyValue] instead.
1113
const (
1214
// ErrNoUpstreams is returned from the methods that expect at least a single
1315
// upstream to work with when no upstreams specified.

0 commit comments

Comments
 (0)