Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 236 additions & 117 deletions README.md

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ const (

coapWithoutDTLS = "MPROXY_COAP_WITHOUT_DTLS_"
coapWithDTLS = "MPROXY_COAP_WITH_DTLS_"

defaultTargetHost = "localhost"
)

func main() {
Expand Down Expand Up @@ -136,7 +138,7 @@ func startMQTTProxy(g *errgroup.Group, ctx context.Context, envPrefix string, ha
}

if cfg.TargetHost == "" {
cfg.TargetHost = "localhost"
cfg.TargetHost = defaultTargetHost
}

if cfg.TargetPort == "" {
Expand Down Expand Up @@ -187,7 +189,7 @@ func startWebSocketProxy(g *errgroup.Group, ctx context.Context, envPrefix strin
}

if cfg.TargetHost == "" {
cfg.TargetHost = "localhost"
cfg.TargetHost = defaultTargetHost
}

if cfg.TargetPort == "" {
Expand Down Expand Up @@ -245,7 +247,7 @@ func startHTTPProxy(g *errgroup.Group, ctx context.Context, envPrefix string, ha
}

if cfg.TargetHost == "" {
cfg.TargetHost = "localhost"
cfg.TargetHost = defaultTargetHost
}

if cfg.TargetPort == "" {
Expand Down Expand Up @@ -300,7 +302,7 @@ func startCoAPProxy(g *errgroup.Group, ctx context.Context, envPrefix string, ha
}

if cfg.TargetHost == "" {
cfg.TargetHost = "localhost"
cfg.TargetHost = defaultTargetHost
}

if cfg.TargetPort == "" {
Expand Down
18 changes: 5 additions & 13 deletions cmd/production/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"github.com/absmach/mproxy/pkg/ratelimit"
)

const protocolMQTT = "mqtt"

// RateLimitedHandler wraps a handler with rate limiting.
type RateLimitedHandler struct {
handler handler.Handler
Expand Down Expand Up @@ -97,13 +99,10 @@ type InstrumentedHandler struct {
func (h *InstrumentedHandler) AuthConnect(ctx context.Context, hctx *handler.Context) error {
start := time.Now()
h.metrics.AuthAttempts.WithLabelValues(hctx.Protocol, "connect").Inc()

err := h.handler.AuthConnect(ctx, hctx)

if err != nil {
h.metrics.AuthFailures.WithLabelValues(hctx.Protocol, "connect", "unauthorized").Inc()
}

duration := time.Since(start).Seconds()
h.metrics.RequestDuration.WithLabelValues(hctx.Protocol, "connect").Observe(duration)

Expand All @@ -114,17 +113,13 @@ func (h *InstrumentedHandler) AuthConnect(ctx context.Context, hctx *handler.Con
func (h *InstrumentedHandler) AuthPublish(ctx context.Context, hctx *handler.Context, topic *string, payload *[]byte) error {
start := time.Now()
h.metrics.AuthAttempts.WithLabelValues(hctx.Protocol, "publish").Inc()

if payload != nil {
h.metrics.RequestSize.WithLabelValues(hctx.Protocol).Observe(float64(len(*payload)))
}

err := h.handler.AuthPublish(ctx, hctx, topic, payload)

if err != nil {
h.metrics.AuthFailures.WithLabelValues(hctx.Protocol, "publish", "unauthorized").Inc()
}

duration := time.Since(start).Seconds()
h.metrics.RequestDuration.WithLabelValues(hctx.Protocol, "publish").Observe(duration)

Expand All @@ -141,13 +136,10 @@ func (h *InstrumentedHandler) AuthPublish(ctx context.Context, hctx *handler.Con
func (h *InstrumentedHandler) AuthSubscribe(ctx context.Context, hctx *handler.Context, topics *[]string) error {
start := time.Now()
h.metrics.AuthAttempts.WithLabelValues(hctx.Protocol, "subscribe").Inc()

err := h.handler.AuthSubscribe(ctx, hctx, topics)

if err != nil {
h.metrics.AuthFailures.WithLabelValues(hctx.Protocol, "subscribe", "unauthorized").Inc()
}

duration := time.Since(start).Seconds()
h.metrics.RequestDuration.WithLabelValues(hctx.Protocol, "subscribe").Observe(duration)

Expand All @@ -170,7 +162,7 @@ func (h *InstrumentedHandler) OnConnect(ctx context.Context, hctx *handler.Conte

// OnPublish implements handler.Handler with metrics.
func (h *InstrumentedHandler) OnPublish(ctx context.Context, hctx *handler.Context, topic string, payload []byte) error {
if hctx.Protocol == "mqtt" {
if hctx.Protocol == protocolMQTT {
h.metrics.MQTTPackets.WithLabelValues("publish", "upstream").Inc()
}

Expand All @@ -179,7 +171,7 @@ func (h *InstrumentedHandler) OnPublish(ctx context.Context, hctx *handler.Conte

// OnSubscribe implements handler.Handler with metrics.
func (h *InstrumentedHandler) OnSubscribe(ctx context.Context, hctx *handler.Context, topics []string) error {
if hctx.Protocol == "mqtt" {
if hctx.Protocol == protocolMQTT {
h.metrics.MQTTPackets.WithLabelValues("subscribe", "upstream").Inc()
}

Expand All @@ -188,7 +180,7 @@ func (h *InstrumentedHandler) OnSubscribe(ctx context.Context, hctx *handler.Con

// OnUnsubscribe implements handler.Handler with metrics.
func (h *InstrumentedHandler) OnUnsubscribe(ctx context.Context, hctx *handler.Context, topics []string) error {
if hctx.Protocol == "mqtt" {
if hctx.Protocol == protocolMQTT {
h.metrics.MQTTPackets.WithLabelValues("unsubscribe", "upstream").Inc()
}

Expand Down
47 changes: 28 additions & 19 deletions cmd/production/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,29 @@ import (
"golang.org/x/sync/errgroup"
)

const logError = "error"

// Config holds the application configuration.
type Config struct {
// Observability
MetricsPort int `env:"METRICS_PORT" envDefault:"9090"`
HealthPort int `env:"HEALTH_PORT" envDefault:"8080"`
LogLevel string `env:"LOG_LEVEL" envDefault:"info"`
LogFormat string `env:"LOG_FORMAT" envDefault:"json"`
MetricsPort int `env:"METRICS_PORT" envDefault:"9090"`
HealthPort int `env:"HEALTH_PORT" envDefault:"8080"`
LogLevel string `env:"LOG_LEVEL" envDefault:"info"`
LogFormat string `env:"LOG_FORMAT" envDefault:"json"`

// Resource Limits
MaxConnections int `env:"MAX_CONNECTIONS" envDefault:"10000"`
MaxGoroutines int `env:"MAX_GOROUTINES" envDefault:"50000"`
MaxConnections int `env:"MAX_CONNECTIONS" envDefault:"10000"`
MaxGoroutines int `env:"MAX_GOROUTINES" envDefault:"50000"`

// Connection Pooling
PoolMaxIdle int `env:"POOL_MAX_IDLE" envDefault:"100"`
PoolMaxActive int `env:"POOL_MAX_ACTIVE" envDefault:"1000"`
PoolIdleTimeout time.Duration `env:"POOL_IDLE_TIMEOUT" envDefault:"5m"`

// Circuit Breaker
BreakerMaxFailures int `env:"BREAKER_MAX_FAILURES" envDefault:"5"`
BreakerResetTimeout time.Duration `env:"BREAKER_RESET_TIMEOUT" envDefault:"60s"`
BreakerTimeout time.Duration `env:"BREAKER_TIMEOUT" envDefault:"30s"`
BreakerMaxFailures int `env:"BREAKER_MAX_FAILURES" envDefault:"5"`
BreakerResetTimeout time.Duration `env:"BREAKER_RESET_TIMEOUT" envDefault:"60s"`
BreakerTimeout time.Duration `env:"BREAKER_TIMEOUT" envDefault:"30s"`

// Rate Limiting
RateLimitCapacity int64 `env:"RATE_LIMIT_CAPACITY" envDefault:"100"`
Expand All @@ -70,14 +72,20 @@ type Config struct {
}

func main() {
if err := run(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}

func run() error {
// Load configuration
cfg := Config{}
if err := godotenv.Load(); err != nil {
// .env file is optional
}
if err := env.Parse(&cfg); err != nil {
fmt.Fprintf(os.Stderr, "Failed to parse config: %v\n", err)
os.Exit(1)
return fmt.Errorf("failed to parse config: %w", err)
}

// Setup logger
Expand Down Expand Up @@ -216,7 +224,7 @@ func main() {

mqttProxy, err := proxy.NewMQTT(mqttProxyConfig, instrumentedHandler)
if err != nil {
logger.Error("Failed to create MQTT proxy", slog.String("error", err.Error()))
logger.Error("Failed to create MQTT proxy", slog.String(logError, err.Error()))
} else {
g.Go(func() error {
address := net.JoinHostPort(mqttProxyConfig.Host, mqttProxyConfig.Port)
Expand Down Expand Up @@ -246,21 +254,22 @@ func main() {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), cfg.ShutdownTimeout)
defer shutdownCancel()

done := make(chan error)
done := make(chan error, 1)
go func() {
done <- g.Wait()
}()

select {
case err := <-done:
if err != nil {
logger.Error("Shutdown error", slog.String("error", err.Error()))
os.Exit(1)
logger.Error("Shutdown error", slog.String(logError, err.Error()))
return err
}
logger.Info("Graceful shutdown completed")
return nil
case <-shutdownCtx.Done():
logger.Warn("Shutdown timeout exceeded, forcing exit")
os.Exit(1)
return shutdownCtx.Err()
}
}

Expand All @@ -274,7 +283,7 @@ func setupLogger(level, format string) *slog.Logger {
logLevel = slog.LevelInfo
case "warn":
logLevel = slog.LevelWarn
case "error":
case logError:
logLevel = slog.LevelError
default:
logLevel = slog.LevelInfo
Expand Down Expand Up @@ -311,7 +320,7 @@ func startMetricsServer(port int, logger *slog.Logger) {
}

if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error("Metrics server error", slog.String("error", err.Error()))
logger.Error("Metrics server error", slog.String(logError, err.Error()))
}
}

Expand All @@ -334,6 +343,6 @@ func startHealthServer(port int, checker *health.Checker, logger *slog.Logger) {
}

if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error("Health server error", slog.String("error", err.Error()))
logger.Error("Health server error", slog.String(logError, err.Error()))
}
}
22 changes: 10 additions & 12 deletions pkg/breaker/breaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ import (
"time"
)

var (
// ErrCircuitOpen is returned when the circuit breaker is open.
ErrCircuitOpen = errors.New("circuit breaker is open")
)
// ErrCircuitOpen is returned when the circuit breaker is open.
var ErrCircuitOpen = errors.New("circuit breaker is open")

// State represents the circuit breaker state.
type State int
Expand Down Expand Up @@ -51,14 +49,14 @@ type Config struct {

// CircuitBreaker implements the circuit breaker pattern.
type CircuitBreaker struct {
mu sync.RWMutex
config Config
state State
failures int
successes int
lastFailureTime time.Time
lastStateChange time.Time
onStateChange func(from, to State)
mu sync.RWMutex
config Config
state State
failures int
successes int
lastFailureTime time.Time
lastStateChange time.Time
onStateChange func(from, to State)
}

// New creates a new circuit breaker.
Expand Down
18 changes: 9 additions & 9 deletions pkg/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"fmt"
)

// Common error types
// Common error types.
var (
// ErrUnauthorized indicates authentication or authorization failure.
ErrUnauthorized = errors.New("unauthorized")
Expand Down Expand Up @@ -41,11 +41,11 @@ var (

// ProxyError wraps an error with additional context.
type ProxyError struct {
Op string // Operation that failed
Protocol string // Protocol (mqtt, http, coap, websocket)
SessionID string // Session identifier
Op string // Operation that failed
Protocol string // Protocol (mqtt, http, coap, websocket)
SessionID string // Session identifier
RemoteAddr string // Client address
Err error // Underlying error
Err error // Underlying error
}

// Error implements the error interface.
Expand All @@ -67,11 +67,11 @@ func New(op, protocol, sessionID, remoteAddr string, err error) error {
return nil
}
return &ProxyError{
Op: op,
Protocol: protocol,
SessionID: sessionID,
Op: op,
Protocol: protocol,
SessionID: sessionID,
RemoteAddr: remoteAddr,
Err: err,
Err: err,
}
}

Expand Down
18 changes: 11 additions & 7 deletions pkg/health/health.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ func (c *Checker) HTTPHandler() http.HandlerFunc {
w.Header().Set("Content-Type", "application/json")
if status == StatusUnhealthy {
w.WriteHeader(http.StatusServiceUnavailable)
} else if status == StatusDegraded {
w.WriteHeader(http.StatusOK) // Still accept traffic
} else {
w.WriteHeader(http.StatusOK)
w.WriteHeader(http.StatusOK) // Still accept traffic
}

json.NewEncoder(w).Encode(response)
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
}

Expand All @@ -135,9 +135,11 @@ func LivenessHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
if err := json.NewEncoder(w).Encode(map[string]string{
"status": "alive",
})
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
}

Expand All @@ -161,6 +163,8 @@ func (c *Checker) ReadinessHandler() http.HandlerFunc {
w.WriteHeader(http.StatusOK)
}

json.NewEncoder(w).Encode(response)
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
}
Loading
Loading