@@ -9,13 +9,16 @@ import (
99 "io"
1010 "log/slog"
1111 "net"
12+ "strings"
1213 "sync"
14+ "sync/atomic"
1315 "time"
1416
1517 "github.com/absmach/mgate"
1618 "github.com/absmach/mgate/pkg/session"
1719 mptls "github.com/absmach/mgate/pkg/tls"
1820 "github.com/pion/dtls/v3"
21+ "github.com/plgd-dev/go-coap/v3/message"
1922 "github.com/plgd-dev/go-coap/v3/message/codes"
2023 "github.com/plgd-dev/go-coap/v3/message/pool"
2124 "github.com/plgd-dev/go-coap/v3/udp/coder"
@@ -25,11 +28,13 @@ import (
2528const (
2629 bufferSize uint64 = 1280
2730 startObserve uint32 = 0
31+ authQuery = "auth"
2832)
2933
3034type Conn struct {
3135 clientAddr * net.UDPAddr
3236 serverConn * net.UDPConn
37+ started atomic.Bool
3338}
3439
3540type Proxy struct {
@@ -58,38 +63,30 @@ func (p *Proxy) proxyUDP(ctx context.Context, l *net.UDPConn) {
5863 default :
5964 n , clientAddr , err := l .ReadFromUDP (buffer )
6065 if err != nil {
61- p .logger .Error ("Failed to read from UDP" , slog .Any ("error" , err ))
66+ p .logger .Error ("failed to read from UDP" , slog .String ("error" , err . Error () ))
6267 return
6368 }
64- p .mutex .Lock ()
65- conn , ok := p .connMap [clientAddr .String ()]
66- if ! ok {
67- conn , err = p .newConn (clientAddr )
68- if err != nil {
69- p .mutex .Unlock ()
70- p .logger .Error ("Failed to create new connection" , slog .Any ("error" , err ))
71- return
72- }
73- p .connMap [clientAddr .String ()] = conn
74- go p .downUDP (ctx , l , conn )
69+ conn , err := p .newConn (clientAddr )
70+ if err != nil {
71+ p .logger .Error ("failed to create new connection" , slog .String ("error" , err .Error ()))
72+ continue
7573 }
76- p .mutex .Unlock ()
7774 //nolint:contextcheck // upUDP does not need context
78- p .upUDP (conn , buffer [:n ])
75+ p .upUDP (conn , buffer [:n ], l )
7976 }
8077 }
8178}
8279
8380func (p * Proxy ) Listen (ctx context.Context ) error {
84- addr , err := net .ResolveUDPAddr ("udp6 " , net .JoinHostPort (p .config .Host , p .config .Port ))
81+ addr , err := net .ResolveUDPAddr ("udp " , net .JoinHostPort (p .config .Host , p .config .Port ))
8582 if err != nil {
86- p .logger .Error ("Failed to resolve UDP address" , slog .Any ("error" , err ))
83+ p .logger .Error ("failed to resolve UDP address" , slog .String ("error" , err . Error () ))
8784 return err
8885 }
8986 g , ctx := errgroup .WithContext (ctx )
9087 switch {
9188 case p .config .DTLSConfig != nil :
92- l , err := dtls .Listen ("udp6 " , addr , p .config .DTLSConfig )
89+ l , err := dtls .Listen ("udp " , addr , p .config .DTLSConfig )
9390 if err != nil {
9491 return err
9592 }
@@ -134,30 +131,44 @@ func (p *Proxy) Listen(ctx context.Context) error {
134131}
135132
136133func (p * Proxy ) newConn (clientAddr * net.UDPAddr ) (* Conn , error ) {
137- conn := new (Conn )
138- conn .clientAddr = clientAddr
139- addr , err := net .ResolveUDPAddr ("udp" , net .JoinHostPort (p .config .TargetHost , p .config .TargetPort ))
140- if err != nil {
141- return nil , err
142- }
143- t , err := net .DialUDP ("udp" , nil , addr )
144- if err != nil {
145- return nil , err
134+ p .mutex .Lock ()
135+ defer p .mutex .Unlock ()
136+ conn , ok := p .connMap [clientAddr .String ()]
137+ if ! ok {
138+ conn = & Conn {clientAddr : clientAddr }
139+ addr , err := net .ResolveUDPAddr ("udp" , net .JoinHostPort (p .config .TargetHost , p .config .TargetPort ))
140+ if err != nil {
141+ return nil , err
142+ }
143+ t , err := net .DialUDP ("udp" , nil , addr )
144+ if err != nil {
145+ return nil , err
146+ }
147+ conn .serverConn = t
148+ p .connMap [clientAddr .String ()] = conn
146149 }
147- conn .serverConn = t
148150 return conn , nil
149151}
150152
151- func (p * Proxy ) upUDP (conn * Conn , buffer []byte ) {
152- err := p .handleCoAPMessage (context .Background (), buffer )
153- if err != nil {
154- p .logger .Error ("Failed to handle CoAP message" , slog .Any ("err" , err ))
153+ func (p * Proxy ) upUDP (conn * Conn , buffer []byte , l * net.UDPConn ) {
154+ if msg , err := p .handleCoAPMessage (context .Background (), buffer ); err != nil {
155+ data := p .encodeErrorResponse (context .Background (), msg , err )
156+ if len (data ) > 0 {
157+ if _ , werr := l .WriteToUDP (data , conn .clientAddr ); werr != nil {
158+ p .logger .Error ("failed to send error response" , slog .String ("err" , werr .Error ()))
159+ }
160+ }
155161 return
156162 }
157- _ , err = conn . serverConn . Write ( buffer )
158- if err != nil {
163+
164+ if _ , err := conn . serverConn . Write ( buffer ); err != nil {
159165 return
160166 }
167+
168+ // Start the downstream reader once the first upstream write succeeds.
169+ if conn .started .CompareAndSwap (false , true ) {
170+ go p .downUDP (context .Background (), l , conn )
171+ }
161172}
162173
163174func (p * Proxy ) downUDP (ctx context.Context , l * net.UDPConn , conn * Conn ) {
@@ -169,7 +180,7 @@ func (p *Proxy) downUDP(ctx context.Context, l *net.UDPConn, conn *Conn) {
169180 return
170181 default :
171182 }
172- err := conn .serverConn .SetReadDeadline (time .Now ().Add (10 * time .Second ))
183+ err := conn .serverConn .SetReadDeadline (time .Now ().Add (30 * time .Second ))
173184 if err != nil {
174185 return
175186 }
@@ -198,28 +209,28 @@ func (p *Proxy) proxyDTLS(ctx context.Context, l net.Listener) {
198209 case <- ctx .Done ():
199210 return
200211 default :
201- conn , err := l .Accept ()
202- if err != nil {
203- p .logger .Warn ("Accept error " + err .Error ())
204- continue
205- }
206- p .logger .Info ("Accepted new client" )
207- go p .handleDTLS (ctx , conn )
208212 }
213+ conn , err := l .Accept ()
214+ if err != nil {
215+ p .logger .Warn ("Accept error " + err .Error ())
216+ continue
217+ }
218+ p .logger .Info ("Accepted new client" )
219+ go p .handleDTLS (ctx , conn )
209220 }
210221}
211222
212223func (p * Proxy ) handleDTLS (ctx context.Context , inbound net.Conn ) {
213224 defer inbound .Close ()
214225 outboundAddr , err := net .ResolveUDPAddr ("udp" , net .JoinHostPort (p .config .TargetHost , p .config .TargetPort ))
215226 if err != nil {
216- p .logger .Error ("Cannot resolve remote broker address " + net .JoinHostPort (p .config .TargetHost , p .config .TargetPort ) + " due to: " + err .Error ())
227+ p .logger .Error ("cannot resolve remote broker address " + net .JoinHostPort (p .config .TargetHost , p .config .TargetPort ) + " due to: " + err .Error ())
217228 return
218229 }
219230
220231 outbound , err := net .DialUDP ("udp" , nil , outboundAddr )
221232 if err != nil {
222- p .logger .Error ("Cannot connect to remote broker " + outboundAddr .String () + " due to: " + err .Error ())
233+ p .logger .Error ("cannot connect to remote broker " + outboundAddr .String () + " due to: " + err .Error ())
223234 return
224235 }
225236 defer outbound .Close ()
@@ -237,7 +248,7 @@ func (p *Proxy) handleDTLS(ctx context.Context, inbound net.Conn) {
237248 })
238249
239250 if err := g .Wait (); err != nil {
240- p .logger .Error ("DTLS proxy error" , slog .Any ("error" , err ))
251+ p .logger .Error ("DTLS proxy error" , slog .String ("error" , err . Error () ))
241252 }
242253}
243254
@@ -248,14 +259,17 @@ func (p *Proxy) dtlsUp(ctx context.Context, outbound *net.UDPConn, inbound net.C
248259 if err != nil {
249260 return
250261 }
251- err = p .handleCoAPMessage (ctx , buffer [:n ])
252- if err != nil {
253- p .logger .Error ("Failed to handle CoAP message" , slog .Any ("err" , err ))
262+ if msg , err := p .handleCoAPMessage (ctx , buffer [:n ]); err != nil {
263+ data := p .encodeErrorResponse (ctx , msg , err )
264+ if len (data ) > 0 {
265+ if _ , werr := inbound .Write (data ); werr != nil {
266+ p .logger .Error ("failed to send error response" , slog .String ("err" , werr .Error ()))
267+ }
268+ }
254269 return
255270 }
256271
257- _ , err = outbound .Write (buffer [:n ])
258- if err != nil {
272+ if _ , err = outbound .Write (buffer [:n ]); err != nil {
259273 return
260274 }
261275 }
@@ -273,63 +287,102 @@ func (p *Proxy) dtlsDown(inbound net.Conn, outbound *net.UDPConn) {
273287 return
274288 }
275289
276- _ , err = inbound .Write (buffer [:n ])
277- if err != nil {
290+ if _ , err = inbound .Write (buffer [:n ]); err != nil {
278291 return
279292 }
280293 }
281294}
282295
283- func (p * Proxy ) handleCoAPMessage (ctx context.Context , buffer []byte ) error {
296+ func (p * Proxy ) handleCoAPMessage (ctx context.Context , buffer []byte ) ( * pool. Message , error ) {
284297 var payload []byte
285298 var path string
286299 msg := pool .NewMessage (ctx )
287300 _ , err := msg .UnmarshalWithDecoder (coder .DefaultCoder , buffer )
288301 if err != nil {
289- return err
302+ return msg , err
290303 }
291- token := msg .Token ()
292- if msg .Code () != codes .Empty {
293- path , err = msg .Path ()
294- if err != nil {
295- return err
296- }
304+ if msg .Code () != codes .POST && msg .Code () != codes .GET {
305+ return msg , nil
297306 }
298- ctx = session .NewContext (ctx , & session.Session {Password : token })
307+
308+ authKey , err := parseKey (msg )
309+ if err != nil {
310+ return msg , err
311+ }
312+
313+ path , err = msg .Path ()
314+ if err != nil {
315+ return msg , err
316+ }
317+
318+ ctx = session .NewContext (ctx , & session.Session {Password : []byte (authKey )})
299319
300320 if msg .Body () != nil {
301321 payload , err = io .ReadAll (msg .Body ())
302322 if err != nil {
303- return err
323+ return msg , err
304324 }
305325 }
306326
307327 switch msg .Code () {
308328 case codes .POST :
309329 if err := p .session .AuthConnect (ctx ); err != nil {
310- return err
330+ return msg , err
311331 }
312332 if err := p .session .AuthPublish (ctx , & path , & payload ); err != nil {
313- return err
333+ return msg , err
314334 }
315335 if err := p .session .Publish (ctx , & path , & payload ); err != nil {
316- return err
336+ return msg , err
317337 }
318338 case codes .GET :
319339 if err := p .session .AuthConnect (ctx ); err != nil {
320- return err
340+ return msg , err
321341 }
322342 if obs , err := msg .Options ().Observe (); err == nil {
323343 if obs == startObserve {
324344 if err := p .session .AuthSubscribe (ctx , & []string {path }); err != nil {
325- return err
345+ return msg , err
326346 }
327347 if err := p .session .Subscribe (ctx , & []string {path }); err != nil {
328- return err
348+ return msg , err
329349 }
330350 }
331351 }
332352 }
333353
334- return nil
354+ return msg , nil
355+ }
356+
357+ func (p * Proxy ) encodeErrorResponse (ctx context.Context , msg * pool.Message , err error ) []byte {
358+ resp := pool .NewMessage (ctx )
359+ resp .SetToken (msg .Token ())
360+ resp .SetMessageID (msg .MessageID ())
361+ resp .SetType (msg .Type ())
362+ for _ , opt := range msg .Options () {
363+ resp .AddOptionBytes (opt .ID , opt .Value )
364+ }
365+ cpe , ok := err .(COAPProxyError )
366+ if ! ok {
367+ cpe = NewCOAPProxyError (codes .BadRequest , err )
368+ }
369+ resp .SetCode (cpe .StatusCode ())
370+ data , err := resp .MarshalWithEncoder (coder .DefaultCoder )
371+ if err != nil {
372+ p .logger .Error ("failed to marshal error response message" , slog .String ("err" , err .Error ()))
373+ return nil
374+ }
375+ return data
376+ }
377+
378+ func parseKey (msg * pool.Message ) (string , error ) {
379+ authKey , err := msg .Options ().GetString (message .URIQuery )
380+ if err != nil {
381+ return "" , NewCOAPProxyError (codes .BadRequest , err )
382+ }
383+ vars := strings .Split (authKey , "=" )
384+ if len (vars ) != 2 || vars [0 ] != authQuery {
385+ return "" , nil
386+ }
387+ return vars [1 ], nil
335388}
0 commit comments