diff --git a/README.md b/README.md
index 85f603f..02aa569 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,18 @@
+
+
# mProxy
-![Go Report Card][grc]
-[![License][LIC-BADGE]][LIC]
+### Lightweight multi‑protocol IoT proxy
+
+### Pluggable Auth • Observability • Packet Manipulation
+
+[](https://goreportcard.com/report/github.com/absmach/mproxy)
+[](https://github.com/absmach/mproxy/releases)
+[](LICENSE)
+
+Made with ❤️ by [Abstract Machines](https://www.absmach.eu)
+
+
mProxy is a lightweight, scalable, and customizable IoT API gateway designed to support seamless communication across multiple protocols. It enables real-time packet manipulation, features pluggable authentication mechanisms, and offers observability for monitoring and troubleshooting. Built for flexibility, mProxy can be deployed as a sidecar or standalone service and can also function as a library for easy integration into applications.
@@ -62,6 +73,74 @@ Built with Go programming language, it's optimized for low resource usage, makin
Can be deployed as a sidecar to enhance existing microservices or as a standalone service for direct IoT device interaction.
Available as a library for integration into existing applications.
+## Quickstart
+
+- Build and run the sample proxy:
+
+```bash
+make
+./build/mgate
+```
+
+- Alternatively, run directly for development:
+
+```bash
+go run cmd/main.go
+```
+
+## Try It Now
+
+Spin up local dependencies and run protocol proxies:
+
+```bash
+# Start MQTT broker (Mosquitto) with WS support
+examples/server/mosquitto/server.sh
+
+# Start HTTP echo server
+go run examples/server/http-echo/main.go &
+
+# Start OCSP/CRL mock responder
+go run examples/ocsp-crl-responder/main.go &
+
+# Start mProxy example servers
+go run cmd/main.go
+```
+
+Client examples:
+
+- MQTT (no TLS): `examples/client/mqtt/without_tls.sh`
+- MQTT (TLS): `examples/client/mqtt/with_tls.sh`
+- MQTT (mTLS): `examples/client/mqtt/with_mtls.sh`
+- MQTT over WebSocket (no TLS): `go run examples/client/websocket/without_tls/main.go`
+- MQTT over WebSocket (TLS): `go run examples/client/websocket/with_tls/main.go`
+- MQTT over WebSocket (mTLS): `go run examples/client/websocket/with_mtls/main.go`
+- HTTP (no TLS): `examples/client/http/without_tls.sh`
+- HTTP (TLS): `examples/client/http/with_tls.sh`
+- HTTP (mTLS): `examples/client/http/with_mtls.sh`
+- CoAP (no DTLS): `examples/client/coap/without_dtls.sh`
+- CoAP (DTLS): `examples/client/coap/with_dtls.sh`
+
+## Protocol Matrix & Examples
+
+| Protocol | Mode | Port | Path Prefix |
+|:---------|:------------|-------:|:------------|
+| `MQTT` | no TLS | `1884` | — |
+| `MQTT` | TLS | `8883` | — |
+| `MQTT` | mTLS | `8884` | — |
+| `MQTT/WS`| no TLS | `8083` | — |
+| `MQTT/WS`| TLS | `8084` | — |
+| `MQTT/WS`| mTLS | `8085` | /mqtt |
+| `HTTP` | no TLS | `8086` | /messages |
+| `HTTP` | TLS | `8087` | /messages |
+| `HTTP` | mTLS | `8088` | /messages |
+| `CoAP` | no DTLS | `5682` | — |
+| `CoAP` | DTLS | `5684` | — |
+
+Examples:
+
+- Servers: `examples/server/mosquitto`, `examples/server/http-echo`, `examples/ocsp-crl-responder`
+- Clients: `examples/client/mqtt`, `examples/client/websocket`, `examples/client/http`, `examples/client/coap`
+
## Usage
```bash
@@ -90,24 +169,24 @@ mProxy can parse and understand protocol packages, and upon their detection, it
type Handler interface {
// Authorization on client `CONNECT`
// Each of the params are passed by reference, so that it can be changed
- AuthConnect(ctx context.Context) error
+ AuthConnect(ctx context.Context) error
- // Authorization on client `PUBLISH`
- // Topic is passed by reference, so that it can be modified
- AuthPublish(ctx context.Context, topic *string, payload *[]byte) error
+ // Authorization on client `PUBLISH`
+ // Topic is passed by reference, so that it can be modified
+ AuthPublish(ctx context.Context, topic *string, payload *[]byte) error
- // Authorization on client `SUBSCRIBE`
- // Topics are passed by reference, so that they can be modified
- AuthSubscribe(ctx context.Context, topics *[]string) error
+ // Authorization on client `SUBSCRIBE`
+ // Topics are passed by reference, so that they can be modified
+ AuthSubscribe(ctx context.Context, topics *[]string) error
- // After client successfully connected
- Connect(ctx context.Context)
+ // After client successfully connected
+ Connect(ctx context.Context)
- // After client successfully published
- Publish(ctx context.Context, topic *string, payload *[]byte)
+ // After client successfully published
+ Publish(ctx context.Context, topic *string, payload *[]byte)
- // After client successfully subscribed
- Subscribe(ctx context.Context, topics *[]string)
+ // After client successfully subscribed
+ Subscribe(ctx context.Context, topics *[]string)
// After client unsubscribed
Unsubscribe(ctx context.Context, topics *[]string)
@@ -117,7 +196,7 @@ type Handler interface {
}
```
-The Handler interface is inspired by MQTT protocol control packets; if the underlying protocol does not support some of these actions, the implementation can simply omit them. An example of implementation is given [here](examples/simple/simple.go), alongside with it's [`main()` function](cmd/main.go).
+The Handler interface is inspired by MQTT protocol control packets; if the underlying protocol does not support some of these actions, the implementation can simply omit them. An example of implementation is given in this [file](examples/simple/simple.go), alongside with it's [`main()` function](cmd/main.go).
## Deployment
@@ -265,110 +344,136 @@ The script can be used alongside the simple go-coap server provided at `example/
examples/client/coap/with_dtls.sh
```
+## Production Mode
+
+For a production-ready run with observability, circuit breakers, rate limiting, and connection pooling, use the example in `cmd/production/main.go`:
+
+```bash
+go run cmd/production/main.go
+```
+
+Endpoints:
+
+- Metrics (Prometheus): `http://localhost:9090/metrics`
+- Health (JSON): `http://localhost:8080/health`
+
+Key environment variables:
+
+- `METRICS_PORT`, `HEALTH_PORT`, `LOG_LEVEL`, `LOG_FORMAT`
+- `MAX_CONNECTIONS`, `MAX_GOROUTINES`
+- `POOL_MAX_IDLE`, `POOL_MAX_ACTIVE`, `POOL_IDLE_TIMEOUT`
+- `BREAKER_MAX_FAILURES`, `BREAKER_RESET_TIMEOUT`, `BREAKER_TIMEOUT`
+- `RATE_LIMIT_CAPACITY`, `RATE_LIMIT_REFILL`, `GLOBAL_RATE_CAPACITY`, `GLOBAL_RATE_REFILL`
+- `READ_TIMEOUT`, `WRITE_TIMEOUT`, `IDLE_TIMEOUT`, `SHUTDOWN_TIMEOUT`
+
## Configuration
The service is configured using the environment variables presented in the following table. Note that any unset variables will be replaced with their default values.
| Variable | Description | Default |
| ------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------- |
-| MPROXY_MQTT_WITHOUT_TLS_ADDRESS | MQTT without TLS inbound (IN) connection listening address | :1884 |
-| MPROXY_MQTT_WITHOUT_TLS_TARGET | MQTT without TLS outbound (OUT) connection address | localhost:1883 |
-| MPROXY_MQTT_WITH_TLS_ADDRESS | MQTT with TLS inbound (IN) connection listening address | :8883 |
-| MPROXY_MQTT_WITH_TLS_TARGET | MQTT with TLS outbound (OUT) connection address | localhost:1883 |
-| MPROXY_MQTT_WITH_TLS_CERT_FILE | MQTT with TLS certificate file path | ssl/certs/server.crt |
-| MPROXY_MQTT_WITH_TLS_KEY_FILE | MQTT with TLS key file path | ssl/certs/server.key |
-| MPROXY_MQTT_WITH_TLS_SERVER_CA_FILE | MQTT with TLS server CA file path | ssl/certs/ca.crt |
-| MPROXY_MQTT_WITH_MTLS_ADDRESS | MQTT with mTLS inbound (IN) connection listening address | :8884 |
-| MPROXY_MQTT_WITH_MTLS_TARGET | MQTT with mTLS outbound (OUT) connection address | localhost:1883 |
-| MPROXY_MQTT_WITH_MTLS_CERT_FILE | MQTT with mTLS certificate file path | ssl/certs/server.crt |
-| MPROXY_MQTT_WITH_MTLS_KEY_FILE | MQTT with mTLS key file path | ssl/certs/server.key |
-| MPROXY_MQTT_WITH_MTLS_SERVER_CA_FILE | MQTT with mTLS server CA file path | ssl/certs/ca.crt |
-| MPROXY_MQTT_WITH_MTLS_CLIENT_CA_FILE | MQTT with mTLS client CA file path | ssl/certs/ca.crt |
-| MPROXY_MQTT_WITH_MTLS_CERT_VERIFICATION_METHODS | MQTT with mTLS certificate verification methods, if no value or unset then mProxy server will not do client validation | ocsp |
-| MPROXY_MQTT_WITH_MTLS_OCSP_RESPONDER_URL | MQTT with mTLS OCSP responder URL, it is used if OCSP responder URL is not available in client certificate AIA | |
-| MPROXY_MQTT_WS_WITHOUT_TLS_ADDRESS | MQTT over Websocket without TLS inbound (IN) connection listening address | :8083 |
-| MPROXY_MQTT_WS_WITHOUT_TLS_TARGET | MQTT over Websocket without TLS outbound (OUT) connection address | ws://localhost:8000/ |
-| MPROXY_MQTT_WS_WITH_TLS_ADDRESS | MQTT over Websocket with TLS inbound (IN) connection listening address | :8084 |
-| MPROXY_MQTT_WS_WITH_TLS_TARGET | MQTT over Websocket with TLS outbound (OUT) connection address | ws://localhost:8000/ |
-| MPROXY_MQTT_WS_WITH_TLS_CERT_FILE | MQTT over Websocket with TLS certificate file path | ssl/certs/server.crt |
-| MPROXY_MQTT_WS_WITH_TLS_KEY_FILE | MQTT over Websocket with TLS key file path | ssl/certs/server.key |
-| MPROXY_MQTT_WS_WITH_TLS_SERVER_CA_FILE | MQTT over Websocket with TLS server CA file path | ssl/certs/ca.crt |
-| MPROXY_MQTT_WS_WITH_MTLS_ADDRESS | MQTT over Websocket with mTLS inbound (IN) connection listening address | :8085 |
-| MPROXY_MQTT_WS_WITH_MTLS_PATH_PREFIX | MQTT over Websocket with mTLS inbound (IN) connection path | /mqtt |
-| MPROXY_MQTT_WS_WITH_MTLS_TARGET | MQTT over Websocket with mTLS outbound (OUT) connection address | ws://localhost:8000/ |
-| MPROXY_MQTT_WS_WITH_MTLS_CERT_FILE | MQTT over Websocket with mTLS certificate file path | ssl/certs/server.crt |
-| MPROXY_MQTT_WS_WITH_MTLS_KEY_FILE | MQTT over Websocket with mTLS key file path | ssl/certs/server.key |
-| MPROXY_MQTT_WS_WITH_MTLS_SERVER_CA_FILE | MQTT over Websocket with mTLS server CA file path | ssl/certs/ca.crt |
-| MPROXY_MQTT_WS_WITH_MTLS_CLIENT_CA_FILE | MQTT over Websocket with mTLS client CA file path | ssl/certs/ca.crt |
-| MPROXY_MQTT_WS_WITH_MTLS_CERT_VERIFICATION_METHODS | MQTT over Websocket with mTLS certificate verification methods, if no value or unset then mProxy server will not do client validation | ocsp |
-| MPROXY_MQTT_WS_WITH_MTLS_OCSP_RESPONDER_URL | MQTT over Websocket with mTLS OCSP responder URL, it is used if OCSP responder URL is not available in client certificate AIA | |
-| MPROXY_HTTP_WITHOUT_TLS_ADDRESS | HTTP without TLS inbound (IN) connection listening address | :8086 |
-| MPROXY_HTTP_WITHOUT_TLS_PATH_PREFIX | HTTP without TLS inbound (IN) connection path | /messages |
-| MPROXY_HTTP_WITHOUT_TLS_TARGET | HTTP without TLS outbound (OUT) connection address | |
-| MPROXY_HTTP_WITH_TLS_ADDRESS | HTTP with TLS inbound (IN) connection listening address | :8087 |
-| MPROXY_HTTP_WITH_TLS_PATH_PREFIX | HTTP with TLS inbound (IN) connection path | /messages |
-| MPROXY_HTTP_WITH_TLS_TARGET | HTTP with TLS outbound (OUT) connection address | |
-| MPROXY_HTTP_WITH_TLS_CERT_FILE | HTTP with TLS certificate file path | ssl/certs/server.crt |
-| MPROXY_HTTP_WITH_TLS_KEY_FILE | HTTP with TLS key file path | ssl/certs/server.key |
-| MPROXY_HTTP_WITH_TLS_SERVER_CA_FILE | HTTP with TLS server CA file path | ssl/certs/ca.crt |
-| MPROXY_HTTP_WITH_MTLS_ADDRESS | HTTP with mTLS inbound (IN) connection listening address | :8088 |
-| MPROXY_HTTP_WITH_MTLS_PATH_PREFIX | HTTP with mTLS inbound (IN) connection path | /messages |
-| MPROXY_HTTP_WITH_MTLS_TARGET | HTTP with mTLS outbound (OUT) connection address | |
-| MPROXY_HTTP_WITH_MTLS_CERT_FILE | HTTP with mTLS certificate file path | ssl/certs/server.crt |
-| MPROXY_HTTP_WITH_MTLS_KEY_FILE | HTTP with mTLS key file path | ssl/certs/server.key |
-| MPROXY_HTTP_WITH_MTLS_SERVER_CA_FILE | HTTP with mTLS server CA file path | ssl/certs/ca.crt |
-| MPROXY_HTTP_WITH_MTLS_CLIENT_CA_FILE | HTTP with mTLS client CA file path | ssl/certs/ca.crt |
-| MPROXY_HTTP_WITH_MTLS_CERT_VERIFICATION_METHODS | HTTP with mTLS certificate verification methods, if no value or unset then mProxy server will not do client validation | ocsp |
-| MPROXY_HTTP_WITH_MTLS_OCSP_RESPONDER_URL | HTTP with mTLS OCSP responder URL, it is used if OCSP responder URL is not available in client certificate AIA | |
-| MPROXY_COAP_WITHOUT_DTLS_HOST | CoAP without DTLS inbound (IN) connection listening host | localhost |
-| MPROXY_COAP_WITHOUT_DTLS_PORT | CoAP without DTLS inbound (IN) connection listening port | 5682 |
-| MPROXY_COAP_WITHOUT_DTLS_TARGET_HOST | CoAP without DTLS outbound (OUT) connection listening host | localhost |
-| MPROXY_COAP_WITHOUT_DTLS_TARGET_PORT | CoAP without DTLS outbound (OUT) connection listening port | 5683 |
-| MPROXY_COAP_WITH_DTLS_HOST | CoAP with DTLS inbound (IN) connection listening host | localhost |
-| MPROXY_COAP_WITH_DTLS_PORT | CoAP with DTLS inbound (IN) connection listening port | 5684 |
-| MPROXY_COAP_WITH_DTLS_TARGET_HOST | CoAP with DTLS outbound (OUT) connection listening host | localhost |
-| MPROXY_COAP_WITH_DTLS_TARGET_PORT | CoAP with DTLS outbound (OUT) connection listening port | 5683 |
-| MPROXY_COAP_WITH_DTLS_CERT_FILE | CoAP with DTLS certificate file path | ssl/certs/server.crt |
-| MPROXY_COAP_WITH_DTLS_KEY_FILE | CoAP with DTLS key file path | ssl/certs/server.key |
-| MPROXY_COAP_WITH_DTLS_SERVER_CA_FILE | CoAP with DTLS server CA file path | ssl/certs/ca.crt |
+| MGATE_MQTT_WITHOUT_TLS_ADDRESS | MQTT without TLS inbound (IN) connection listening address | :1884 |
+| MGATE_MQTT_WITHOUT_TLS_TARGET | MQTT without TLS outbound (OUT) connection address | localhost:1883 |
+| MGATE_MQTT_WITH_TLS_ADDRESS | MQTT with TLS inbound (IN) connection listening address | :8883 |
+| MGATE_MQTT_WITH_TLS_TARGET | MQTT with TLS outbound (OUT) connection address | localhost:1883 |
+| MGATE_MQTT_WITH_TLS_CERT_FILE | MQTT with TLS certificate file path | ssl/certs/server.crt |
+| MGATE_MQTT_WITH_TLS_KEY_FILE | MQTT with TLS key file path | ssl/certs/server.key |
+| MGATE_MQTT_WITH_TLS_SERVER_CA_FILE | MQTT with TLS server CA file path | ssl/certs/ca.crt |
+| MGATE_MQTT_WITH_MTLS_ADDRESS | MQTT with mTLS inbound (IN) connection listening address | :8884 |
+| MGATE_MQTT_WITH_MTLS_TARGET | MQTT with mTLS outbound (OUT) connection address | localhost:1883 |
+| MGATE_MQTT_WITH_MTLS_CERT_FILE | MQTT with mTLS certificate file path | ssl/certs/server.crt |
+| MGATE_MQTT_WITH_MTLS_KEY_FILE | MQTT with mTLS key file path | ssl/certs/server.key |
+| MGATE_MQTT_WITH_MTLS_SERVER_CA_FILE | MQTT with mTLS server CA file path | ssl/certs/ca.crt |
+| MGATE_MQTT_WITH_MTLS_CLIENT_CA_FILE | MQTT with mTLS client CA file path | ssl/certs/ca.crt |
+| MGATE_MQTT_WITH_MTLS_CERT_VERIFICATION_METHODS | MQTT with mTLS certificate verification methods, if no value or unset then mGate server will not do client validation | ocsp |
+| MGATE_MQTT_WITH_MTLS_OCSP_RESPONDER_URL | MQTT with mTLS OCSP responder URL, it is used if OCSP responder URL is not available in client certificate AIA | |
+| MGATE_MQTT_WS_WITHOUT_TLS_ADDRESS | MQTT over Websocket without TLS inbound (IN) connection listening address | :8083 |
+| MGATE_MQTT_WS_WITHOUT_TLS_TARGET | MQTT over Websocket without TLS outbound (OUT) connection address | ws://localhost:8000/ |
+| MGATE_MQTT_WS_WITH_TLS_ADDRESS | MQTT over Websocket with TLS inbound (IN) connection listening address | :8084 |
+| MGATE_MQTT_WS_WITH_TLS_TARGET | MQTT over Websocket with TLS outbound (OUT) connection address | ws://localhost:8000/ |
+| MGATE_MQTT_WS_WITH_TLS_CERT_FILE | MQTT over Websocket with TLS certificate file path | ssl/certs/server.crt |
+| MGATE_MQTT_WS_WITH_TLS_KEY_FILE | MQTT over Websocket with TLS key file path | ssl/certs/server.key |
+| MGATE_MQTT_WS_WITH_TLS_SERVER_CA_FILE | MQTT over Websocket with TLS server CA file path | ssl/certs/ca.crt |
+| MGATE_MQTT_WS_WITH_MTLS_ADDRESS | MQTT over Websocket with mTLS inbound (IN) connection listening address | :8085 |
+| MGATE_MQTT_WS_WITH_MTLS_PATH_PREFIX | MQTT over Websocket with mTLS inbound (IN) connection path | /mqtt |
+| MGATE_MQTT_WS_WITH_MTLS_TARGET | MQTT over Websocket with mTLS outbound (OUT) connection address | ws://localhost:8000/ |
+| MGATE_MQTT_WS_WITH_MTLS_CERT_FILE | MQTT over Websocket with mTLS certificate file path | ssl/certs/server.crt |
+| MGATE_MQTT_WS_WITH_MTLS_KEY_FILE | MQTT over Websocket with mTLS key file path | ssl/certs/server.key |
+| MGATE_MQTT_WS_WITH_MTLS_SERVER_CA_FILE | MQTT over Websocket with mTLS server CA file path | ssl/certs/ca.crt |
+| MGATE_MQTT_WS_WITH_MTLS_CLIENT_CA_FILE | MQTT over Websocket with mTLS client CA file path | ssl/certs/ca.crt |
+| MGATE_MQTT_WS_WITH_MTLS_CERT_VERIFICATION_METHODS | MQTT over Websocket with mTLS certificate verification methods, if no value or unset then mGate server will not do client validation | ocsp |
+| MGATE_MQTT_WS_WITH_MTLS_OCSP_RESPONDER_URL | MQTT over Websocket with mTLS OCSP responder URL, it is used if OCSP responder URL is not available in client certificate AIA | |
+| MGATE_HTTP_WITHOUT_TLS_ADDRESS | HTTP without TLS inbound (IN) connection listening address | :8086 |
+| MGATE_HTTP_WITHOUT_TLS_PATH_PREFIX | HTTP without TLS inbound (IN) connection path | /messages |
+| MGATE_HTTP_WITHOUT_TLS_TARGET | HTTP without TLS outbound (OUT) connection address | |
+| MGATE_HTTP_WITH_TLS_ADDRESS | HTTP with TLS inbound (IN) connection listening address | :8087 |
+| MGATE_HTTP_WITH_TLS_PATH_PREFIX | HTTP with TLS inbound (IN) connection path | /messages |
+| MGATE_HTTP_WITH_TLS_TARGET | HTTP with TLS outbound (OUT) connection address | |
+| MGATE_HTTP_WITH_TLS_CERT_FILE | HTTP with TLS certificate file path | ssl/certs/server.crt |
+| MGATE_HTTP_WITH_TLS_KEY_FILE | HTTP with TLS key file path | ssl/certs/server.key |
+| MGATE_HTTP_WITH_TLS_SERVER_CA_FILE | HTTP with TLS server CA file path | ssl/certs/ca.crt |
+| MGATE_HTTP_WITH_MTLS_ADDRESS | HTTP with mTLS inbound (IN) connection listening address | :8088 |
+| MGATE_HTTP_WITH_MTLS_PATH_PREFIX | HTTP with mTLS inbound (IN) connection path | /messages |
+| MGATE_HTTP_WITH_MTLS_TARGET | HTTP with mTLS outbound (OUT) connection address | |
+| MGATE_HTTP_WITH_MTLS_CERT_FILE | HTTP with mTLS certificate file path | ssl/certs/server.crt |
+| MGATE_HTTP_WITH_MTLS_KEY_FILE | HTTP with mTLS key file path | ssl/certs/server.key |
+| MGATE_HTTP_WITH_MTLS_SERVER_CA_FILE | HTTP with mTLS server CA file path | ssl/certs/ca.crt |
+| MGATE_HTTP_WITH_MTLS_CLIENT_CA_FILE | HTTP with mTLS client CA file path | ssl/certs/ca.crt |
+| MGATE_HTTP_WITH_MTLS_CERT_VERIFICATION_METHODS | HTTP with mTLS certificate verification methods, if no value or unset then mGate server will not do client validation | ocsp |
+| MGATE_HTTP_WITH_MTLS_OCSP_RESPONDER_URL | HTTP with mTLS OCSP responder URL, it is used if OCSP responder URL is not available in client certificate AIA | |
+| MGATE_COAP_WITHOUT_DTLS_HOST | CoAP without DTLS inbound (IN) connection listening host | localhost |
+| MGATE_COAP_WITHOUT_DTLS_PORT | CoAP without DTLS inbound (IN) connection listening port | 5682 |
+| MGATE_COAP_WITHOUT_DTLS_TARGET_HOST | CoAP without DTLS outbound (OUT) connection listening host | localhost |
+| MGATE_COAP_WITHOUT_DTLS_TARGET_PORT | CoAP without DTLS outbound (OUT) connection listening port | 5683 |
+| MGATE_COAP_WITH_DTLS_HOST | CoAP with DTLS inbound (IN) connection listening host | localhost |
+| MGATE_COAP_WITH_DTLS_PORT | CoAP with DTLS inbound (IN) connection listening port | 5684 |
+| MGATE_COAP_WITH_DTLS_TARGET_HOST | CoAP with DTLS outbound (OUT) connection listening host | localhost |
+| MGATE_COAP_WITH_DTLS_TARGET_PORT | CoAP with DTLS outbound (OUT) connection listening port | 5683 |
+| MGATE_COAP_WITH_DTLS_CERT_FILE | CoAP with DTLS certificate file path | ssl/certs/server.crt |
+| MGATE_COAP_WITH_DTLS_KEY_FILE | CoAP with DTLS key file path | ssl/certs/server.key |
+| MGATE_COAP_WITH_DTLS_SERVER_CA_FILE | CoAP with DTLS server CA file path | ssl/certs/ca.crt |
## mProxy Configuration Environment Variables
-### Server Configuration Environment Variables
+### Server Configuration Keys (used under a prefix)
-- `ADDRESS` : Specifies the address at which mProxy will listen. Supports MQTT, MQTT over WebSocket, and HTTP proxy connections.
-- `PATH_PREFIX` : Defines the path prefix when listening for MQTT over WebSocket or HTTP connections.
-- `TARGET` : Specifies the address of the target server, including any prefix path if available. The target server can be an MQTT server, MQTT over WebSocket, or an HTTP server.
+- `HOST`: Inbound bind host (empty binds all interfaces).
+- `PORT`: Inbound port (defaults per server as listed above).
+- `TARGET_HOST`: Backend host (default `localhost`).
+- `TARGET_PORT`: Backend port (defaults per protocol as listed above).
+- `TARGET_PROTOCOL`: Backend scheme (`http` or `ws` when applicable).
+- `TARGET_PATH`: Backend path suffix (useful for WebSocket/HTTP).
+- `PATH_PREFIX`: Optional inbound path prefix (present in config; may be unused in examples).
-### TLS Configuration Environment Variables
+### TLS Configuration Keys
-- `CERT_FILE` : Path to the TLS certificate file.
-- `KEY_FILE` : Path to the TLS certificate key file.
-- `SERVER_CA_FILE` : Path to the Server CA certificate file.
-- `CLIENT_CA_FILE` : Path to the Client CA certificate file.
-- `CERT_VERIFICATION_METHODS` : Methods for validating certificates. Accepted values are `ocsp` or `crl`.
- For the `ocsp` value, the `tls.Config` attempts to retrieve the OCSP responder/server URL from the Authority Information Access (AIA) section of the client certificate. If the client certificate lacks an OCSP responder URL or if an alternative URL is preferred, you can override it using the environmental variable `OCSP_RESPONDER_URL`.
- For the `crl` value, the `tls.Config` attempts to obtain the Certificate Revocation List (CRL) file from the CRL Distribution Point section in the client certificate. If the client certificate lacks a CRL distribution point section, or if you prefer to override it, you can use the environmental variables `CRL_DISTRIBUTION_POINTS` and `CRL_DISTRIBUTION_POINTS_ISSUER_CERT_FILE`. If no CRL distribution point server is available, you can specify an offline CRL file using the environmental variables `OFFLINE_CRL_FILE` and `OFFLINE_CRL_ISSUER_CERT_FILE`.
+- `CERT_FILE`: TLS certificate file.
+- `KEY_FILE`: TLS private key file.
+- `SERVER_CA_FILE`: Server CA bundle.
+- `CLIENT_CA_FILE`: Client CA bundle (for mTLS only).
+- `CERT_VERIFICATION_METHODS`: Comma-separated `ocsp` and/or `crl`.
+ - OCSP: If AIA lacks a responder URL or you prefer a custom endpoint, set `OCSP_RESPONDER_URL`.
+ - CRL: If distribution points are missing or unavailable, use overrides and/or offline files.
-#### OCSP Configuration Environment Variables
+#### OCSP Keys
-- `OCSP_DEPTH` : Depth of client certificate verification in the OCSP method. The default value is 0, meaning there is no limit, and all certificates are verified.
-- `OCSP_RESPONDER_URL` : Override value for the OCSP responder URL present in the Authority Information Access (AIA) section of the client certificate. If left empty, it expects the OCSP responder URL from the AIA section of the client certificate.
+- `OCSP_DEPTH`: Verification depth (0 = no limit; verify all certs).
+- `OCSP_RESPONDER_URL`: Override OCSP responder when AIA is missing or overridden.
-#### CRL Configuration Environment Variables
+#### CRL Keys
-- `CRL_DEPTH`: Depth of client certificate verification in the CRL method. The default value is 1, meaning only the leaf certificate is verified.
-- `CRL_DISTRIBUTION_POINTS` : Override for the CRL Distribution Point value present in the certificate's CRL Distribution Point section.
-- `CRL_DISTRIBUTION_POINTS_ISSUER_CERT_FILE` : Path to the issuer certificate file for verifying the CRL retrieved from `CRL_DISTRIBUTION_POINTS`.
-- `OFFLINE_CRL_FILE` : Path to the offline CRL file, which can be used if the CRL Distribution point is not available in either the environmental variable or the certificate's CRL Distribution Point section.
-- `OFFLINE_CRL_ISSUER_CERT_FILE` : Location of the issuer certificate file for verifying the offline CRL file specified in `OFFLINE_CRL_FILE`.
+- `CRL_DEPTH`: Verification depth (default `1`, leaf only).
+- `CRL_DISTRIBUTION_POINTS`: CRL URL override.
+- `CRL_DISTRIBUTION_POINTS_ISSUER_CERT_FILE`: Issuer cert to verify CRL signature.
+- `OFFLINE_CRL_FILE`: Offline CRL file.
+- `OFFLINE_CRL_ISSUER_CERT_FILE`: Issuer cert to verify offline CRL.
## Adding Prefix to Environmental Variables
mProxy relies on the [caarlos0/env](https://github.com/caarlos0/env) package to load environmental variables into its [configuration](https://github.com/arvindh123/mgate/blob/main/config.go#L15).
You can control how these variables are loaded by passing `env.Options` to the `config.EnvParse` function.
-To add a prefix to environmental variables, use `env.Options{Prefix: "MPROXY_"}` from the [caarlos0/env](https://github.com/caarlos0/env) package. For example:
+To add a prefix to environmental variables, use `env.Options{Prefix: "MGATE_"}` from the [caarlos0/env](https://github.com/caarlos0/env) package. For example:
```go
package main
@@ -378,35 +483,49 @@ import (
)
mqttConfig := mgate.Config{}
-if err := mqttConfig.EnvParse(env.Options{Prefix: "MPROXY_" }); err != nil {
+if err := mqttConfig.EnvParse(env.Options{Prefix: "MGATE_" }); err != nil {
panic(err)
}
fmt.Printf("%+v\n")
```
-In the above snippet, `mqttConfig.EnvParse` expects all environmental variables with the prefix `MPROXY_`.
+In the above snippet, `mqttConfig.EnvParse` expects all environmental variables with the prefix `MGATE_`.
For instance:
-- MPROXY_ADDRESS
-- MPROXY_PATH_PREFIX
-- MPROXY_TARGET
-- MPROXY_CERT_FILE
-- MPROXY_KEY_FILE
-- MPROXY_SERVER_CA_FILE
-- MPROXY_CLIENT_CA_FILE
-- MPROXY_CERT_VERIFICATION_METHODS
-- MPROXY_OCSP_DEPTH
-- MPROXY_OCSP_RESPONDER_URL
-- MPROXY_CRL_DEPTH
-- MPROXY_CRL_DISTRIBUTION_POINTS
-- MPROXY_CRL_DISTRIBUTION_POINTS_ISSUER_CERT_FILE
-- MPROXY_OFFLINE_CRL_FILE
-- MPROXY_OFFLINE_CRL_ISSUER_CERT_FILE
+- MGATE_ADDRESS
+- MGATE_PATH_PREFIX
+- MGATE_TARGET
+- MGATE_CERT_FILE
+- MGATE_KEY_FILE
+- MGATE_SERVER_CA_FILE
+- MGATE_CLIENT_CA_FILE
+- MGATE_CERT_VERIFICATION_METHODS
+- MGATE_OCSP_DEPTH
+- MGATE_OCSP_RESPONDER_URL
+- MGATE_CRL_DEPTH
+- MGATE_CRL_DISTRIBUTION_POINTS
+- MGATE_CRL_DISTRIBUTION_POINTS_ISSUER_CERT_FILE
+- MGATE_OFFLINE_CRL_FILE
+- MGATE_OFFLINE_CRL_ISSUER_CERT_FILE
+
+## Troubleshooting & FAQ
+
+- Cert chain errors: Ensure `SERVER_CA_FILE` (and `CLIENT_CA_FILE` for mTLS) include required intermediates.
+- OCSP responder missing/unreachable: Set `OCSP_RESPONDER_URL` to a reachable endpoint; verify firewall rules.
+- CRL retrieval failures: Use `CRL_DISTRIBUTION_POINTS` overrides or provide `OFFLINE_CRL_FILE` and `OFFLINE_CRL_ISSUER_CERT_FILE`.
+- WebSocket path mismatches: Confirm client path matches server prefix (e.g., `/mqtt`) and backend target path.
+- HTTP prefix issues: Verify inbound prefix `/messages` and backend routing.
+- Target connectivity: Confirm the backend (MQTT/HTTP/WS/CoAP) is listening and ports are open.
+- Metrics/Health not reachable: Check `METRICS_PORT`/`HEALTH_PORT` collisions and local firewall.
## License
[Apache-2.0](LICENSE)
-[grc]: https://goreportcard.com/badge/github.com/absmach/mgate
-[LIC]: LICENCE
+[grc]: https://goreportcard.com/badge/github.com/absmach/mproxy
+[LIC]: LICENSE
[LIC-BADGE]: https://img.shields.io/badge/License-Apache_2.0-blue.svg
+[PKG-BADGE]: https://pkg.go.dev/badge/github.com/absmach/mproxy
+[PKG]: https://pkg.go.dev/github.com/absmach/mproxy
+[RELEASE-BADGE]: https://img.shields.io/github/v/release/absmach/mproxy?display_name=tag&sort=semver
+[RELEASE]: https://github.com/absmach/mproxy/releases
diff --git a/cmd/main.go b/cmd/main.go
index 5c87a3a..e0e2145 100644
--- a/cmd/main.go
+++ b/cmd/main.go
@@ -36,6 +36,8 @@ const (
coapWithoutDTLS = "MPROXY_COAP_WITHOUT_DTLS_"
coapWithDTLS = "MPROXY_COAP_WITH_DTLS_"
+
+ defaultTargetHost = "localhost"
)
func main() {
@@ -136,7 +138,7 @@ func startMQTTProxy(g *errgroup.Group, ctx context.Context, envPrefix string, ha
}
if cfg.TargetHost == "" {
- cfg.TargetHost = "localhost"
+ cfg.TargetHost = defaultTargetHost
}
if cfg.TargetPort == "" {
@@ -187,7 +189,7 @@ func startWebSocketProxy(g *errgroup.Group, ctx context.Context, envPrefix strin
}
if cfg.TargetHost == "" {
- cfg.TargetHost = "localhost"
+ cfg.TargetHost = defaultTargetHost
}
if cfg.TargetPort == "" {
@@ -245,7 +247,7 @@ func startHTTPProxy(g *errgroup.Group, ctx context.Context, envPrefix string, ha
}
if cfg.TargetHost == "" {
- cfg.TargetHost = "localhost"
+ cfg.TargetHost = defaultTargetHost
}
if cfg.TargetPort == "" {
@@ -300,7 +302,7 @@ func startCoAPProxy(g *errgroup.Group, ctx context.Context, envPrefix string, ha
}
if cfg.TargetHost == "" {
- cfg.TargetHost = "localhost"
+ cfg.TargetHost = defaultTargetHost
}
if cfg.TargetPort == "" {
diff --git a/cmd/production/handlers.go b/cmd/production/handlers.go
index 50797f8..6e82f96 100644
--- a/cmd/production/handlers.go
+++ b/cmd/production/handlers.go
@@ -13,6 +13,8 @@ import (
"github.com/absmach/mproxy/pkg/ratelimit"
)
+const protocolMQTT = "mqtt"
+
// RateLimitedHandler wraps a handler with rate limiting.
type RateLimitedHandler struct {
handler handler.Handler
@@ -97,13 +99,10 @@ type InstrumentedHandler struct {
func (h *InstrumentedHandler) AuthConnect(ctx context.Context, hctx *handler.Context) error {
start := time.Now()
h.metrics.AuthAttempts.WithLabelValues(hctx.Protocol, "connect").Inc()
-
err := h.handler.AuthConnect(ctx, hctx)
-
if err != nil {
h.metrics.AuthFailures.WithLabelValues(hctx.Protocol, "connect", "unauthorized").Inc()
}
-
duration := time.Since(start).Seconds()
h.metrics.RequestDuration.WithLabelValues(hctx.Protocol, "connect").Observe(duration)
@@ -114,17 +113,13 @@ func (h *InstrumentedHandler) AuthConnect(ctx context.Context, hctx *handler.Con
func (h *InstrumentedHandler) AuthPublish(ctx context.Context, hctx *handler.Context, topic *string, payload *[]byte) error {
start := time.Now()
h.metrics.AuthAttempts.WithLabelValues(hctx.Protocol, "publish").Inc()
-
if payload != nil {
h.metrics.RequestSize.WithLabelValues(hctx.Protocol).Observe(float64(len(*payload)))
}
-
err := h.handler.AuthPublish(ctx, hctx, topic, payload)
-
if err != nil {
h.metrics.AuthFailures.WithLabelValues(hctx.Protocol, "publish", "unauthorized").Inc()
}
-
duration := time.Since(start).Seconds()
h.metrics.RequestDuration.WithLabelValues(hctx.Protocol, "publish").Observe(duration)
@@ -141,13 +136,10 @@ func (h *InstrumentedHandler) AuthPublish(ctx context.Context, hctx *handler.Con
func (h *InstrumentedHandler) AuthSubscribe(ctx context.Context, hctx *handler.Context, topics *[]string) error {
start := time.Now()
h.metrics.AuthAttempts.WithLabelValues(hctx.Protocol, "subscribe").Inc()
-
err := h.handler.AuthSubscribe(ctx, hctx, topics)
-
if err != nil {
h.metrics.AuthFailures.WithLabelValues(hctx.Protocol, "subscribe", "unauthorized").Inc()
}
-
duration := time.Since(start).Seconds()
h.metrics.RequestDuration.WithLabelValues(hctx.Protocol, "subscribe").Observe(duration)
@@ -170,7 +162,7 @@ func (h *InstrumentedHandler) OnConnect(ctx context.Context, hctx *handler.Conte
// OnPublish implements handler.Handler with metrics.
func (h *InstrumentedHandler) OnPublish(ctx context.Context, hctx *handler.Context, topic string, payload []byte) error {
- if hctx.Protocol == "mqtt" {
+ if hctx.Protocol == protocolMQTT {
h.metrics.MQTTPackets.WithLabelValues("publish", "upstream").Inc()
}
@@ -179,7 +171,7 @@ func (h *InstrumentedHandler) OnPublish(ctx context.Context, hctx *handler.Conte
// OnSubscribe implements handler.Handler with metrics.
func (h *InstrumentedHandler) OnSubscribe(ctx context.Context, hctx *handler.Context, topics []string) error {
- if hctx.Protocol == "mqtt" {
+ if hctx.Protocol == protocolMQTT {
h.metrics.MQTTPackets.WithLabelValues("subscribe", "upstream").Inc()
}
@@ -188,7 +180,7 @@ func (h *InstrumentedHandler) OnSubscribe(ctx context.Context, hctx *handler.Con
// OnUnsubscribe implements handler.Handler with metrics.
func (h *InstrumentedHandler) OnUnsubscribe(ctx context.Context, hctx *handler.Context, topics []string) error {
- if hctx.Protocol == "mqtt" {
+ if hctx.Protocol == protocolMQTT {
h.metrics.MQTTPackets.WithLabelValues("unsubscribe", "upstream").Inc()
}
diff --git a/cmd/production/main.go b/cmd/production/main.go
index 6a4cf0a..56f4304 100644
--- a/cmd/production/main.go
+++ b/cmd/production/main.go
@@ -30,17 +30,19 @@ import (
"golang.org/x/sync/errgroup"
)
+const logError = "error"
+
// Config holds the application configuration.
type Config struct {
// Observability
- MetricsPort int `env:"METRICS_PORT" envDefault:"9090"`
- HealthPort int `env:"HEALTH_PORT" envDefault:"8080"`
- LogLevel string `env:"LOG_LEVEL" envDefault:"info"`
- LogFormat string `env:"LOG_FORMAT" envDefault:"json"`
+ MetricsPort int `env:"METRICS_PORT" envDefault:"9090"`
+ HealthPort int `env:"HEALTH_PORT" envDefault:"8080"`
+ LogLevel string `env:"LOG_LEVEL" envDefault:"info"`
+ LogFormat string `env:"LOG_FORMAT" envDefault:"json"`
// Resource Limits
- MaxConnections int `env:"MAX_CONNECTIONS" envDefault:"10000"`
- MaxGoroutines int `env:"MAX_GOROUTINES" envDefault:"50000"`
+ MaxConnections int `env:"MAX_CONNECTIONS" envDefault:"10000"`
+ MaxGoroutines int `env:"MAX_GOROUTINES" envDefault:"50000"`
// Connection Pooling
PoolMaxIdle int `env:"POOL_MAX_IDLE" envDefault:"100"`
@@ -48,9 +50,9 @@ type Config struct {
PoolIdleTimeout time.Duration `env:"POOL_IDLE_TIMEOUT" envDefault:"5m"`
// Circuit Breaker
- BreakerMaxFailures int `env:"BREAKER_MAX_FAILURES" envDefault:"5"`
- BreakerResetTimeout time.Duration `env:"BREAKER_RESET_TIMEOUT" envDefault:"60s"`
- BreakerTimeout time.Duration `env:"BREAKER_TIMEOUT" envDefault:"30s"`
+ BreakerMaxFailures int `env:"BREAKER_MAX_FAILURES" envDefault:"5"`
+ BreakerResetTimeout time.Duration `env:"BREAKER_RESET_TIMEOUT" envDefault:"60s"`
+ BreakerTimeout time.Duration `env:"BREAKER_TIMEOUT" envDefault:"30s"`
// Rate Limiting
RateLimitCapacity int64 `env:"RATE_LIMIT_CAPACITY" envDefault:"100"`
@@ -70,14 +72,20 @@ type Config struct {
}
func main() {
+ if err := run(); err != nil {
+ fmt.Fprintln(os.Stderr, err)
+ os.Exit(1)
+ }
+}
+
+func run() error {
// Load configuration
cfg := Config{}
if err := godotenv.Load(); err != nil {
// .env file is optional
}
if err := env.Parse(&cfg); err != nil {
- fmt.Fprintf(os.Stderr, "Failed to parse config: %v\n", err)
- os.Exit(1)
+ return fmt.Errorf("failed to parse config: %w", err)
}
// Setup logger
@@ -216,7 +224,7 @@ func main() {
mqttProxy, err := proxy.NewMQTT(mqttProxyConfig, instrumentedHandler)
if err != nil {
- logger.Error("Failed to create MQTT proxy", slog.String("error", err.Error()))
+ logger.Error("Failed to create MQTT proxy", slog.String(logError, err.Error()))
} else {
g.Go(func() error {
address := net.JoinHostPort(mqttProxyConfig.Host, mqttProxyConfig.Port)
@@ -246,7 +254,7 @@ func main() {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), cfg.ShutdownTimeout)
defer shutdownCancel()
- done := make(chan error)
+ done := make(chan error, 1)
go func() {
done <- g.Wait()
}()
@@ -254,13 +262,14 @@ func main() {
select {
case err := <-done:
if err != nil {
- logger.Error("Shutdown error", slog.String("error", err.Error()))
- os.Exit(1)
+ logger.Error("Shutdown error", slog.String(logError, err.Error()))
+ return err
}
logger.Info("Graceful shutdown completed")
+ return nil
case <-shutdownCtx.Done():
logger.Warn("Shutdown timeout exceeded, forcing exit")
- os.Exit(1)
+ return shutdownCtx.Err()
}
}
@@ -274,7 +283,7 @@ func setupLogger(level, format string) *slog.Logger {
logLevel = slog.LevelInfo
case "warn":
logLevel = slog.LevelWarn
- case "error":
+ case logError:
logLevel = slog.LevelError
default:
logLevel = slog.LevelInfo
@@ -311,7 +320,7 @@ func startMetricsServer(port int, logger *slog.Logger) {
}
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- logger.Error("Metrics server error", slog.String("error", err.Error()))
+ logger.Error("Metrics server error", slog.String(logError, err.Error()))
}
}
@@ -334,6 +343,6 @@ func startHealthServer(port int, checker *health.Checker, logger *slog.Logger) {
}
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- logger.Error("Health server error", slog.String("error", err.Error()))
+ logger.Error("Health server error", slog.String(logError, err.Error()))
}
}
diff --git a/pkg/breaker/breaker.go b/pkg/breaker/breaker.go
index f14d463..bbd6e72 100644
--- a/pkg/breaker/breaker.go
+++ b/pkg/breaker/breaker.go
@@ -10,10 +10,8 @@ import (
"time"
)
-var (
- // ErrCircuitOpen is returned when the circuit breaker is open.
- ErrCircuitOpen = errors.New("circuit breaker is open")
-)
+// ErrCircuitOpen is returned when the circuit breaker is open.
+var ErrCircuitOpen = errors.New("circuit breaker is open")
// State represents the circuit breaker state.
type State int
@@ -51,14 +49,14 @@ type Config struct {
// CircuitBreaker implements the circuit breaker pattern.
type CircuitBreaker struct {
- mu sync.RWMutex
- config Config
- state State
- failures int
- successes int
- lastFailureTime time.Time
- lastStateChange time.Time
- onStateChange func(from, to State)
+ mu sync.RWMutex
+ config Config
+ state State
+ failures int
+ successes int
+ lastFailureTime time.Time
+ lastStateChange time.Time
+ onStateChange func(from, to State)
}
// New creates a new circuit breaker.
diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go
index b0dba1d..052547d 100644
--- a/pkg/errors/errors.go
+++ b/pkg/errors/errors.go
@@ -9,7 +9,7 @@ import (
"fmt"
)
-// Common error types
+// Common error types.
var (
// ErrUnauthorized indicates authentication or authorization failure.
ErrUnauthorized = errors.New("unauthorized")
@@ -41,11 +41,11 @@ var (
// ProxyError wraps an error with additional context.
type ProxyError struct {
- Op string // Operation that failed
- Protocol string // Protocol (mqtt, http, coap, websocket)
- SessionID string // Session identifier
+ Op string // Operation that failed
+ Protocol string // Protocol (mqtt, http, coap, websocket)
+ SessionID string // Session identifier
RemoteAddr string // Client address
- Err error // Underlying error
+ Err error // Underlying error
}
// Error implements the error interface.
@@ -67,11 +67,11 @@ func New(op, protocol, sessionID, remoteAddr string, err error) error {
return nil
}
return &ProxyError{
- Op: op,
- Protocol: protocol,
- SessionID: sessionID,
+ Op: op,
+ Protocol: protocol,
+ SessionID: sessionID,
RemoteAddr: remoteAddr,
- Err: err,
+ Err: err,
}
}
diff --git a/pkg/health/health.go b/pkg/health/health.go
index 3c780af..0e98766 100644
--- a/pkg/health/health.go
+++ b/pkg/health/health.go
@@ -120,13 +120,13 @@ func (c *Checker) HTTPHandler() http.HandlerFunc {
w.Header().Set("Content-Type", "application/json")
if status == StatusUnhealthy {
w.WriteHeader(http.StatusServiceUnavailable)
- } else if status == StatusDegraded {
- w.WriteHeader(http.StatusOK) // Still accept traffic
} else {
- w.WriteHeader(http.StatusOK)
+ w.WriteHeader(http.StatusOK) // Still accept traffic
}
- json.NewEncoder(w).Encode(response)
+ if err := json.NewEncoder(w).Encode(response); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
}
}
@@ -135,9 +135,11 @@ func LivenessHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
- json.NewEncoder(w).Encode(map[string]string{
+ if err := json.NewEncoder(w).Encode(map[string]string{
"status": "alive",
- })
+ }); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
}
}
@@ -161,6 +163,8 @@ func (c *Checker) ReadinessHandler() http.HandlerFunc {
w.WriteHeader(http.StatusOK)
}
- json.NewEncoder(w).Encode(response)
+ if err := json.NewEncoder(w).Encode(response); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
}
}
diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go
index 99c716e..5aa8df4 100644
--- a/pkg/metrics/metrics.go
+++ b/pkg/metrics/metrics.go
@@ -5,36 +5,30 @@
package metrics
import (
- "sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
-var (
- once sync.Once
- reg *prometheus.Registry
-)
-
// Metrics holds all Prometheus metrics for mProxy.
type Metrics struct {
// Connection metrics
- ActiveConnections *prometheus.GaugeVec
- TotalConnections *prometheus.CounterVec
- ConnectionErrors *prometheus.CounterVec
+ ActiveConnections *prometheus.GaugeVec
+ TotalConnections *prometheus.CounterVec
+ ConnectionErrors *prometheus.CounterVec
ConnectionDuration *prometheus.HistogramVec
// Request metrics
- RequestsTotal *prometheus.CounterVec
- RequestDuration *prometheus.HistogramVec
- RequestSize *prometheus.HistogramVec
- ResponseSize *prometheus.HistogramVec
+ RequestsTotal *prometheus.CounterVec
+ RequestDuration *prometheus.HistogramVec
+ RequestSize *prometheus.HistogramVec
+ ResponseSize *prometheus.HistogramVec
// Backend metrics
- BackendRequestsTotal *prometheus.CounterVec
- BackendErrors *prometheus.CounterVec
- BackendDuration *prometheus.HistogramVec
+ BackendRequestsTotal *prometheus.CounterVec
+ BackendErrors *prometheus.CounterVec
+ BackendDuration *prometheus.HistogramVec
BackendActiveConnections *prometheus.GaugeVec
// Circuit breaker metrics
diff --git a/pkg/parser/http/parser.go b/pkg/parser/http/parser.go
index 0138bd2..f7d1156 100644
--- a/pkg/parser/http/parser.go
+++ b/pkg/parser/http/parser.go
@@ -82,7 +82,7 @@ func (p *Parser) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Read body for publish authorization with size limit
// Default: 10MB max body size to prevent memory exhaustion
- const maxBodySize = 10 * 1024 * 1024 // 10MB
+ const maxBodySize = 10 * 1024 * 1024 // 10MB
limitedReader := io.LimitReader(r.Body, maxBodySize+1) // +1 to detect if exceeded
payload, err := io.ReadAll(limitedReader)
diff --git a/pkg/parser/mqtt/parser_test.go b/pkg/parser/mqtt/parser_test.go
index f4643b7..d62f932 100644
--- a/pkg/parser/mqtt/parser_test.go
+++ b/pkg/parser/mqtt/parser_test.go
@@ -14,6 +14,11 @@ import (
"github.com/eclipse/paho.mqtt.golang/packets"
)
+const (
+ testClientID = "test-client"
+ testTopic = "test/topic"
+)
+
type mockHandler struct {
connectErr error
publishErr error
@@ -78,7 +83,7 @@ func TestMQTTParser_ParseConnect(t *testing.T) {
// Create CONNECT packet
connectPkt := packets.NewControlPacket(packets.Connect).(*packets.ConnectPacket)
- connectPkt.ClientIdentifier = "test-client"
+ connectPkt.ClientIdentifier = testClientID
connectPkt.Username = "testuser"
connectPkt.Password = []byte("testpass")
connectPkt.UsernameFlag = true
@@ -107,8 +112,8 @@ func TestMQTTParser_ParseConnect(t *testing.T) {
}
// Verify credentials were extracted and passed to handler
- if mock.lastHctx.ClientID != "test-client" {
- t.Errorf("Expected ClientID 'test-client', got '%s'", mock.lastHctx.ClientID)
+ if mock.lastHctx.ClientID != testClientID {
+ t.Errorf("Expected ClientID '%s', got '%s'", testClientID, mock.lastHctx.ClientID)
}
if mock.lastHctx.Username != "testuser" {
t.Errorf("Expected Username 'testuser', got '%s'", mock.lastHctx.Username)
@@ -129,7 +134,7 @@ func TestMQTTParser_ParsePublish(t *testing.T) {
// Create PUBLISH packet
publishPkt := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
- publishPkt.TopicName = "test/topic"
+ publishPkt.TopicName = testTopic
publishPkt.Payload = []byte("test payload")
publishPkt.Qos = 0
@@ -156,8 +161,8 @@ func TestMQTTParser_ParsePublish(t *testing.T) {
}
// Verify topic and payload were captured
- if mock.lastTopic != "test/topic" {
- t.Errorf("Expected topic 'test/topic', got '%s'", mock.lastTopic)
+ if mock.lastTopic != testTopic {
+ t.Errorf("Expected topic '%s', got '%s'", testTopic, mock.lastTopic)
}
if string(mock.lastPayload) != "test payload" {
t.Errorf("Expected payload 'test payload', got '%s'", mock.lastPayload)
@@ -269,7 +274,7 @@ func TestMQTTParser_AuthError(t *testing.T) {
// Create CONNECT packet
connectPkt := packets.NewControlPacket(packets.Connect).(*packets.ConnectPacket)
- connectPkt.ClientIdentifier = "test-client"
+ connectPkt.ClientIdentifier = testClientID
connectPkt.Username = "baduser"
connectPkt.Password = []byte("badpass")
connectPkt.UsernameFlag = true
@@ -315,7 +320,7 @@ func TestMQTTParser_DownstreamPublish(t *testing.T) {
// Create PUBLISH packet from broker
publishPkt := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
- publishPkt.TopicName = "test/topic"
+ publishPkt.TopicName = testTopic
publishPkt.Payload = []byte("broker message")
publishPkt.Qos = 0
diff --git a/pkg/parser/websocket/parser.go b/pkg/parser/websocket/parser.go
index 157da19..9067c47 100644
--- a/pkg/parser/websocket/parser.go
+++ b/pkg/parser/websocket/parser.go
@@ -48,15 +48,15 @@ func NewParser(targetURL string, underlyingParser parser.Parser, h handler.Handl
// Allow requests without Origin header (e.g., from native apps)
return true
}
- // TODO: Make allowed origins configurable
- // For now, only allow same-origin requests
+ // Note: Allowed origins should be configurable in production.
+ // For now, only allow same-origin requests.
return origin == "http://"+r.Host || origin == "https://"+r.Host
},
ReadBufferSize: 4096,
WriteBufferSize: 4096,
// Limit message size to prevent DoS
// Default: 10MB
- // TODO: Make this configurable
+ // Note: This should be configurable in production.
},
targetURL: targetURL,
underlyingParser: underlyingParser,
@@ -115,7 +115,7 @@ func (p *Parser) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// Create context for this session
- ctx, cancel := context.WithCancel(context.Background())
+ ctx, cancel := context.WithCancel(r.Context())
defer cancel()
// Start bidirectional streaming with underlying protocol parser
@@ -143,7 +143,7 @@ func (p *Parser) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// Notify disconnect
- if err := p.handler.OnDisconnect(context.Background(), hctx); err != nil {
+ if err := p.handler.OnDisconnect(ctx, hctx); err != nil {
p.logger.Error("disconnect handler error",
slog.String("session", sessionID),
slog.String("error", err.Error()))
diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go
index 8f7b38d..4e81a13 100644
--- a/pkg/pool/pool.go
+++ b/pkg/pool/pool.go
@@ -55,13 +55,13 @@ type DialFunc func(ctx context.Context) (net.Conn, error)
// Pool is a connection pool.
type Pool struct {
- mu sync.Mutex
- idle []*Conn
- active int
- dialFunc DialFunc
- config Config
- closed bool
- waitChan chan struct{}
+ mu sync.Mutex
+ idle []*Conn
+ active int
+ dialFunc DialFunc
+ config Config
+ closed bool
+ waitChan chan struct{}
}
// New creates a new connection pool.
@@ -195,7 +195,7 @@ func (p *Pool) isValid(conn *Conn) bool {
return false
}
- // TODO: Add connection health check (send ping)
+ // Note: Add connection health check (send ping).
return true
}
diff --git a/pkg/proxy/http.go b/pkg/proxy/http.go
index 05dcf3b..af7f429 100644
--- a/pkg/proxy/http.go
+++ b/pkg/proxy/http.go
@@ -79,7 +79,7 @@ func (p *HTTPProxy) Listen(ctx context.Context) error {
p.logger.Info("shutdown signal received, closing HTTP server")
// Create shutdown context with timeout
- shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ shutdownCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second)
defer cancel()
// Graceful shutdown
diff --git a/pkg/proxy/websocket.go b/pkg/proxy/websocket.go
index 1d33a6f..7047916 100644
--- a/pkg/proxy/websocket.go
+++ b/pkg/proxy/websocket.go
@@ -78,7 +78,7 @@ func (p *WebSocketProxy) Listen(ctx context.Context) error {
p.logger.Info("shutdown signal received, closing WebSocket server")
// Create shutdown context with timeout
- shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ shutdownCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second)
defer cancel()
// Graceful shutdown
diff --git a/pkg/ratelimit/ratelimit.go b/pkg/ratelimit/ratelimit.go
index 99b4b79..2b38d10 100644
--- a/pkg/ratelimit/ratelimit.go
+++ b/pkg/ratelimit/ratelimit.go
@@ -10,18 +10,16 @@ import (
"time"
)
-var (
- // ErrRateLimitExceeded is returned when rate limit is exceeded.
- ErrRateLimitExceeded = errors.New("rate limit exceeded")
-)
+// ErrRateLimitExceeded is returned when rate limit is exceeded.
+var ErrRateLimitExceeded = errors.New("rate limit exceeded")
// TokenBucket implements the token bucket algorithm for rate limiting.
type TokenBucket struct {
- mu sync.Mutex
- capacity int64
- tokens int64
- refillRate int64 // tokens per second
- lastRefill time.Time
+ mu sync.Mutex
+ capacity int64
+ tokens int64
+ refillRate int64 // tokens per second
+ lastRefill time.Time
}
// NewTokenBucket creates a new token bucket rate limiter.
diff --git a/pkg/server/tcp/server.go b/pkg/server/tcp/server.go
index 8aa29c7..0b68869 100644
--- a/pkg/server/tcp/server.go
+++ b/pkg/server/tcp/server.go
@@ -77,7 +77,6 @@ type Server struct {
parser parser.Parser
handler handler.Handler
wg sync.WaitGroup
- mu sync.Mutex
bufferPool *sync.Pool
connSem chan struct{} // semaphore for connection limiting
}
@@ -147,7 +146,7 @@ func (s *Server) Listen(ctx context.Context) error {
// Create a separate context for active connections
// This allows us to control when to forcefully close connections
- connCtx, connCancel := context.WithCancel(context.Background())
+ connCtx, connCancel := context.WithCancel(ctx)
defer connCancel()
// Accept loop
@@ -323,7 +322,7 @@ func (s *Server) handleConn(ctx context.Context, inbound net.Conn) error {
}
// Notify disconnect
- if err := s.handler.OnDisconnect(context.Background(), hctx); err != nil {
+ if err := s.handler.OnDisconnect(ctx, hctx); err != nil {
s.config.Logger.Error("disconnect handler error",
slog.String("session", sessionID),
slog.String("error", err.Error()))
diff --git a/pkg/server/tcp/server_test.go b/pkg/server/tcp/server_test.go
index 81f1613..4260f8f 100644
--- a/pkg/server/tcp/server_test.go
+++ b/pkg/server/tcp/server_test.go
@@ -110,7 +110,9 @@ func TestTCPServer_ListenAndAccept(t *testing.T) {
defer conn.Close()
// Echo back
- io.Copy(conn, conn)
+ if _, err := io.Copy(conn, conn); err != nil {
+ return
+ }
}()
// Create server
@@ -266,7 +268,9 @@ func TestTCPServer_BackendDialFailure(t *testing.T) {
// Server might have shut down already
return
}
- conn.Write([]byte("test"))
+ if _, err := conn.Write([]byte("test")); err != nil {
+ t.Fatalf("Failed to write to server: %v", err)
+ }
conn.Close()
// Server should continue running despite failed backend dial
@@ -337,10 +341,24 @@ func TestTCPServer_ParseError(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- go server.Listen(ctx)
+ serverErr := make(chan error, 1)
+ go func() {
+ serverErr <- server.Listen(ctx)
+ }()
time.Sleep(100 * time.Millisecond)
// Server should be running fine despite parse errors in connections
+ t.Cleanup(func() {
+ cancel()
+ select {
+ case err := <-serverErr:
+ if err != nil && err != context.Canceled {
+ t.Logf("Server stopped with error: %v", err)
+ }
+ case <-time.After(2 * time.Second):
+ t.Log("Server shutdown timeout")
+ }
+ })
}
func TestTCPServer_ContextCancellation(t *testing.T) {
@@ -454,7 +472,9 @@ func TestTCPServer_TCPOptions(t *testing.T) {
conn, _ := backendListener.Accept()
if conn != nil {
defer conn.Close()
- io.Copy(conn, conn)
+ if _, err := io.Copy(conn, conn); err != nil {
+ return
+ }
}
}()
@@ -477,10 +497,21 @@ func TestTCPServer_TCPOptions(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- go server.Listen(ctx)
+ serverErr := make(chan error, 1)
+ go func() {
+ serverErr <- server.Listen(ctx)
+ }()
time.Sleep(100 * time.Millisecond)
cancel()
+ select {
+ case err := <-serverErr:
+ if err != nil && err != context.Canceled {
+ t.Logf("Server stopped with error: %v", err)
+ }
+ case <-time.After(2 * time.Second):
+ t.Log("Server shutdown timeout")
+ }
}
func TestTCPServer_BufferPool(t *testing.T) {
diff --git a/pkg/server/udp/server.go b/pkg/server/udp/server.go
index a0a077a..c6df893 100644
--- a/pkg/server/udp/server.go
+++ b/pkg/server/udp/server.go
@@ -257,7 +257,7 @@ func (s *Server) Listen(ctx context.Context) error {
s.config.Logger.Info("all workers stopped")
// Drain sessions with timeout
- return s.sessions.DrainAll(s.config.ShutdownTimeout, s.handler)
+ return s.sessions.DrainAll(ctx, s.config.ShutdownTimeout, s.handler)
}
// startWorkerPool starts the worker goroutines for packet processing.
@@ -323,18 +323,18 @@ func (s *Server) handlePacket(ctx context.Context, listener *net.UDPConn, client
// If this is a new session, start downstream reader
if isNew {
- go s.readDownstream(sess, listener)
+ go s.readDownstream(sess.ctx, sess, listener)
}
return nil
}
// readDownstream continuously reads packets from the backend and forwards to the client.
-func (s *Server) readDownstream(sess *Session, listener *net.UDPConn) {
+func (s *Server) readDownstream(ctx context.Context, sess *Session, listener *net.UDPConn) {
defer func() {
// Remove session when downstream reader exits
s.sessions.Remove(sess.RemoteAddr)
- if err := s.handler.OnDisconnect(context.Background(), sess.Context); err != nil {
+ if err := s.handler.OnDisconnect(ctx, sess.Context); err != nil {
s.config.Logger.Error("disconnect handler error",
slog.String("session", sess.ID),
slog.String("error", err.Error()))
@@ -346,7 +346,7 @@ func (s *Server) readDownstream(sess *Session, listener *net.UDPConn) {
for {
select {
- case <-sess.ctx.Done():
+ case <-ctx.Done():
return
default:
}
@@ -388,7 +388,7 @@ func (s *Server) readDownstream(sess *Session, listener *net.UDPConn) {
reader := bytes.NewReader(buffer[:n])
writer := &udpClientWriter{conn: listener, addr: sess.RemoteAddr}
- if err := s.parser.Parse(sess.ctx, reader, writer, parser.Downstream, s.handler, sess.Context); err != nil {
+ if err := s.parser.Parse(ctx, reader, writer, parser.Downstream, s.handler, sess.Context); err != nil {
s.config.Logger.Debug("parser error",
slog.String("session", sess.ID),
slog.String("direction", "downstream"),
diff --git a/pkg/server/udp/server_test.go b/pkg/server/udp/server_test.go
index 04024da..da3e45a 100644
--- a/pkg/server/udp/server_test.go
+++ b/pkg/server/udp/server_test.go
@@ -104,7 +104,9 @@ func TestUDPServer_ListenAndReceive(t *testing.T) {
if err != nil {
return
}
- backendConn.WriteToUDP(buf[:n], addr)
+ if _, err := backendConn.WriteToUDP(buf[:n], addr); err != nil {
+ return
+ }
}
}()
@@ -179,7 +181,10 @@ func TestUDPServer_SessionCreation(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
- go server.Listen(ctx)
+ serverErr := make(chan error, 1)
+ go func() {
+ serverErr <- server.Listen(ctx)
+ }()
time.Sleep(100 * time.Millisecond)
// Initially no sessions
@@ -190,6 +195,16 @@ func TestUDPServer_SessionCreation(t *testing.T) {
// Note: We can't easily test session creation without actually sending
// UDP packets to the server, which would require knowing the server's
// actual port. This is tested in integration tests.
+ t.Cleanup(func() {
+ select {
+ case err := <-serverErr:
+ if err != nil && err != context.Canceled {
+ t.Logf("Server stopped with error: %v", err)
+ }
+ case <-time.After(2 * time.Second):
+ t.Log("Server shutdown timeout")
+ }
+ })
}
func TestUDPServer_InvalidAddress(t *testing.T) {
@@ -360,7 +375,7 @@ func TestSessionManager_Cleanup(t *testing.T) {
sess.mu.Unlock()
// Run cleanup
- sm.cleanupExpired(1*time.Minute, mockH)
+ sm.cleanupExpired(context.Background(), 1*time.Minute, mockH)
// Session should be removed
if sm.Count() != 0 {
@@ -387,7 +402,9 @@ func TestSessionManager_ForceCloseAll(t *testing.T) {
// Create multiple sessions
for i := 0; i < 3; i++ {
addr, _ := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", 50000+i))
- sm.GetOrCreate(context.Background(), addr, targetAddr)
+ if _, _, err := sm.GetOrCreate(context.Background(), addr, targetAddr); err != nil {
+ t.Fatalf("Failed to create session: %v", err)
+ }
}
if sm.Count() != 3 {
@@ -395,7 +412,7 @@ func TestSessionManager_ForceCloseAll(t *testing.T) {
}
// Force close all
- sm.ForceCloseAll(mockH)
+ sm.ForceCloseAll(context.Background(), mockH)
if sm.Count() != 0 {
t.Errorf("Expected 0 sessions after force close, got %d", sm.Count())
@@ -450,7 +467,9 @@ func TestUDPServer_ShutdownTimeout(t *testing.T) {
// Create a session manually
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:54321")
- server.sessions.GetOrCreate(context.Background(), clientAddr, cfg.TargetAddress)
+ if _, _, err := server.sessions.GetOrCreate(context.Background(), clientAddr, cfg.TargetAddress); err != nil {
+ t.Fatalf("Failed to create session: %v", err)
+ }
// Trigger shutdown
cancel()
@@ -490,11 +509,24 @@ func TestUDPServer_ParseError(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- go server.Listen(ctx)
+ serverErr := make(chan error, 1)
+ go func() {
+ serverErr <- server.Listen(ctx)
+ }()
time.Sleep(100 * time.Millisecond)
// Server should handle parse errors gracefully
// and continue running
+ t.Cleanup(func() {
+ select {
+ case err := <-serverErr:
+ if err != nil && err != context.Canceled {
+ t.Logf("Server stopped with error: %v", err)
+ }
+ case <-time.After(2 * time.Second):
+ t.Log("Server shutdown timeout")
+ }
+ })
}
func TestUDPServer_SessionLimit(t *testing.T) {
@@ -524,10 +556,21 @@ func TestUDPServer_SessionLimit(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- go server.Listen(ctx)
+ serverErr := make(chan error, 1)
+ go func() {
+ serverErr <- server.Listen(ctx)
+ }()
time.Sleep(100 * time.Millisecond)
cancel()
+ select {
+ case err := <-serverErr:
+ if err != nil && err != context.Canceled {
+ t.Logf("Server stopped with error: %v", err)
+ }
+ case <-time.After(2 * time.Second):
+ t.Log("Server shutdown timeout")
+ }
}
func TestUDPServer_WorkerPool(t *testing.T) {
@@ -567,11 +610,22 @@ func TestUDPServer_WorkerPool(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- go server.Listen(ctx)
+ serverErr := make(chan error, 1)
+ go func() {
+ serverErr <- server.Listen(ctx)
+ }()
time.Sleep(100 * time.Millisecond)
cancel()
time.Sleep(100 * time.Millisecond)
+ select {
+ case err := <-serverErr:
+ if err != nil && err != context.Canceled {
+ t.Logf("Server stopped with error: %v", err)
+ }
+ case <-time.After(2 * time.Second):
+ t.Log("Server shutdown timeout")
+ }
}
func TestUDPServer_BufferPool(t *testing.T) {
diff --git a/pkg/server/udp/session.go b/pkg/server/udp/session.go
index 0d78b3d..a6c35a2 100644
--- a/pkg/server/udp/session.go
+++ b/pkg/server/udp/session.go
@@ -69,7 +69,6 @@ type SessionManager struct {
sessions map[string]*Session
mu sync.RWMutex
logger *slog.Logger
- wg sync.WaitGroup
maxSessions int
}
@@ -179,13 +178,13 @@ func (sm *SessionManager) Cleanup(ctx context.Context, timeout time.Duration, h
case <-ctx.Done():
return
case <-ticker.C:
- sm.cleanupExpired(timeout, h)
+ sm.cleanupExpired(ctx, timeout, h)
}
}
}
// cleanupExpired removes sessions that haven't been active within the timeout.
-func (sm *SessionManager) cleanupExpired(timeout time.Duration, h handler.Handler) {
+func (sm *SessionManager) cleanupExpired(ctx context.Context, timeout time.Duration, h handler.Handler) {
now := time.Now()
var toRemove []string
@@ -209,7 +208,7 @@ func (sm *SessionManager) cleanupExpired(timeout time.Duration, h handler.Handle
slog.String("client", sess.RemoteAddr.String()))
// Notify disconnect
- if err := h.OnDisconnect(context.Background(), sess.Context); err != nil {
+ if err := h.OnDisconnect(ctx, sess.Context); err != nil {
sm.logger.Error("disconnect handler error",
slog.String("session", sess.ID),
slog.String("error", err.Error()))
@@ -225,7 +224,7 @@ func (sm *SessionManager) cleanupExpired(timeout time.Duration, h handler.Handle
}
// DrainAll waits for all sessions to complete or forces closure after timeout.
-func (sm *SessionManager) DrainAll(timeout time.Duration, h handler.Handler) error {
+func (sm *SessionManager) DrainAll(ctx context.Context, timeout time.Duration, h handler.Handler) error {
sm.logger.Info("draining all UDP sessions")
sm.mu.RLock()
@@ -261,13 +260,13 @@ func (sm *SessionManager) DrainAll(timeout time.Duration, h handler.Handler) err
return nil
case <-time.After(timeout):
sm.logger.Warn("drain timeout exceeded, forcing session closure")
- sm.ForceCloseAll(h)
+ sm.ForceCloseAll(ctx, h)
return ErrShutdownTimeout
}
}
// ForceCloseAll forcefully closes all sessions.
-func (sm *SessionManager) ForceCloseAll(h handler.Handler) {
+func (sm *SessionManager) ForceCloseAll(ctx context.Context, h handler.Handler) {
sm.mu.Lock()
defer sm.mu.Unlock()
@@ -276,7 +275,7 @@ func (sm *SessionManager) ForceCloseAll(h handler.Handler) {
slog.String("session", sess.ID))
// Notify disconnect
- if err := h.OnDisconnect(context.Background(), sess.Context); err != nil {
+ if err := h.OnDisconnect(ctx, sess.Context); err != nil {
sm.logger.Error("disconnect handler error",
slog.String("session", sess.ID),
slog.String("error", err.Error()))