Skip to content

Commit c690155

Browse files
committed
TUN-8822: Prevent concurrent usage of ICMPDecoder
## Summary Some description... Closes TUN-8822
1 parent 9bc6cbd commit c690155

File tree

2 files changed

+107
-6
lines changed

2 files changed

+107
-6
lines changed

quic/v3/muxer.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,11 @@ type datagramConn struct {
6565
icmpRouter ingress.ICMPRouter
6666
metrics Metrics
6767
logger *zerolog.Logger
68-
69-
datagrams chan []byte
70-
readErrors chan error
68+
datagrams chan []byte
69+
readErrors chan error
7170

7271
icmpEncoderPool sync.Pool // a pool of *packet.Encoder
73-
icmpDecoder *packet.ICMPDecoder
72+
icmpDecoderPool sync.Pool
7473
}
7574

7675
func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRouter ingress.ICMPRouter, index uint8, metrics Metrics, logger *zerolog.Logger) DatagramConn {
@@ -89,7 +88,11 @@ func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRou
8988
return packet.NewEncoder()
9089
},
9190
},
92-
icmpDecoder: packet.NewICMPDecoder(),
91+
icmpDecoderPool: sync.Pool{
92+
New: func() any {
93+
return packet.NewICMPDecoder()
94+
},
95+
},
9396
}
9497
}
9598

@@ -367,7 +370,16 @@ func (c *datagramConn) handleICMPPacket(datagram *ICMPDatagram) {
367370

368371
// Decode the provided ICMPDatagram as an ICMP packet
369372
rawPacket := packet.RawPacket{Data: datagram.Payload}
370-
icmp, err := c.icmpDecoder.Decode(rawPacket)
373+
cachedDecoder := c.icmpDecoderPool.Get()
374+
defer c.icmpDecoderPool.Put(cachedDecoder)
375+
decoder, ok := cachedDecoder.(*packet.ICMPDecoder)
376+
if !ok {
377+
c.logger.Error().Msg("Could not get ICMPDecoder from the pool. Dropping packet")
378+
return
379+
}
380+
381+
icmp, err := decoder.Decode(rawPacket)
382+
371383
if err != nil {
372384
c.logger.Err(err).Msgf("unable to marshal icmp packet")
373385
return

quic/v3/muxer_test.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@ import (
44
"bytes"
55
"context"
66
"errors"
7+
"fmt"
78
"net"
89
"net/netip"
910
"slices"
11+
"sort"
1012
"sync"
1113
"testing"
1214
"time"
1315

16+
"github.com/stretchr/testify/assert"
17+
1418
"github.com/google/gopacket/layers"
1519
"github.com/rs/zerolog"
1620
"golang.org/x/net/icmp"
@@ -304,6 +308,91 @@ func TestDatagramConnServe(t *testing.T) {
304308
assertContextClosed(t, ctx, done, cancel)
305309
}
306310

311+
// This test exists because decoding multiple packets in parallel with the same decoder
312+
// instances causes inteference resulting in multiple different raw packets being decoded
313+
// as the same decoded packet.
314+
func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
315+
log := zerolog.Nop()
316+
quic := newMockQuicConn()
317+
session := newMockSession()
318+
sessionManager := mockSessionManager{session: &session}
319+
router := newMockICMPRouter()
320+
conn := v3.NewDatagramConn(quic, &sessionManager, router, 0, &noopMetrics{}, &log)
321+
322+
// Setup the muxer
323+
ctx, cancel := context.WithCancelCause(context.Background())
324+
defer cancel(errors.New("other error"))
325+
done := make(chan error, 1)
326+
go func() {
327+
done <- conn.Serve(ctx)
328+
}()
329+
330+
packetCount := 100
331+
packets := make([]*packet.ICMP, 100)
332+
ipTemplate := "10.0.0.%d"
333+
for i := 1; i <= packetCount; i++ {
334+
packets[i-1] = &packet.ICMP{
335+
IP: &packet.IP{
336+
Src: netip.MustParseAddr("192.168.1.1"),
337+
Dst: netip.MustParseAddr(fmt.Sprintf(ipTemplate, i)),
338+
Protocol: layers.IPProtocolICMPv4,
339+
TTL: 20,
340+
},
341+
Message: &icmp.Message{
342+
Type: ipv4.ICMPTypeEcho,
343+
Code: 0,
344+
Body: &icmp.Echo{
345+
ID: 25821,
346+
Seq: 58129,
347+
Data: []byte("test"),
348+
},
349+
},
350+
}
351+
}
352+
353+
wg := sync.WaitGroup{}
354+
var receivedPackets []*packet.ICMP
355+
go func() {
356+
for ctx.Err() == nil {
357+
select {
358+
case icmpPacket := <-router.recv:
359+
receivedPackets = append(receivedPackets, icmpPacket)
360+
wg.Done()
361+
}
362+
}
363+
}()
364+
365+
for _, p := range packets {
366+
// We increment here but only decrement when receiving the packet
367+
wg.Add(1)
368+
go func() {
369+
datagram := newICMPDatagram(p)
370+
quic.send <- datagram
371+
}()
372+
}
373+
374+
wg.Wait()
375+
376+
// If there were duplicates then we won't have the same number of IPs
377+
packetSet := make(map[netip.Addr]*packet.ICMP, 0)
378+
for _, p := range receivedPackets {
379+
packetSet[p.Dst] = p
380+
}
381+
assert.Equal(t, len(packetSet), len(packets))
382+
383+
// Sort the slice by last byte of IP address (the one we increment for each destination)
384+
// and then check that we have one match for each packet sent
385+
sort.Slice(receivedPackets, func(i, j int) bool {
386+
return receivedPackets[i].Dst.As4()[3] < receivedPackets[j].Dst.As4()[3]
387+
})
388+
for i, p := range receivedPackets {
389+
assert.Equal(t, p.Dst, packets[i].Dst)
390+
}
391+
392+
// Cancel the muxer Serve context and make sure it closes with the expected error
393+
assertContextClosed(t, ctx, done, cancel)
394+
}
395+
307396
func TestDatagramConnServe_RegisterTwice(t *testing.T) {
308397
log := zerolog.Nop()
309398
quic := newMockQuicConn()

0 commit comments

Comments
 (0)