Skip to content

Commit 3449ea3

Browse files
committed
TUN-6791: Calculate ICMPv6 checksum
1 parent 7f487c2 commit 3449ea3

File tree

5 files changed

+134
-7
lines changed

5 files changed

+134
-7
lines changed

ingress/origin_icmp_proxy_test.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,7 @@ func (efr *echoFlowResponder) validate(t *testing.T, echoReq *packet.ICMP) {
299299
require.Equal(t, ipv6.ICMPTypeEchoReply, decoded.Type)
300300
}
301301
require.Equal(t, 0, decoded.Code)
302-
if echoReq.Type == ipv4.ICMPTypeEcho {
303-
require.NotZero(t, decoded.Checksum)
304-
} else {
305-
// For ICMPv6, the kernel will compute the checksum during transmission unless pseudo header is not nil
306-
require.Zero(t, decoded.Checksum)
307-
}
302+
require.NotZero(t, decoded.Checksum)
308303

309304
require.Equal(t, echoReq.Body, decoded.Body)
310305
}

packet/decoder_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ func TestDecodeICMP(t *testing.T) {
152152

153153
require.Equal(t, test.packet.Type, icmpPacket.Type)
154154
require.Equal(t, test.packet.Code, icmpPacket.Code)
155+
assertICMPChecksum(t, icmpPacket)
155156
require.Equal(t, test.packet.Body, icmpPacket.Body)
156157
expectedBody, err := test.packet.Body.Marshal(test.packet.Type.Protocol())
157158
require.NoError(t, err)

packet/packet.go

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package packet
22

33
import (
4+
"encoding/binary"
45
"fmt"
56
"net/netip"
67

@@ -21,6 +22,7 @@ const (
2122
// 0 = ttl exceed in transit, 1 = fragment reassembly time exceeded
2223
icmpTTLExceedInTransitCode = 0
2324
DefaultTTL uint8 = 255
25+
pseudoHeaderLen = 40
2426
)
2527

2628
// Packet represents an IP packet or a packet that is encapsulated by IP
@@ -117,14 +119,48 @@ func (i *ICMP) EncodeLayers() ([]gopacket.SerializableLayer, error) {
117119
return nil, err
118120
}
119121

120-
msg, err := i.Marshal(nil)
122+
var serializedPsh []byte = nil
123+
if i.Protocol == layers.IPProtocolICMPv6 {
124+
psh := &PseudoHeader{
125+
SrcIP: i.Src.As16(),
126+
DstIP: i.Dst.As16(),
127+
// i.Marshal re-calculates the UpperLayerPacketLength
128+
UpperLayerPacketLength: 0,
129+
NextHeader: uint8(i.Protocol),
130+
}
131+
serializedPsh = psh.Marshal()
132+
}
133+
msg, err := i.Marshal(serializedPsh)
121134
if err != nil {
122135
return nil, err
123136
}
124137
icmpLayer := gopacket.Payload(msg)
125138
return append(ipLayers, icmpLayer), nil
126139
}
127140

141+
// https://www.rfc-editor.org/rfc/rfc2460#section-8.1
142+
type PseudoHeader struct {
143+
SrcIP [16]byte
144+
DstIP [16]byte
145+
UpperLayerPacketLength uint32
146+
zero [3]byte
147+
NextHeader uint8
148+
}
149+
150+
func (ph *PseudoHeader) Marshal() []byte {
151+
buf := make([]byte, pseudoHeaderLen)
152+
index := 0
153+
copy(buf, ph.SrcIP[:])
154+
index += 16
155+
copy(buf[index:], ph.DstIP[:])
156+
index += 16
157+
binary.BigEndian.PutUint32(buf[index:], ph.UpperLayerPacketLength)
158+
index += 4
159+
copy(buf[index:], ph.zero[:])
160+
buf[pseudoHeaderLen-1] = ph.NextHeader
161+
return buf
162+
}
163+
128164
func NewICMPTTLExceedPacket(originalIP *IP, originalPacket RawPacket, routerIP netip.Addr) *ICMP {
129165
var (
130166
protocol layers.IPProtocol
@@ -137,6 +173,7 @@ func NewICMPTTLExceedPacket(originalIP *IP, originalPacket RawPacket, routerIP n
137173
protocol = layers.IPProtocolICMPv6
138174
icmpType = ipv6.ICMPTypeTimeExceeded
139175
}
176+
140177
return &ICMP{
141178
IP: &IP{
142179
Src: routerIP,

packet/packet_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"net/netip"
66
"testing"
77

8+
"github.com/google/gopacket"
89
"github.com/google/gopacket/layers"
910
"github.com/stretchr/testify/require"
1011
"golang.org/x/net/icmp"
@@ -101,4 +102,96 @@ func assertTTLExceedPacket(t *testing.T, pk *ICMP) {
101102
require.Len(t, rawTTLExceedPacket.Data, headerLen+icmpHeaderLen+len(rawPacket.Data))
102103
require.True(t, bytes.Equal(rawPacket.Data, rawTTLExceedPacket.Data[headerLen+icmpHeaderLen:]))
103104
}
105+
106+
decoder := NewICMPDecoder()
107+
decodedPacket, err := decoder.Decode(rawTTLExceedPacket)
108+
require.NoError(t, err)
109+
assertICMPChecksum(t, decodedPacket)
110+
}
111+
112+
func assertICMPChecksum(t *testing.T, icmpPacket *ICMP) {
113+
buf := gopacket.NewSerializeBuffer()
114+
if icmpPacket.Protocol == layers.IPProtocolICMPv4 {
115+
icmpv4 := layers.ICMPv4{
116+
TypeCode: layers.CreateICMPv4TypeCode(uint8(icmpPacket.Type.(ipv4.ICMPType)), uint8(icmpPacket.Code)),
117+
}
118+
switch body := icmpPacket.Body.(type) {
119+
case *icmp.Echo:
120+
icmpv4.Id = uint16(body.ID)
121+
icmpv4.Seq = uint16(body.Seq)
122+
payload := gopacket.Payload(body.Data)
123+
require.NoError(t, payload.SerializeTo(buf, serializeOpts))
124+
default:
125+
require.NoError(t, serializeICMPAsPayload(icmpPacket.Message, buf))
126+
}
127+
// SerializeTo sets the checksum in icmpv4
128+
require.NoError(t, icmpv4.SerializeTo(buf, serializeOpts))
129+
require.Equal(t, icmpv4.Checksum, uint16(icmpPacket.Checksum))
130+
} else {
131+
switch body := icmpPacket.Body.(type) {
132+
case *icmp.Echo:
133+
payload := gopacket.Payload(body.Data)
134+
require.NoError(t, payload.SerializeTo(buf, serializeOpts))
135+
echo := layers.ICMPv6Echo{
136+
Identifier: uint16(body.ID),
137+
SeqNumber: uint16(body.Seq),
138+
}
139+
require.NoError(t, echo.SerializeTo(buf, serializeOpts))
140+
default:
141+
require.NoError(t, serializeICMPAsPayload(icmpPacket.Message, buf))
142+
}
143+
144+
icmpv6 := layers.ICMPv6{
145+
TypeCode: layers.CreateICMPv6TypeCode(uint8(icmpPacket.Type.(ipv6.ICMPType)), uint8(icmpPacket.Code)),
146+
}
147+
ipLayer := layers.IPv6{
148+
Version: 6,
149+
SrcIP: icmpPacket.Src.AsSlice(),
150+
DstIP: icmpPacket.Dst.AsSlice(),
151+
NextHeader: icmpPacket.Protocol,
152+
HopLimit: icmpPacket.TTL,
153+
}
154+
require.NoError(t, icmpv6.SetNetworkLayerForChecksum(&ipLayer))
155+
156+
// SerializeTo sets the checksum in icmpv4
157+
require.NoError(t, icmpv6.SerializeTo(buf, serializeOpts))
158+
require.Equal(t, icmpv6.Checksum, uint16(icmpPacket.Checksum))
159+
}
160+
}
161+
162+
func serializeICMPAsPayload(message *icmp.Message, buf gopacket.SerializeBuffer) error {
163+
serializedBody, err := message.Body.Marshal(message.Type.Protocol())
164+
if err != nil {
165+
return err
166+
}
167+
payload := gopacket.Payload(serializedBody)
168+
return payload.SerializeTo(buf, serializeOpts)
169+
}
170+
171+
func TestChecksum(t *testing.T) {
172+
data := []byte{0x63, 0x2c, 0x49, 0xd6, 0x00, 0x0d, 0xc1, 0xda}
173+
pk := ICMP{
174+
IP: &IP{
175+
Src: netip.MustParseAddr("2606:4700:110:89c1:c63a:861:e08c:b049"),
176+
Dst: netip.MustParseAddr("fde8:b693:d420:109b::2"),
177+
Protocol: layers.IPProtocolICMPv6,
178+
TTL: 3,
179+
},
180+
Message: &icmp.Message{
181+
Type: ipv6.ICMPTypeEchoRequest,
182+
Code: 0,
183+
Body: &icmp.Echo{
184+
ID: 0x20a7,
185+
Seq: 8,
186+
Data: data,
187+
},
188+
},
189+
}
190+
encoder := NewEncoder()
191+
encoded, err := encoder.Encode(&pk)
192+
require.NoError(t, err)
193+
194+
decoder := NewICMPDecoder()
195+
decoded, err := decoder.Decode(encoded)
196+
require.Equal(t, 0xff96, decoded.Checksum)
104197
}

packet/router_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ func assertTTLExceed(t *testing.T, originalPacket *ICMP, expectedSrc netip.Addr,
9696
require.Equal(t, ipv6.ICMPTypeTimeExceeded, decoded.Type)
9797
}
9898
require.Equal(t, 0, decoded.Code)
99+
assertICMPChecksum(t, decoded)
99100
timeExceed, ok := decoded.Body.(*icmp.TimeExceeded)
100101
require.True(t, ok)
101102
require.True(t, bytes.Equal(rawPacket.Data, timeExceed.Data))

0 commit comments

Comments
 (0)