diff --git a/pkg/coap/coap.go b/pkg/coap/coap.go index d884e33e..87bdedde 100644 --- a/pkg/coap/coap.go +++ b/pkg/coap/coap.go @@ -9,13 +9,16 @@ import ( "io" "log/slog" "net" + "strings" "sync" + "sync/atomic" "time" "github.com/absmach/mgate" "github.com/absmach/mgate/pkg/session" mptls "github.com/absmach/mgate/pkg/tls" "github.com/pion/dtls/v3" + "github.com/plgd-dev/go-coap/v3/message" "github.com/plgd-dev/go-coap/v3/message/codes" "github.com/plgd-dev/go-coap/v3/message/pool" "github.com/plgd-dev/go-coap/v3/udp/coder" @@ -25,11 +28,13 @@ import ( const ( bufferSize uint64 = 1280 startObserve uint32 = 0 + authQuery = "auth" ) type Conn struct { clientAddr *net.UDPAddr serverConn *net.UDPConn + started atomic.Bool } type Proxy struct { @@ -58,38 +63,30 @@ func (p *Proxy) proxyUDP(ctx context.Context, l *net.UDPConn) { default: n, clientAddr, err := l.ReadFromUDP(buffer) if err != nil { - p.logger.Error("Failed to read from UDP", slog.Any("error", err)) + p.logger.Error("failed to read from UDP", slog.String("error", err.Error())) return } - p.mutex.Lock() - conn, ok := p.connMap[clientAddr.String()] - if !ok { - conn, err = p.newConn(clientAddr) - if err != nil { - p.mutex.Unlock() - p.logger.Error("Failed to create new connection", slog.Any("error", err)) - return - } - p.connMap[clientAddr.String()] = conn - go p.downUDP(ctx, l, conn) + conn, err := p.newConn(clientAddr) + if err != nil { + p.logger.Error("failed to create new connection", slog.String("error", err.Error())) + continue } - p.mutex.Unlock() //nolint:contextcheck // upUDP does not need context - p.upUDP(conn, buffer[:n]) + p.upUDP(conn, buffer[:n], l) } } } func (p *Proxy) Listen(ctx context.Context) error { - addr, err := net.ResolveUDPAddr("udp6", net.JoinHostPort(p.config.Host, p.config.Port)) + addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.config.Host, p.config.Port)) if err != nil { - p.logger.Error("Failed to resolve UDP address", slog.Any("error", err)) + p.logger.Error("failed to resolve UDP address", slog.String("error", err.Error())) return err } g, ctx := errgroup.WithContext(ctx) switch { case p.config.DTLSConfig != nil: - l, err := dtls.Listen("udp6", addr, p.config.DTLSConfig) + l, err := dtls.Listen("udp", addr, p.config.DTLSConfig) if err != nil { return err } @@ -134,30 +131,44 @@ func (p *Proxy) Listen(ctx context.Context) error { } func (p *Proxy) newConn(clientAddr *net.UDPAddr) (*Conn, error) { - conn := new(Conn) - conn.clientAddr = clientAddr - addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.config.TargetHost, p.config.TargetPort)) - if err != nil { - return nil, err - } - t, err := net.DialUDP("udp", nil, addr) - if err != nil { - return nil, err + p.mutex.Lock() + defer p.mutex.Unlock() + conn, ok := p.connMap[clientAddr.String()] + if !ok { + conn = &Conn{clientAddr: clientAddr} + addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.config.TargetHost, p.config.TargetPort)) + if err != nil { + return nil, err + } + t, err := net.DialUDP("udp", nil, addr) + if err != nil { + return nil, err + } + conn.serverConn = t + p.connMap[clientAddr.String()] = conn } - conn.serverConn = t return conn, nil } -func (p *Proxy) upUDP(conn *Conn, buffer []byte) { - err := p.handleCoAPMessage(context.Background(), buffer) - if err != nil { - p.logger.Error("Failed to handle CoAP message", slog.Any("err", err)) +func (p *Proxy) upUDP(conn *Conn, buffer []byte, l *net.UDPConn) { + if msg, err := p.handleCoAPMessage(context.Background(), buffer); err != nil { + data := p.encodeErrorResponse(context.Background(), msg, err) + if len(data) > 0 { + if _, werr := l.WriteToUDP(data, conn.clientAddr); werr != nil { + p.logger.Error("failed to send error response", slog.String("err", werr.Error())) + } + } return } - _, err = conn.serverConn.Write(buffer) - if err != nil { + + if _, err := conn.serverConn.Write(buffer); err != nil { return } + + // Start the downstream reader once the first upstream write succeeds. + if conn.started.CompareAndSwap(false, true) { + go p.downUDP(context.Background(), l, conn) + } } func (p *Proxy) downUDP(ctx context.Context, l *net.UDPConn, conn *Conn) { @@ -169,7 +180,7 @@ func (p *Proxy) downUDP(ctx context.Context, l *net.UDPConn, conn *Conn) { return default: } - err := conn.serverConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + err := conn.serverConn.SetReadDeadline(time.Now().Add(30 * time.Second)) if err != nil { return } @@ -198,14 +209,14 @@ func (p *Proxy) proxyDTLS(ctx context.Context, l net.Listener) { case <-ctx.Done(): return default: - conn, err := l.Accept() - if err != nil { - p.logger.Warn("Accept error " + err.Error()) - continue - } - p.logger.Info("Accepted new client") - go p.handleDTLS(ctx, conn) } + conn, err := l.Accept() + if err != nil { + p.logger.Warn("Accept error " + err.Error()) + continue + } + p.logger.Info("Accepted new client") + go p.handleDTLS(ctx, conn) } } @@ -213,13 +224,13 @@ func (p *Proxy) handleDTLS(ctx context.Context, inbound net.Conn) { defer inbound.Close() outboundAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.config.TargetHost, p.config.TargetPort)) if err != nil { - p.logger.Error("Cannot resolve remote broker address " + net.JoinHostPort(p.config.TargetHost, p.config.TargetPort) + " due to: " + err.Error()) + p.logger.Error("cannot resolve remote broker address " + net.JoinHostPort(p.config.TargetHost, p.config.TargetPort) + " due to: " + err.Error()) return } outbound, err := net.DialUDP("udp", nil, outboundAddr) if err != nil { - p.logger.Error("Cannot connect to remote broker " + outboundAddr.String() + " due to: " + err.Error()) + p.logger.Error("cannot connect to remote broker " + outboundAddr.String() + " due to: " + err.Error()) return } defer outbound.Close() @@ -237,7 +248,7 @@ func (p *Proxy) handleDTLS(ctx context.Context, inbound net.Conn) { }) if err := g.Wait(); err != nil { - p.logger.Error("DTLS proxy error", slog.Any("error", err)) + p.logger.Error("DTLS proxy error", slog.String("error", err.Error())) } } @@ -248,14 +259,17 @@ func (p *Proxy) dtlsUp(ctx context.Context, outbound *net.UDPConn, inbound net.C if err != nil { return } - err = p.handleCoAPMessage(ctx, buffer[:n]) - if err != nil { - p.logger.Error("Failed to handle CoAP message", slog.Any("err", err)) + if msg, err := p.handleCoAPMessage(ctx, buffer[:n]); err != nil { + data := p.encodeErrorResponse(ctx, msg, err) + if len(data) > 0 { + if _, werr := inbound.Write(data); werr != nil { + p.logger.Error("failed to send error response", slog.String("err", werr.Error())) + } + } return } - _, err = outbound.Write(buffer[:n]) - if err != nil { + if _, err = outbound.Write(buffer[:n]); err != nil { return } } @@ -273,63 +287,102 @@ func (p *Proxy) dtlsDown(inbound net.Conn, outbound *net.UDPConn) { return } - _, err = inbound.Write(buffer[:n]) - if err != nil { + if _, err = inbound.Write(buffer[:n]); err != nil { return } } } -func (p *Proxy) handleCoAPMessage(ctx context.Context, buffer []byte) error { +func (p *Proxy) handleCoAPMessage(ctx context.Context, buffer []byte) (*pool.Message, error) { var payload []byte var path string msg := pool.NewMessage(ctx) _, err := msg.UnmarshalWithDecoder(coder.DefaultCoder, buffer) if err != nil { - return err + return msg, err } - token := msg.Token() - if msg.Code() != codes.Empty { - path, err = msg.Path() - if err != nil { - return err - } + if msg.Code() != codes.POST && msg.Code() != codes.GET { + return msg, nil } - ctx = session.NewContext(ctx, &session.Session{Password: token}) + + authKey, err := parseKey(msg) + if err != nil { + return msg, err + } + + path, err = msg.Path() + if err != nil { + return msg, err + } + + ctx = session.NewContext(ctx, &session.Session{Password: []byte(authKey)}) if msg.Body() != nil { payload, err = io.ReadAll(msg.Body()) if err != nil { - return err + return msg, err } } switch msg.Code() { case codes.POST: if err := p.session.AuthConnect(ctx); err != nil { - return err + return msg, err } if err := p.session.AuthPublish(ctx, &path, &payload); err != nil { - return err + return msg, err } if err := p.session.Publish(ctx, &path, &payload); err != nil { - return err + return msg, err } case codes.GET: if err := p.session.AuthConnect(ctx); err != nil { - return err + return msg, err } if obs, err := msg.Options().Observe(); err == nil { if obs == startObserve { if err := p.session.AuthSubscribe(ctx, &[]string{path}); err != nil { - return err + return msg, err } if err := p.session.Subscribe(ctx, &[]string{path}); err != nil { - return err + return msg, err } } } } - return nil + return msg, nil +} + +func (p *Proxy) encodeErrorResponse(ctx context.Context, msg *pool.Message, err error) []byte { + resp := pool.NewMessage(ctx) + resp.SetToken(msg.Token()) + resp.SetMessageID(msg.MessageID()) + resp.SetType(msg.Type()) + for _, opt := range msg.Options() { + resp.AddOptionBytes(opt.ID, opt.Value) + } + cpe, ok := err.(COAPProxyError) + if !ok { + cpe = NewCOAPProxyError(codes.BadRequest, err) + } + resp.SetCode(cpe.StatusCode()) + data, err := resp.MarshalWithEncoder(coder.DefaultCoder) + if err != nil { + p.logger.Error("failed to marshal error response message", slog.String("err", err.Error())) + return nil + } + return data +} + +func parseKey(msg *pool.Message) (string, error) { + authKey, err := msg.Options().GetString(message.URIQuery) + if err != nil { + return "", NewCOAPProxyError(codes.BadRequest, err) + } + vars := strings.Split(authKey, "=") + if len(vars) != 2 || vars[0] != authQuery { + return "", nil + } + return vars[1], nil } diff --git a/pkg/coap/errors.go b/pkg/coap/errors.go new file mode 100644 index 00000000..3a8758f1 --- /dev/null +++ b/pkg/coap/errors.go @@ -0,0 +1,43 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package coap + +import ( + "encoding/json" + + "github.com/plgd-dev/go-coap/v3/message/codes" +) + +type coapProxyError struct { + statusCode codes.Code + err error +} + +type COAPProxyError interface { + error + MarshalJSON() ([]byte, error) + StatusCode() codes.Code +} + +var _ COAPProxyError = (*coapProxyError)(nil) + +func (cpe *coapProxyError) Error() string { + return cpe.err.Error() +} + +func (cpe *coapProxyError) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Error string `json:"message"` + }{ + Error: cpe.err.Error(), + }) +} + +func (cpe *coapProxyError) StatusCode() codes.Code { + return cpe.statusCode +} + +func NewCOAPProxyError(statusCode codes.Code, err error) COAPProxyError { + return &coapProxyError{statusCode: statusCode, err: err} +}