Skip to content

Commit 3d4b502

Browse files
authored
[server] Add health check HTTP endpoint for Relay server (#4297)
The health check endpoint listens on a dedicated HTTP server. By default, it is available at 0.0.0.0:9000/health. This can be configured using the --health-listen-address flag. The results are cached for 3 seconds to avoid excessive calls. The health check performs the following: Checks the number of active listeners. Validates each listener via WebSocket and QUIC dials, including TLS certificate verification.
1 parent a4e8647 commit 3d4b502

File tree

14 files changed

+354
-18
lines changed

14 files changed

+354
-18
lines changed

relay/cmd/root.go

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/http"
1010
"os"
1111
"os/signal"
12+
"sync"
1213
"syscall"
1314
"time"
1415

@@ -17,8 +18,9 @@ import (
1718
"github.com/spf13/cobra"
1819

1920
"github.com/netbirdio/netbird/encryption"
20-
"github.com/netbirdio/netbird/shared/relay/auth"
21+
"github.com/netbirdio/netbird/relay/healthcheck"
2122
"github.com/netbirdio/netbird/relay/server"
23+
"github.com/netbirdio/netbird/shared/relay/auth"
2224
"github.com/netbirdio/netbird/signal/metrics"
2325
"github.com/netbirdio/netbird/util"
2426
)
@@ -34,12 +36,13 @@ type Config struct {
3436
LetsencryptDomains []string
3537
// in case of using Route 53 for DNS challenge the credentials should be provided in the environment variables or
3638
// in the AWS credentials file
37-
LetsencryptAWSRoute53 bool
38-
TlsCertFile string
39-
TlsKeyFile string
40-
AuthSecret string
41-
LogLevel string
42-
LogFile string
39+
LetsencryptAWSRoute53 bool
40+
TlsCertFile string
41+
TlsKeyFile string
42+
AuthSecret string
43+
LogLevel string
44+
LogFile string
45+
HealthcheckListenAddress string
4346
}
4447

4548
func (c Config) Validate() error {
@@ -87,6 +90,7 @@ func init() {
8790
rootCmd.PersistentFlags().StringVarP(&cobraConfig.AuthSecret, "auth-secret", "s", "", "auth secret")
8891
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
8992
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
93+
rootCmd.PersistentFlags().StringVarP(&cobraConfig.HealthcheckListenAddress, "health-listen-address", "H", ":9000", "listen address of healthcheck server")
9094

9195
setFlagsFromEnvVars(rootCmd)
9296
}
@@ -102,6 +106,7 @@ func waitForExitSignal() {
102106
}
103107

104108
func execute(cmd *cobra.Command, args []string) error {
109+
wg := sync.WaitGroup{}
105110
err := cobraConfig.Validate()
106111
if err != nil {
107112
log.Debugf("invalid config: %s", err)
@@ -120,7 +125,9 @@ func execute(cmd *cobra.Command, args []string) error {
120125
return fmt.Errorf("setup metrics: %v", err)
121126
}
122127

128+
wg.Add(1)
123129
go func() {
130+
defer wg.Done()
124131
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
125132
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
126133
log.Fatalf("Failed to start metrics server: %v", err)
@@ -154,19 +161,42 @@ func execute(cmd *cobra.Command, args []string) error {
154161
return fmt.Errorf("failed to create relay server: %v", err)
155162
}
156163
log.Infof("server will be available on: %s", srv.InstanceURL())
164+
wg.Add(1)
157165
go func() {
166+
defer wg.Done()
158167
if err := srv.Listen(srvListenerCfg); err != nil {
159168
log.Fatalf("failed to bind server: %s", err)
160169
}
161170
}()
162171

172+
hCfg := healthcheck.Config{
173+
ListenAddress: cobraConfig.HealthcheckListenAddress,
174+
ServiceChecker: srv,
175+
}
176+
httpHealthcheck, err := healthcheck.NewServer(hCfg)
177+
if err != nil {
178+
log.Debugf("failed to create healthcheck server: %v", err)
179+
return fmt.Errorf("failed to create healthcheck server: %v", err)
180+
}
181+
wg.Add(1)
182+
go func() {
183+
defer wg.Done()
184+
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
185+
log.Fatalf("Failed to start healthcheck server: %v", err)
186+
}
187+
}()
188+
163189
// it will block until exit signal
164190
waitForExitSignal()
165191

166192
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
167193
defer cancel()
168194

169195
var shutDownErrors error
196+
if err := httpHealthcheck.Shutdown(ctx); err != nil {
197+
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close healthcheck server: %v", err))
198+
}
199+
170200
if err := srv.Shutdown(ctx); err != nil {
171201
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err))
172202
}
@@ -175,6 +205,8 @@ func execute(cmd *cobra.Command, args []string) error {
175205
if err := metricsServer.Shutdown(ctx); err != nil {
176206
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err))
177207
}
208+
209+
wg.Wait()
178210
return shutDownErrors
179211
}
180212

relay/healthcheck/healthcheck.go

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
package healthcheck
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"net"
8+
"net/http"
9+
"sync"
10+
"time"
11+
12+
log "github.com/sirupsen/logrus"
13+
14+
"github.com/netbirdio/netbird/relay/protocol"
15+
"github.com/netbirdio/netbird/relay/server/listener/quic"
16+
"github.com/netbirdio/netbird/relay/server/listener/ws"
17+
)
18+
19+
const (
20+
statusHealthy = "healthy"
21+
statusUnhealthy = "unhealthy"
22+
23+
path = "/health"
24+
25+
cacheTTL = 3 * time.Second // Cache TTL for health status
26+
)
27+
28+
type ServiceChecker interface {
29+
ListenerProtocols() []protocol.Protocol
30+
ListenAddress() string
31+
}
32+
33+
type HealthStatus struct {
34+
Status string `json:"status"`
35+
Timestamp time.Time `json:"timestamp"`
36+
Listeners []protocol.Protocol `json:"listeners"`
37+
CertificateValid bool `json:"certificate_valid"`
38+
}
39+
40+
type Config struct {
41+
ListenAddress string
42+
ServiceChecker ServiceChecker
43+
}
44+
45+
type Server struct {
46+
config Config
47+
httpServer *http.Server
48+
49+
cacheMu sync.Mutex
50+
cacheStatus *HealthStatus
51+
}
52+
53+
func NewServer(config Config) (*Server, error) {
54+
mux := http.NewServeMux()
55+
56+
if config.ServiceChecker == nil {
57+
return nil, errors.New("service checker is required")
58+
}
59+
60+
server := &Server{
61+
config: config,
62+
httpServer: &http.Server{
63+
Addr: config.ListenAddress,
64+
Handler: mux,
65+
ReadTimeout: 5 * time.Second,
66+
WriteTimeout: 10 * time.Second,
67+
IdleTimeout: 15 * time.Second,
68+
},
69+
}
70+
71+
mux.HandleFunc(path, server.handleHealthcheck)
72+
return server, nil
73+
}
74+
75+
func (s *Server) ListenAndServe() error {
76+
log.Infof("starting healthcheck server on: http://%s%s", dialAddress(s.config.ListenAddress), path)
77+
return s.httpServer.ListenAndServe()
78+
}
79+
80+
// Shutdown gracefully shuts down the healthcheck server
81+
func (s *Server) Shutdown(ctx context.Context) error {
82+
log.Info("Shutting down healthcheck server")
83+
return s.httpServer.Shutdown(ctx)
84+
}
85+
86+
func (s *Server) handleHealthcheck(w http.ResponseWriter, _ *http.Request) {
87+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
88+
defer cancel()
89+
90+
var (
91+
status *HealthStatus
92+
ok bool
93+
)
94+
// Cache check
95+
s.cacheMu.Lock()
96+
status = s.cacheStatus
97+
s.cacheMu.Unlock()
98+
99+
if status != nil && time.Since(status.Timestamp) <= cacheTTL {
100+
ok = status.Status == statusHealthy
101+
} else {
102+
status, ok = s.getHealthStatus(ctx)
103+
// Update cache
104+
s.cacheMu.Lock()
105+
s.cacheStatus = status
106+
s.cacheMu.Unlock()
107+
}
108+
109+
w.Header().Set("Content-Type", "application/json")
110+
111+
if ok {
112+
w.WriteHeader(http.StatusOK)
113+
} else {
114+
w.WriteHeader(http.StatusServiceUnavailable)
115+
}
116+
117+
encoder := json.NewEncoder(w)
118+
if err := encoder.Encode(status); err != nil {
119+
log.Errorf("Failed to encode healthcheck response: %v", err)
120+
}
121+
}
122+
123+
func (s *Server) getHealthStatus(ctx context.Context) (*HealthStatus, bool) {
124+
healthy := true
125+
status := &HealthStatus{
126+
Timestamp: time.Now(),
127+
Status: statusHealthy,
128+
CertificateValid: true,
129+
}
130+
131+
listeners, ok := s.validateListeners()
132+
if !ok {
133+
status.Status = statusUnhealthy
134+
healthy = false
135+
}
136+
status.Listeners = listeners
137+
138+
if ok := s.validateCertificate(ctx); !ok {
139+
status.Status = statusUnhealthy
140+
status.CertificateValid = false
141+
healthy = false
142+
}
143+
144+
return status, healthy
145+
}
146+
147+
func (s *Server) validateListeners() ([]protocol.Protocol, bool) {
148+
listeners := s.config.ServiceChecker.ListenerProtocols()
149+
if len(listeners) == 0 {
150+
return nil, false
151+
}
152+
return listeners, true
153+
}
154+
155+
func (s *Server) validateCertificate(ctx context.Context) bool {
156+
listenAddress := s.config.ServiceChecker.ListenAddress()
157+
if listenAddress == "" {
158+
log.Warn("listen address is empty")
159+
return false
160+
}
161+
162+
dAddr := dialAddress(listenAddress)
163+
164+
for _, proto := range s.config.ServiceChecker.ListenerProtocols() {
165+
switch proto {
166+
case ws.Proto:
167+
if err := dialWS(ctx, dAddr); err != nil {
168+
log.Errorf("failed to dial WebSocket listener: %v", err)
169+
return false
170+
}
171+
case quic.Proto:
172+
if err := dialQUIC(ctx, dAddr); err != nil {
173+
log.Errorf("failed to dial QUIC listener: %v", err)
174+
return false
175+
}
176+
default:
177+
log.Warnf("unknown protocol for healthcheck: %s", proto)
178+
return false
179+
}
180+
}
181+
return true
182+
}
183+
184+
func dialAddress(listenAddress string) string {
185+
host, port, err := net.SplitHostPort(listenAddress)
186+
if err != nil {
187+
return listenAddress // fallback, might be invalid for dialing
188+
}
189+
190+
if host == "" || host == "::" || host == "0.0.0.0" {
191+
host = "0.0.0.0"
192+
}
193+
194+
return net.JoinHostPort(host, port)
195+
}

relay/healthcheck/quic.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package healthcheck
2+
3+
import (
4+
"context"
5+
"crypto/tls"
6+
"fmt"
7+
"time"
8+
9+
"github.com/quic-go/quic-go"
10+
11+
tlsnb "github.com/netbirdio/netbird/shared/relay/tls"
12+
)
13+
14+
func dialQUIC(ctx context.Context, address string) error {
15+
tlsConfig := &tls.Config{
16+
InsecureSkipVerify: false, // Keep certificate validation enabled
17+
NextProtos: []string{tlsnb.NBalpn},
18+
}
19+
20+
conn, err := quic.DialAddr(ctx, address, tlsConfig, &quic.Config{
21+
MaxIdleTimeout: 30 * time.Second,
22+
KeepAlivePeriod: 10 * time.Second,
23+
EnableDatagrams: true,
24+
})
25+
if err != nil {
26+
return fmt.Errorf("failed to connect to QUIC server: %w", err)
27+
}
28+
29+
_ = conn.CloseWithError(0, "availability check complete")
30+
return nil
31+
}

relay/healthcheck/ws.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package healthcheck
2+
3+
import (
4+
"context"
5+
"fmt"
6+
7+
"github.com/coder/websocket"
8+
9+
"github.com/netbirdio/netbird/shared/relay"
10+
)
11+
12+
func dialWS(ctx context.Context, address string) error {
13+
url := fmt.Sprintf("wss://%s%s", address, relay.WebSocketURLPath)
14+
15+
conn, resp, err := websocket.Dial(ctx, url, nil)
16+
if resp != nil {
17+
defer func() {
18+
_ = resp.Body.Close()
19+
}()
20+
21+
}
22+
if err != nil {
23+
return fmt.Errorf("failed to connect to websocket: %w", err)
24+
}
25+
26+
_ = conn.Close(websocket.StatusNormalClosure, "availability check complete")
27+
return nil
28+
}

relay/protocol/protocol.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package protocol
2+
3+
type Protocol string

relay/server/listener/listener.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@ package listener
33
import (
44
"context"
55
"net"
6+
7+
"github.com/netbirdio/netbird/relay/protocol"
68
)
79

810
type Listener interface {
911
Listen(func(conn net.Conn)) error
1012
Shutdown(ctx context.Context) error
13+
Protocol() protocol.Protocol
1114
}

0 commit comments

Comments
 (0)