Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 44 additions & 31 deletions intra/dnscrypt/multiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,44 +425,57 @@ func (proxy *DcMulti) start() error {
curve25519.ScalarBaseMult(&proxy.proxyPublicKey, &proxy.proxySecretKey)

_, err := proxy.Refresh()
_ = core.Periodic("dcmulti.start", proxy.ctx, certRefreshDelay, func() {
maxtries := 10
i := 0

// This goroutine periodically refreshes the certificates.
// It uses a single timer that adjusts its delay based on success or failure,
// making the retry mechanism more reliable and resilient.
go func() {
var delay time.Duration
if len(proxy.liveServers) > 0 {
delay = certRefreshDelay
} else {
delay = certRefreshDelayAfterFailure
}

timer := time.NewTimer(delay)
defer timer.Stop()

for {
i++
if i > maxtries {
log.E("dnscrypt: cert refresh failed after %d tries", maxtries)
return
}
select {
case <-proxy.ctx.Done():
log.I("dnscrypt: cert refresh stopped")
return
default:
}

hasServers := proxy.serversInfo.len() > 0
if !hasServers {
log.D("dnscrypt: no servers; next check after %v", certRefreshDelayAfterFailure)
return
}
proxy.liveServers, _ = proxy.serversInfo.refresh(proxy)
if someAlive := len(proxy.liveServers) > 0; someAlive {
log.I("dnscrypt: some servers alive; retry #%d; next check after",
i, certRefreshDelayAfterFailure)
proxy.certIgnoreTimestamp = false
return
case <-timer.C:
// If there are no registered servers, wait for the full refresh delay.
hasRegisteredServers := proxy.serversInfo.len() > 0
if !hasRegisteredServers {
log.D("dnscrypt: no registered servers; next check after %v", certRefreshDelay)
timer.Reset(certRefreshDelay)
continue
}

// Attempt to refresh the certificates for all registered servers.
live, refreshErr := proxy.serversInfo.refresh(proxy)

proxy.Lock()
proxy.liveServers = live
if len(proxy.liveServers) > 0 {
// If the refresh is successful, use the standard refresh delay.
log.I("dnscrypt: cert refresh success, next check in %v", certRefreshDelay)
proxy.certIgnoreTimestamp = false
delay = certRefreshDelay
} else {
// If the refresh fails, use a shorter delay to retry sooner.
log.W("dnscrypt: all servers dead; retry in %v, err: %v", certRefreshDelayAfterFailure, refreshErr)
proxy.certIgnoreTimestamp = true
delay = certRefreshDelayAfterFailure
}
proxy.Unlock()
timer.Reset(delay)
}
proxy.certIgnoreTimestamp = true
backoff := time.Duration(i) * time.Second
wait := certRefreshDelayAfterFailure * backoff
log.W("dnscrypt: all servers dead; retry #%d in %v", i, wait)
time.Sleep(wait)
continue

}
})
// todo: on error: context.AfterFunc(refreshCtx, proxy.notifyRestart)
}()

return err
}

Expand Down
88 changes: 83 additions & 5 deletions intra/dnscrypt/servers.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ import (
"math/rand"
"net"
"net/netip"
"sort"
"strconv"
"strings"
"sync"
"time"

x "github.com/celzero/firestack/intra/backend"
"github.com/celzero/firestack/intra/core"
Expand Down Expand Up @@ -68,6 +70,8 @@ type serverinfo struct {
RelayUDPAddrs *core.Volatile[[]*net.UDPAddr] // anonymous relays, if any
RelayTCPAddrs *core.Volatile[[]*net.TCPAddr] // anonymous relays, if any
status *core.Volatile[int] // status of the last query
unhealthy *core.Volatile[bool]
lastErr *core.Volatile[int64] // unix timestamp of the last error
}

var _ dnsx.Transport = (*serverinfo)(nil)
Expand Down Expand Up @@ -117,13 +121,39 @@ func (serversInfo *ServersInfo) getOne() (serverInfo *serverinfo) {
if serversCount <= 0 {
return nil
}

// Create a slice of healthy servers. A server is considered healthy if it
// has not been marked as unhealthy and its status is not DEnd or Paused.
var healthyServers []*serverinfo
for _, si := range serversInfo.inner {
if si != nil && dnsx.WillErr(si) == nil {
healthyServers = append(healthyServers, si)
}
}

// If there are healthy servers, sort them by latency and return the best one.
if len(healthyServers) > 0 {
// Sort healthy servers by latency (p50) in ascending order.
sort.Slice(healthyServers, func(i, j int) bool {
return healthyServers[i].P50() < healthyServers[j].P50()
})
serverInfo = healthyServers[0]
if settings.Debug {
log.V("dnscrypt: selected candidate [%v] with p50 [%d]", serverInfo, serverInfo.P50())
}
return serverInfo
}

// If no healthy servers are found, fallback to the original random selection.
// This ensures that the client can still attempt to resolve DNS queries
// even if all servers are marked as unhealthy.
selectAny := false
candidate := rand.Intn(serversCount)
retry:
i := 0
for _, si := range serversInfo.inner {
if i == candidate || selectAny {
if si != nil && dnsx.WillErr(si) == nil {
if si != nil {
if settings.Debug {
log.V("dnscrypt: candidate [%v]", si) // may be nil?
}
Expand Down Expand Up @@ -282,7 +312,26 @@ func fetchDNSCryptServerInfo(proxy *DcMulti, name string, stamp stamps.ServerSta
relay: relay,
est: core.NewP50Estimator(ctx),
status: core.NewVolatile(dnsx.Start),
}
unhealthy: core.NewVolatile(false),
lastErr: core.NewVolatile[int64](0),
}

// This goroutine periodically checks the health of the server, updating
// its latency and marking it as healthy or unhealthy.
go func() {
timer := time.NewTicker(60 * time.Second) // Periodically check health
defer timer.Stop()

for {
select {
case <-si.ctx.Done():
return
case <-timer.C:
si.checkHealth()
}
}
}()

log.I("dnscrypt: (%s) setup: %s; anonrelay? %t, proxy? %t", name, si.HostName, len(relay) > 0)
return si, nil
}
Expand Down Expand Up @@ -386,11 +435,19 @@ func (s *serverinfo) Query(network string, q *dns.Msg, smm *x.DNSSummary) (r *dn
r, err = resolve(network, q, s, smm)
s.status.Store(smm.Status)

if s.est != nil {
s.est.Add(smm.Latency)
}
// If the query fails, mark the server as unhealthy and record the time.
// This will prevent the server from being selected for new queries until
// it is marked as healthy again by the periodic health check.
if err != nil {
s.unhealthy.Store(true)
s.lastErr.Store(time.Now().Unix())
smm.Msg = err.Error()
} else {
s.unhealthy.Store(false)
}

if s.est != nil {
s.est.Add(smm.Latency)
}

return
Expand Down Expand Up @@ -432,6 +489,9 @@ func (s *serverinfo) IPPorts() []netip.AddrPort {
}

func (s *serverinfo) Status() int {
if s.unhealthy.Load() {
return dnsx.DEnd
}
if px := s.getRelay(); px != nil {
if px.Status() == ipn.TPU {
return dnsx.Paused
Expand All @@ -440,6 +500,24 @@ func (s *serverinfo) Status() int {
return s.status.Load()
}

func (s *serverinfo) checkHealth() {
// Send a test query to the server to check its health and latency.
q := new(dns.Msg)
q.SetQuestion(".", dns.TypeNS)
smm := &x.DNSSummary{}
_, err := s.Query(dnsx.NetTypeUDP, q, smm)

// Log the health check result.
if err == nil {
s.unhealthy.Store(false)
log.I("dnscrypt: health check for %s successful, latency: %dms", s.Name, s.P50())
} else {
s.unhealthy.Store(true)
s.lastErr.Store(time.Now().Unix())
log.W("dnscrypt: health check for %s failed: %v", s.Name, err)
}
}

func (s *serverinfo) Stop() error {
if s != nil {
s.status.Store(dnsx.DEnd)
Expand Down