Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 121 additions & 68 deletions pkg/coap/coap.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ import (
"io"
"log/slog"
"net"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/absmach/mgate"
"github.com/absmach/mgate/pkg/session"
mptls "github.com/absmach/mgate/pkg/tls"
"github.com/pion/dtls/v3"
"github.com/plgd-dev/go-coap/v3/message"
"github.com/plgd-dev/go-coap/v3/message/codes"
"github.com/plgd-dev/go-coap/v3/message/pool"
"github.com/plgd-dev/go-coap/v3/udp/coder"
Expand All @@ -25,11 +28,13 @@ import (
const (
bufferSize uint64 = 1280
startObserve uint32 = 0
authQuery = "auth"
)

type Conn struct {
clientAddr *net.UDPAddr
serverConn *net.UDPConn
started atomic.Bool
}

type Proxy struct {
Expand Down Expand Up @@ -58,38 +63,30 @@ func (p *Proxy) proxyUDP(ctx context.Context, l *net.UDPConn) {
default:
n, clientAddr, err := l.ReadFromUDP(buffer)
if err != nil {
p.logger.Error("Failed to read from UDP", slog.Any("error", err))
p.logger.Error("failed to read from UDP", slog.String("error", err.Error()))
return
}
p.mutex.Lock()
conn, ok := p.connMap[clientAddr.String()]
if !ok {
conn, err = p.newConn(clientAddr)
if err != nil {
p.mutex.Unlock()
p.logger.Error("Failed to create new connection", slog.Any("error", err))
return
}
p.connMap[clientAddr.String()] = conn
go p.downUDP(ctx, l, conn)
conn, err := p.newConn(clientAddr)
if err != nil {
p.logger.Error("failed to create new connection", slog.String("error", err.Error()))
continue
}
p.mutex.Unlock()
//nolint:contextcheck // upUDP does not need context
p.upUDP(conn, buffer[:n])
p.upUDP(conn, buffer[:n], l)
}
}
}

func (p *Proxy) Listen(ctx context.Context) error {
addr, err := net.ResolveUDPAddr("udp6", net.JoinHostPort(p.config.Host, p.config.Port))
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.config.Host, p.config.Port))
if err != nil {
p.logger.Error("Failed to resolve UDP address", slog.Any("error", err))
p.logger.Error("failed to resolve UDP address", slog.String("error", err.Error()))
return err
}
g, ctx := errgroup.WithContext(ctx)
switch {
case p.config.DTLSConfig != nil:
l, err := dtls.Listen("udp6", addr, p.config.DTLSConfig)
l, err := dtls.Listen("udp", addr, p.config.DTLSConfig)
if err != nil {
return err
}
Expand Down Expand Up @@ -134,30 +131,44 @@ func (p *Proxy) Listen(ctx context.Context) error {
}

func (p *Proxy) newConn(clientAddr *net.UDPAddr) (*Conn, error) {
conn := new(Conn)
conn.clientAddr = clientAddr
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.config.TargetHost, p.config.TargetPort))
if err != nil {
return nil, err
}
t, err := net.DialUDP("udp", nil, addr)
if err != nil {
return nil, err
p.mutex.Lock()
defer p.mutex.Unlock()
conn, ok := p.connMap[clientAddr.String()]
if !ok {
conn = &Conn{clientAddr: clientAddr}
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.config.TargetHost, p.config.TargetPort))
if err != nil {
return nil, err
}
t, err := net.DialUDP("udp", nil, addr)
if err != nil {
return nil, err
}
conn.serverConn = t
p.connMap[clientAddr.String()] = conn
}
conn.serverConn = t
return conn, nil
}

func (p *Proxy) upUDP(conn *Conn, buffer []byte) {
err := p.handleCoAPMessage(context.Background(), buffer)
if err != nil {
p.logger.Error("Failed to handle CoAP message", slog.Any("err", err))
func (p *Proxy) upUDP(conn *Conn, buffer []byte, l *net.UDPConn) {
if msg, err := p.handleCoAPMessage(context.Background(), buffer); err != nil {
data := p.encodeErrorResponse(context.Background(), msg, err)
if len(data) > 0 {
if _, werr := l.WriteToUDP(data, conn.clientAddr); werr != nil {
p.logger.Error("failed to send error response", slog.String("err", werr.Error()))
}
}
return
}
_, err = conn.serverConn.Write(buffer)
if err != nil {

if _, err := conn.serverConn.Write(buffer); err != nil {
return
}

// Start the downstream reader once the first upstream write succeeds.
if conn.started.CompareAndSwap(false, true) {
go p.downUDP(context.Background(), l, conn)
}
}

func (p *Proxy) downUDP(ctx context.Context, l *net.UDPConn, conn *Conn) {
Expand All @@ -169,7 +180,7 @@ func (p *Proxy) downUDP(ctx context.Context, l *net.UDPConn, conn *Conn) {
return
default:
}
err := conn.serverConn.SetReadDeadline(time.Now().Add(10 * time.Second))
err := conn.serverConn.SetReadDeadline(time.Now().Add(30 * time.Second))
if err != nil {
return
}
Expand Down Expand Up @@ -198,28 +209,28 @@ func (p *Proxy) proxyDTLS(ctx context.Context, l net.Listener) {
case <-ctx.Done():
return
default:
conn, err := l.Accept()
if err != nil {
p.logger.Warn("Accept error " + err.Error())
continue
}
p.logger.Info("Accepted new client")
go p.handleDTLS(ctx, conn)
}
conn, err := l.Accept()
if err != nil {
p.logger.Warn("Accept error " + err.Error())
continue
}
p.logger.Info("Accepted new client")
go p.handleDTLS(ctx, conn)
}
}

func (p *Proxy) handleDTLS(ctx context.Context, inbound net.Conn) {
defer inbound.Close()
outboundAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.config.TargetHost, p.config.TargetPort))
if err != nil {
p.logger.Error("Cannot resolve remote broker address " + net.JoinHostPort(p.config.TargetHost, p.config.TargetPort) + " due to: " + err.Error())
p.logger.Error("cannot resolve remote broker address " + net.JoinHostPort(p.config.TargetHost, p.config.TargetPort) + " due to: " + err.Error())
return
}

outbound, err := net.DialUDP("udp", nil, outboundAddr)
if err != nil {
p.logger.Error("Cannot connect to remote broker " + outboundAddr.String() + " due to: " + err.Error())
p.logger.Error("cannot connect to remote broker " + outboundAddr.String() + " due to: " + err.Error())
return
}
defer outbound.Close()
Expand All @@ -237,7 +248,7 @@ func (p *Proxy) handleDTLS(ctx context.Context, inbound net.Conn) {
})

if err := g.Wait(); err != nil {
p.logger.Error("DTLS proxy error", slog.Any("error", err))
p.logger.Error("DTLS proxy error", slog.String("error", err.Error()))
}
}

Expand All @@ -248,14 +259,17 @@ func (p *Proxy) dtlsUp(ctx context.Context, outbound *net.UDPConn, inbound net.C
if err != nil {
return
}
err = p.handleCoAPMessage(ctx, buffer[:n])
if err != nil {
p.logger.Error("Failed to handle CoAP message", slog.Any("err", err))
if msg, err := p.handleCoAPMessage(ctx, buffer[:n]); err != nil {
data := p.encodeErrorResponse(ctx, msg, err)
if len(data) > 0 {
if _, werr := inbound.Write(data); werr != nil {
p.logger.Error("failed to send error response", slog.String("err", werr.Error()))
}
}
return
}

_, err = outbound.Write(buffer[:n])
if err != nil {
if _, err = outbound.Write(buffer[:n]); err != nil {
return
}
}
Expand All @@ -273,63 +287,102 @@ func (p *Proxy) dtlsDown(inbound net.Conn, outbound *net.UDPConn) {
return
}

_, err = inbound.Write(buffer[:n])
if err != nil {
if _, err = inbound.Write(buffer[:n]); err != nil {
return
}
}
}

func (p *Proxy) handleCoAPMessage(ctx context.Context, buffer []byte) error {
func (p *Proxy) handleCoAPMessage(ctx context.Context, buffer []byte) (*pool.Message, error) {
var payload []byte
var path string
msg := pool.NewMessage(ctx)
_, err := msg.UnmarshalWithDecoder(coder.DefaultCoder, buffer)
if err != nil {
return err
return msg, err
}
token := msg.Token()
if msg.Code() != codes.Empty {
path, err = msg.Path()
if err != nil {
return err
}
if msg.Code() != codes.POST && msg.Code() != codes.GET {
return msg, nil
}
ctx = session.NewContext(ctx, &session.Session{Password: token})

authKey, err := parseKey(msg)
if err != nil {
return msg, err
}

path, err = msg.Path()
if err != nil {
return msg, err
}

ctx = session.NewContext(ctx, &session.Session{Password: []byte(authKey)})

if msg.Body() != nil {
payload, err = io.ReadAll(msg.Body())
if err != nil {
return err
return msg, err
}
}

switch msg.Code() {
case codes.POST:
if err := p.session.AuthConnect(ctx); err != nil {
return err
return msg, err
}
if err := p.session.AuthPublish(ctx, &path, &payload); err != nil {
return err
return msg, err
}
if err := p.session.Publish(ctx, &path, &payload); err != nil {
return err
return msg, err
}
case codes.GET:
if err := p.session.AuthConnect(ctx); err != nil {
return err
return msg, err
}
if obs, err := msg.Options().Observe(); err == nil {
if obs == startObserve {
if err := p.session.AuthSubscribe(ctx, &[]string{path}); err != nil {
return err
return msg, err
}
if err := p.session.Subscribe(ctx, &[]string{path}); err != nil {
return err
return msg, err
}
}
}
}

return nil
return msg, nil
}

func (p *Proxy) encodeErrorResponse(ctx context.Context, msg *pool.Message, err error) []byte {
resp := pool.NewMessage(ctx)
resp.SetToken(msg.Token())
resp.SetMessageID(msg.MessageID())
resp.SetType(msg.Type())
for _, opt := range msg.Options() {
resp.AddOptionBytes(opt.ID, opt.Value)
}
cpe, ok := err.(COAPProxyError)
if !ok {
cpe = NewCOAPProxyError(codes.BadRequest, err)
}
resp.SetCode(cpe.StatusCode())
data, err := resp.MarshalWithEncoder(coder.DefaultCoder)
if err != nil {
p.logger.Error("failed to marshal error response message", slog.String("err", err.Error()))
return nil
}
return data
}

func parseKey(msg *pool.Message) (string, error) {
authKey, err := msg.Options().GetString(message.URIQuery)
if err != nil {
return "", NewCOAPProxyError(codes.BadRequest, err)
}
vars := strings.Split(authKey, "=")
if len(vars) != 2 || vars[0] != authQuery {
return "", nil
}
return vars[1], nil
}
43 changes: 43 additions & 0 deletions pkg/coap/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0

package coap

import (
"encoding/json"

"github.com/plgd-dev/go-coap/v3/message/codes"
)

type coapProxyError struct {
statusCode codes.Code
err error
}

type COAPProxyError interface {
error
MarshalJSON() ([]byte, error)
StatusCode() codes.Code
}

var _ COAPProxyError = (*coapProxyError)(nil)

func (cpe *coapProxyError) Error() string {
return cpe.err.Error()
}

func (cpe *coapProxyError) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Error string `json:"message"`
}{
Error: cpe.err.Error(),
})
}

func (cpe *coapProxyError) StatusCode() codes.Code {
return cpe.statusCode
}

func NewCOAPProxyError(statusCode codes.Code, err error) COAPProxyError {
return &coapProxyError{statusCode: statusCode, err: err}
}