Skip to content

Commit 41cda15

Browse files
author
Dušan Borovčanin
authored
MPX-67 - Refactor mProxy (#114)
* Refactor code Signed-off-by: dusan <borovcanindusan1@gmail.com> * Add tests Signed-off-by: dusan <borovcanindusan1@gmail.com> * Fix tests Signed-off-by: dusan <borovcanindusan1@gmail.com> * Add conn limits Signed-off-by: dusan <borovcanindusan1@gmail.com> * Add pools and improve code quality Signed-off-by: dusan <borovcanindusan1@gmail.com> * Improve default values Signed-off-by: dusan <borovcanindusan1@gmail.com> * Update dependencies Signed-off-by: dusan <borovcanindusan1@gmail.com> --------- Signed-off-by: dusan <borovcanindusan1@gmail.com>
1 parent c63026c commit 41cda15

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+7961
-2327
lines changed

cmd/main.go

Lines changed: 239 additions & 119 deletions
Large diffs are not rendered by default.

cmd/production/handlers.go

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
// Copyright (c) Abstract Machines
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package main
5+
6+
import (
7+
"context"
8+
"log/slog"
9+
"time"
10+
11+
"github.com/absmach/mproxy/pkg/handler"
12+
"github.com/absmach/mproxy/pkg/metrics"
13+
"github.com/absmach/mproxy/pkg/ratelimit"
14+
)
15+
16+
// RateLimitedHandler wraps a handler with rate limiting.
17+
type RateLimitedHandler struct {
18+
handler handler.Handler
19+
perClientLimiter *ratelimit.Limiter
20+
globalLimiter *ratelimit.TokenBucket
21+
metrics *metrics.Metrics
22+
logger *slog.Logger
23+
}
24+
25+
// AuthConnect implements handler.Handler with rate limiting.
26+
func (h *RateLimitedHandler) AuthConnect(ctx context.Context, hctx *handler.Context) error {
27+
// Check global rate limit
28+
if !h.globalLimiter.Allow() {
29+
h.metrics.RateLimitedRequests.WithLabelValues(hctx.Protocol, "global").Inc()
30+
h.logger.Warn("Global rate limit exceeded",
31+
slog.String("remote", hctx.RemoteAddr),
32+
slog.String("protocol", hctx.Protocol))
33+
return ratelimit.ErrRateLimitExceeded
34+
}
35+
36+
// Check per-client rate limit
37+
clientID := hctx.RemoteAddr
38+
if hctx.ClientID != "" {
39+
clientID = hctx.ClientID
40+
}
41+
42+
if !h.perClientLimiter.Allow(clientID) {
43+
h.metrics.RateLimitedRequests.WithLabelValues(hctx.Protocol, "per_client").Inc()
44+
h.logger.Warn("Per-client rate limit exceeded",
45+
slog.String("client", clientID),
46+
slog.String("protocol", hctx.Protocol))
47+
return ratelimit.ErrRateLimitExceeded
48+
}
49+
50+
return h.handler.AuthConnect(ctx, hctx)
51+
}
52+
53+
// AuthPublish implements handler.Handler with rate limiting.
54+
func (h *RateLimitedHandler) AuthPublish(ctx context.Context, hctx *handler.Context, topic *string, payload *[]byte) error {
55+
// Could add payload size rate limiting here
56+
return h.handler.AuthPublish(ctx, hctx, topic, payload)
57+
}
58+
59+
// AuthSubscribe implements handler.Handler.
60+
func (h *RateLimitedHandler) AuthSubscribe(ctx context.Context, hctx *handler.Context, topics *[]string) error {
61+
return h.handler.AuthSubscribe(ctx, hctx, topics)
62+
}
63+
64+
// OnConnect implements handler.Handler.
65+
func (h *RateLimitedHandler) OnConnect(ctx context.Context, hctx *handler.Context) error {
66+
return h.handler.OnConnect(ctx, hctx)
67+
}
68+
69+
// OnPublish implements handler.Handler.
70+
func (h *RateLimitedHandler) OnPublish(ctx context.Context, hctx *handler.Context, topic string, payload []byte) error {
71+
return h.handler.OnPublish(ctx, hctx, topic, payload)
72+
}
73+
74+
// OnSubscribe implements handler.Handler.
75+
func (h *RateLimitedHandler) OnSubscribe(ctx context.Context, hctx *handler.Context, topics []string) error {
76+
return h.handler.OnSubscribe(ctx, hctx, topics)
77+
}
78+
79+
// OnUnsubscribe implements handler.Handler.
80+
func (h *RateLimitedHandler) OnUnsubscribe(ctx context.Context, hctx *handler.Context, topics []string) error {
81+
return h.handler.OnUnsubscribe(ctx, hctx, topics)
82+
}
83+
84+
// OnDisconnect implements handler.Handler.
85+
func (h *RateLimitedHandler) OnDisconnect(ctx context.Context, hctx *handler.Context) error {
86+
return h.handler.OnDisconnect(ctx, hctx)
87+
}
88+
89+
// InstrumentedHandler wraps a handler with metrics instrumentation.
90+
type InstrumentedHandler struct {
91+
handler handler.Handler
92+
metrics *metrics.Metrics
93+
logger *slog.Logger
94+
}
95+
96+
// AuthConnect implements handler.Handler with metrics.
97+
func (h *InstrumentedHandler) AuthConnect(ctx context.Context, hctx *handler.Context) error {
98+
start := time.Now()
99+
h.metrics.AuthAttempts.WithLabelValues(hctx.Protocol, "connect").Inc()
100+
101+
err := h.handler.AuthConnect(ctx, hctx)
102+
103+
if err != nil {
104+
h.metrics.AuthFailures.WithLabelValues(hctx.Protocol, "connect", "unauthorized").Inc()
105+
}
106+
107+
duration := time.Since(start).Seconds()
108+
h.metrics.RequestDuration.WithLabelValues(hctx.Protocol, "connect").Observe(duration)
109+
110+
return err
111+
}
112+
113+
// AuthPublish implements handler.Handler with metrics.
114+
func (h *InstrumentedHandler) AuthPublish(ctx context.Context, hctx *handler.Context, topic *string, payload *[]byte) error {
115+
start := time.Now()
116+
h.metrics.AuthAttempts.WithLabelValues(hctx.Protocol, "publish").Inc()
117+
118+
if payload != nil {
119+
h.metrics.RequestSize.WithLabelValues(hctx.Protocol).Observe(float64(len(*payload)))
120+
}
121+
122+
err := h.handler.AuthPublish(ctx, hctx, topic, payload)
123+
124+
if err != nil {
125+
h.metrics.AuthFailures.WithLabelValues(hctx.Protocol, "publish", "unauthorized").Inc()
126+
}
127+
128+
duration := time.Since(start).Seconds()
129+
h.metrics.RequestDuration.WithLabelValues(hctx.Protocol, "publish").Observe(duration)
130+
131+
status := "success"
132+
if err != nil {
133+
status = "error"
134+
}
135+
h.metrics.RequestsTotal.WithLabelValues(hctx.Protocol, "publish", status).Inc()
136+
137+
return err
138+
}
139+
140+
// AuthSubscribe implements handler.Handler with metrics.
141+
func (h *InstrumentedHandler) AuthSubscribe(ctx context.Context, hctx *handler.Context, topics *[]string) error {
142+
start := time.Now()
143+
h.metrics.AuthAttempts.WithLabelValues(hctx.Protocol, "subscribe").Inc()
144+
145+
err := h.handler.AuthSubscribe(ctx, hctx, topics)
146+
147+
if err != nil {
148+
h.metrics.AuthFailures.WithLabelValues(hctx.Protocol, "subscribe", "unauthorized").Inc()
149+
}
150+
151+
duration := time.Since(start).Seconds()
152+
h.metrics.RequestDuration.WithLabelValues(hctx.Protocol, "subscribe").Observe(duration)
153+
154+
status := "success"
155+
if err != nil {
156+
status = "error"
157+
}
158+
h.metrics.RequestsTotal.WithLabelValues(hctx.Protocol, "subscribe", status).Inc()
159+
160+
return err
161+
}
162+
163+
// OnConnect implements handler.Handler with metrics.
164+
func (h *InstrumentedHandler) OnConnect(ctx context.Context, hctx *handler.Context) error {
165+
h.metrics.ActiveConnections.WithLabelValues(hctx.Protocol, "client").Inc()
166+
h.metrics.TotalConnections.WithLabelValues(hctx.Protocol, "client", "accepted").Inc()
167+
168+
return h.handler.OnConnect(ctx, hctx)
169+
}
170+
171+
// OnPublish implements handler.Handler with metrics.
172+
func (h *InstrumentedHandler) OnPublish(ctx context.Context, hctx *handler.Context, topic string, payload []byte) error {
173+
if hctx.Protocol == "mqtt" {
174+
h.metrics.MQTTPackets.WithLabelValues("publish", "upstream").Inc()
175+
}
176+
177+
return h.handler.OnPublish(ctx, hctx, topic, payload)
178+
}
179+
180+
// OnSubscribe implements handler.Handler with metrics.
181+
func (h *InstrumentedHandler) OnSubscribe(ctx context.Context, hctx *handler.Context, topics []string) error {
182+
if hctx.Protocol == "mqtt" {
183+
h.metrics.MQTTPackets.WithLabelValues("subscribe", "upstream").Inc()
184+
}
185+
186+
return h.handler.OnSubscribe(ctx, hctx, topics)
187+
}
188+
189+
// OnUnsubscribe implements handler.Handler with metrics.
190+
func (h *InstrumentedHandler) OnUnsubscribe(ctx context.Context, hctx *handler.Context, topics []string) error {
191+
if hctx.Protocol == "mqtt" {
192+
h.metrics.MQTTPackets.WithLabelValues("unsubscribe", "upstream").Inc()
193+
}
194+
195+
return h.handler.OnUnsubscribe(ctx, hctx, topics)
196+
}
197+
198+
// OnDisconnect implements handler.Handler with metrics.
199+
func (h *InstrumentedHandler) OnDisconnect(ctx context.Context, hctx *handler.Context) error {
200+
h.metrics.ActiveConnections.WithLabelValues(hctx.Protocol, "client").Dec()
201+
202+
return h.handler.OnDisconnect(ctx, hctx)
203+
}

0 commit comments

Comments
 (0)