Skip to content

Commit 503150b

Browse files
authored
NOISSUE - Improve CoAP message handling (#108)
* refactor: improve coap message handling Signed-off-by: Felix Gateru <felix.gateru@gmail.com> * refactor: encapsulate mutex logic Signed-off-by: Felix Gateru <felix.gateru@gmail.com> --------- Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
1 parent cf2fd82 commit 503150b

File tree

2 files changed

+164
-68
lines changed

2 files changed

+164
-68
lines changed

pkg/coap/coap.go

Lines changed: 121 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@ import (
99
"io"
1010
"log/slog"
1111
"net"
12+
"strings"
1213
"sync"
14+
"sync/atomic"
1315
"time"
1416

1517
"github.com/absmach/mgate"
1618
"github.com/absmach/mgate/pkg/session"
1719
mptls "github.com/absmach/mgate/pkg/tls"
1820
"github.com/pion/dtls/v3"
21+
"github.com/plgd-dev/go-coap/v3/message"
1922
"github.com/plgd-dev/go-coap/v3/message/codes"
2023
"github.com/plgd-dev/go-coap/v3/message/pool"
2124
"github.com/plgd-dev/go-coap/v3/udp/coder"
@@ -25,11 +28,13 @@ import (
2528
const (
2629
bufferSize uint64 = 1280
2730
startObserve uint32 = 0
31+
authQuery = "auth"
2832
)
2933

3034
type Conn struct {
3135
clientAddr *net.UDPAddr
3236
serverConn *net.UDPConn
37+
started atomic.Bool
3338
}
3439

3540
type Proxy struct {
@@ -58,38 +63,30 @@ func (p *Proxy) proxyUDP(ctx context.Context, l *net.UDPConn) {
5863
default:
5964
n, clientAddr, err := l.ReadFromUDP(buffer)
6065
if err != nil {
61-
p.logger.Error("Failed to read from UDP", slog.Any("error", err))
66+
p.logger.Error("failed to read from UDP", slog.String("error", err.Error()))
6267
return
6368
}
64-
p.mutex.Lock()
65-
conn, ok := p.connMap[clientAddr.String()]
66-
if !ok {
67-
conn, err = p.newConn(clientAddr)
68-
if err != nil {
69-
p.mutex.Unlock()
70-
p.logger.Error("Failed to create new connection", slog.Any("error", err))
71-
return
72-
}
73-
p.connMap[clientAddr.String()] = conn
74-
go p.downUDP(ctx, l, conn)
69+
conn, err := p.newConn(clientAddr)
70+
if err != nil {
71+
p.logger.Error("failed to create new connection", slog.String("error", err.Error()))
72+
continue
7573
}
76-
p.mutex.Unlock()
7774
//nolint:contextcheck // upUDP does not need context
78-
p.upUDP(conn, buffer[:n])
75+
p.upUDP(conn, buffer[:n], l)
7976
}
8077
}
8178
}
8279

8380
func (p *Proxy) Listen(ctx context.Context) error {
84-
addr, err := net.ResolveUDPAddr("udp6", net.JoinHostPort(p.config.Host, p.config.Port))
81+
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.config.Host, p.config.Port))
8582
if err != nil {
86-
p.logger.Error("Failed to resolve UDP address", slog.Any("error", err))
83+
p.logger.Error("failed to resolve UDP address", slog.String("error", err.Error()))
8784
return err
8885
}
8986
g, ctx := errgroup.WithContext(ctx)
9087
switch {
9188
case p.config.DTLSConfig != nil:
92-
l, err := dtls.Listen("udp6", addr, p.config.DTLSConfig)
89+
l, err := dtls.Listen("udp", addr, p.config.DTLSConfig)
9390
if err != nil {
9491
return err
9592
}
@@ -134,30 +131,44 @@ func (p *Proxy) Listen(ctx context.Context) error {
134131
}
135132

136133
func (p *Proxy) newConn(clientAddr *net.UDPAddr) (*Conn, error) {
137-
conn := new(Conn)
138-
conn.clientAddr = clientAddr
139-
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.config.TargetHost, p.config.TargetPort))
140-
if err != nil {
141-
return nil, err
142-
}
143-
t, err := net.DialUDP("udp", nil, addr)
144-
if err != nil {
145-
return nil, err
134+
p.mutex.Lock()
135+
defer p.mutex.Unlock()
136+
conn, ok := p.connMap[clientAddr.String()]
137+
if !ok {
138+
conn = &Conn{clientAddr: clientAddr}
139+
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.config.TargetHost, p.config.TargetPort))
140+
if err != nil {
141+
return nil, err
142+
}
143+
t, err := net.DialUDP("udp", nil, addr)
144+
if err != nil {
145+
return nil, err
146+
}
147+
conn.serverConn = t
148+
p.connMap[clientAddr.String()] = conn
146149
}
147-
conn.serverConn = t
148150
return conn, nil
149151
}
150152

151-
func (p *Proxy) upUDP(conn *Conn, buffer []byte) {
152-
err := p.handleCoAPMessage(context.Background(), buffer)
153-
if err != nil {
154-
p.logger.Error("Failed to handle CoAP message", slog.Any("err", err))
153+
func (p *Proxy) upUDP(conn *Conn, buffer []byte, l *net.UDPConn) {
154+
if msg, err := p.handleCoAPMessage(context.Background(), buffer); err != nil {
155+
data := p.encodeErrorResponse(context.Background(), msg, err)
156+
if len(data) > 0 {
157+
if _, werr := l.WriteToUDP(data, conn.clientAddr); werr != nil {
158+
p.logger.Error("failed to send error response", slog.String("err", werr.Error()))
159+
}
160+
}
155161
return
156162
}
157-
_, err = conn.serverConn.Write(buffer)
158-
if err != nil {
163+
164+
if _, err := conn.serverConn.Write(buffer); err != nil {
159165
return
160166
}
167+
168+
// Start the downstream reader once the first upstream write succeeds.
169+
if conn.started.CompareAndSwap(false, true) {
170+
go p.downUDP(context.Background(), l, conn)
171+
}
161172
}
162173

163174
func (p *Proxy) downUDP(ctx context.Context, l *net.UDPConn, conn *Conn) {
@@ -169,7 +180,7 @@ func (p *Proxy) downUDP(ctx context.Context, l *net.UDPConn, conn *Conn) {
169180
return
170181
default:
171182
}
172-
err := conn.serverConn.SetReadDeadline(time.Now().Add(10 * time.Second))
183+
err := conn.serverConn.SetReadDeadline(time.Now().Add(30 * time.Second))
173184
if err != nil {
174185
return
175186
}
@@ -198,28 +209,28 @@ func (p *Proxy) proxyDTLS(ctx context.Context, l net.Listener) {
198209
case <-ctx.Done():
199210
return
200211
default:
201-
conn, err := l.Accept()
202-
if err != nil {
203-
p.logger.Warn("Accept error " + err.Error())
204-
continue
205-
}
206-
p.logger.Info("Accepted new client")
207-
go p.handleDTLS(ctx, conn)
208212
}
213+
conn, err := l.Accept()
214+
if err != nil {
215+
p.logger.Warn("Accept error " + err.Error())
216+
continue
217+
}
218+
p.logger.Info("Accepted new client")
219+
go p.handleDTLS(ctx, conn)
209220
}
210221
}
211222

212223
func (p *Proxy) handleDTLS(ctx context.Context, inbound net.Conn) {
213224
defer inbound.Close()
214225
outboundAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(p.config.TargetHost, p.config.TargetPort))
215226
if err != nil {
216-
p.logger.Error("Cannot resolve remote broker address " + net.JoinHostPort(p.config.TargetHost, p.config.TargetPort) + " due to: " + err.Error())
227+
p.logger.Error("cannot resolve remote broker address " + net.JoinHostPort(p.config.TargetHost, p.config.TargetPort) + " due to: " + err.Error())
217228
return
218229
}
219230

220231
outbound, err := net.DialUDP("udp", nil, outboundAddr)
221232
if err != nil {
222-
p.logger.Error("Cannot connect to remote broker " + outboundAddr.String() + " due to: " + err.Error())
233+
p.logger.Error("cannot connect to remote broker " + outboundAddr.String() + " due to: " + err.Error())
223234
return
224235
}
225236
defer outbound.Close()
@@ -237,7 +248,7 @@ func (p *Proxy) handleDTLS(ctx context.Context, inbound net.Conn) {
237248
})
238249

239250
if err := g.Wait(); err != nil {
240-
p.logger.Error("DTLS proxy error", slog.Any("error", err))
251+
p.logger.Error("DTLS proxy error", slog.String("error", err.Error()))
241252
}
242253
}
243254

@@ -248,14 +259,17 @@ func (p *Proxy) dtlsUp(ctx context.Context, outbound *net.UDPConn, inbound net.C
248259
if err != nil {
249260
return
250261
}
251-
err = p.handleCoAPMessage(ctx, buffer[:n])
252-
if err != nil {
253-
p.logger.Error("Failed to handle CoAP message", slog.Any("err", err))
262+
if msg, err := p.handleCoAPMessage(ctx, buffer[:n]); err != nil {
263+
data := p.encodeErrorResponse(ctx, msg, err)
264+
if len(data) > 0 {
265+
if _, werr := inbound.Write(data); werr != nil {
266+
p.logger.Error("failed to send error response", slog.String("err", werr.Error()))
267+
}
268+
}
254269
return
255270
}
256271

257-
_, err = outbound.Write(buffer[:n])
258-
if err != nil {
272+
if _, err = outbound.Write(buffer[:n]); err != nil {
259273
return
260274
}
261275
}
@@ -273,63 +287,102 @@ func (p *Proxy) dtlsDown(inbound net.Conn, outbound *net.UDPConn) {
273287
return
274288
}
275289

276-
_, err = inbound.Write(buffer[:n])
277-
if err != nil {
290+
if _, err = inbound.Write(buffer[:n]); err != nil {
278291
return
279292
}
280293
}
281294
}
282295

283-
func (p *Proxy) handleCoAPMessage(ctx context.Context, buffer []byte) error {
296+
func (p *Proxy) handleCoAPMessage(ctx context.Context, buffer []byte) (*pool.Message, error) {
284297
var payload []byte
285298
var path string
286299
msg := pool.NewMessage(ctx)
287300
_, err := msg.UnmarshalWithDecoder(coder.DefaultCoder, buffer)
288301
if err != nil {
289-
return err
302+
return msg, err
290303
}
291-
token := msg.Token()
292-
if msg.Code() != codes.Empty {
293-
path, err = msg.Path()
294-
if err != nil {
295-
return err
296-
}
304+
if msg.Code() != codes.POST && msg.Code() != codes.GET {
305+
return msg, nil
297306
}
298-
ctx = session.NewContext(ctx, &session.Session{Password: token})
307+
308+
authKey, err := parseKey(msg)
309+
if err != nil {
310+
return msg, err
311+
}
312+
313+
path, err = msg.Path()
314+
if err != nil {
315+
return msg, err
316+
}
317+
318+
ctx = session.NewContext(ctx, &session.Session{Password: []byte(authKey)})
299319

300320
if msg.Body() != nil {
301321
payload, err = io.ReadAll(msg.Body())
302322
if err != nil {
303-
return err
323+
return msg, err
304324
}
305325
}
306326

307327
switch msg.Code() {
308328
case codes.POST:
309329
if err := p.session.AuthConnect(ctx); err != nil {
310-
return err
330+
return msg, err
311331
}
312332
if err := p.session.AuthPublish(ctx, &path, &payload); err != nil {
313-
return err
333+
return msg, err
314334
}
315335
if err := p.session.Publish(ctx, &path, &payload); err != nil {
316-
return err
336+
return msg, err
317337
}
318338
case codes.GET:
319339
if err := p.session.AuthConnect(ctx); err != nil {
320-
return err
340+
return msg, err
321341
}
322342
if obs, err := msg.Options().Observe(); err == nil {
323343
if obs == startObserve {
324344
if err := p.session.AuthSubscribe(ctx, &[]string{path}); err != nil {
325-
return err
345+
return msg, err
326346
}
327347
if err := p.session.Subscribe(ctx, &[]string{path}); err != nil {
328-
return err
348+
return msg, err
329349
}
330350
}
331351
}
332352
}
333353

334-
return nil
354+
return msg, nil
355+
}
356+
357+
func (p *Proxy) encodeErrorResponse(ctx context.Context, msg *pool.Message, err error) []byte {
358+
resp := pool.NewMessage(ctx)
359+
resp.SetToken(msg.Token())
360+
resp.SetMessageID(msg.MessageID())
361+
resp.SetType(msg.Type())
362+
for _, opt := range msg.Options() {
363+
resp.AddOptionBytes(opt.ID, opt.Value)
364+
}
365+
cpe, ok := err.(COAPProxyError)
366+
if !ok {
367+
cpe = NewCOAPProxyError(codes.BadRequest, err)
368+
}
369+
resp.SetCode(cpe.StatusCode())
370+
data, err := resp.MarshalWithEncoder(coder.DefaultCoder)
371+
if err != nil {
372+
p.logger.Error("failed to marshal error response message", slog.String("err", err.Error()))
373+
return nil
374+
}
375+
return data
376+
}
377+
378+
func parseKey(msg *pool.Message) (string, error) {
379+
authKey, err := msg.Options().GetString(message.URIQuery)
380+
if err != nil {
381+
return "", NewCOAPProxyError(codes.BadRequest, err)
382+
}
383+
vars := strings.Split(authKey, "=")
384+
if len(vars) != 2 || vars[0] != authQuery {
385+
return "", nil
386+
}
387+
return vars[1], nil
335388
}

pkg/coap/errors.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright (c) Abstract Machines
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package coap
5+
6+
import (
7+
"encoding/json"
8+
9+
"github.com/plgd-dev/go-coap/v3/message/codes"
10+
)
11+
12+
type coapProxyError struct {
13+
statusCode codes.Code
14+
err error
15+
}
16+
17+
type COAPProxyError interface {
18+
error
19+
MarshalJSON() ([]byte, error)
20+
StatusCode() codes.Code
21+
}
22+
23+
var _ COAPProxyError = (*coapProxyError)(nil)
24+
25+
func (cpe *coapProxyError) Error() string {
26+
return cpe.err.Error()
27+
}
28+
29+
func (cpe *coapProxyError) MarshalJSON() ([]byte, error) {
30+
return json.Marshal(struct {
31+
Error string `json:"message"`
32+
}{
33+
Error: cpe.err.Error(),
34+
})
35+
}
36+
37+
func (cpe *coapProxyError) StatusCode() codes.Code {
38+
return cpe.statusCode
39+
}
40+
41+
func NewCOAPProxyError(statusCode codes.Code, err error) COAPProxyError {
42+
return &coapProxyError{statusCode: statusCode, err: err}
43+
}

0 commit comments

Comments
 (0)