Skip to content

Commit 548f51c

Browse files
committed
ping: Fix test
1 parent ce050ba commit 548f51c

File tree

2 files changed

+52
-27
lines changed

2 files changed

+52
-27
lines changed

ping/ping.go

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ type Conn struct {
2626
ctx context.Context
2727
logger logger.ContextLogger
2828
privileged bool
29-
bitwiseID bool
3029
conn net.Conn
3130
destination netip.Addr
3231
source atomic.TypedValue[netip.Addr]
@@ -38,15 +37,10 @@ func Connect(ctx context.Context, logger logger.ContextLogger, privileged bool,
3837
if err != nil {
3938
return nil, err
4039
}
41-
replaceID := true
42-
if _, ok := conn.(*UnprivilegedConn); ok {
43-
replaceID = false
44-
}
4540
return &Conn{
4641
ctx: ctx,
4742
logger: logger,
4843
privileged: privileged,
49-
bitwiseID: replaceID,
5044
conn: conn,
5145
destination: destination,
5246
}, nil
@@ -108,7 +102,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
108102
}
109103
ttl = controlMessage.TTL
110104
}
111-
if c.bitwiseID {
105+
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
112106
icmpHdr := header.ICMPv4(buffer.Bytes())
113107
icmpHdr.SetIdent(^icmpHdr.Ident())
114108
icmpHdr.SetChecksum(0)
@@ -147,7 +141,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
147141
trafficClass = controlMessage.TrafficClass
148142
}
149143
icmpHdr := header.ICMPv6(buffer.Bytes())
150-
if c.bitwiseID {
144+
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
151145
icmpHdr.SetIdent(^icmpHdr.Ident())
152146
}
153147
icmpHdr.SetChecksum(0)
@@ -188,7 +182,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
188182
ipHdr.SetChecksum(0)
189183
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
190184
icmpHdr := header.ICMPv4(ipHdr.Payload())
191-
if c.bitwiseID {
185+
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
192186
icmpHdr.SetIdent(^icmpHdr.Ident())
193187
}
194188
icmpHdr.SetChecksum(0)
@@ -201,7 +195,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
201195
}
202196
ipHdr.SetDestinationAddr(c.source.Load())
203197
icmpHdr := header.ICMPv6(ipHdr.Payload())
204-
if c.bitwiseID {
198+
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
205199
icmpHdr.SetIdent(^icmpHdr.Ident())
206200
}
207201
icmpHdr.SetChecksum(0)
@@ -221,17 +215,25 @@ func (c *Conn) ReadICMP(buffer *buf.Buffer) error {
221215
if err != nil {
222216
return err
223217
}
224-
if c.destination.Is6() || (runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged {
225-
return nil
226-
}
227-
if !c.destination.Is6() {
228-
ipHdr := header.IPv4(buffer.Bytes())
229-
buffer.Advance(int(ipHdr.HeaderLength()))
230-
c.logger.TraceContext(c.ctx, "read icmpv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr())
231-
} else {
232-
ipHdr := header.IPv6(buffer.Bytes())
233-
buffer.Advance(buffer.Len() - int(ipHdr.PayloadLength()))
234-
c.logger.TraceContext(c.ctx, "read icmpv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr())
218+
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
219+
if !c.destination.Is6() {
220+
ipHdr := header.IPv4(buffer.Bytes())
221+
buffer.Advance(int(ipHdr.HeaderLength()))
222+
223+
icmpHdr := header.ICMPv4(buffer.Bytes())
224+
icmpHdr.SetIdent(^icmpHdr.Ident())
225+
icmpHdr.SetChecksum(0)
226+
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
227+
} else {
228+
icmpHdr := header.ICMPv6(buffer.Bytes())
229+
icmpHdr.SetIdent(^icmpHdr.Ident())
230+
icmpHdr.SetChecksum(0)
231+
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
232+
Header: icmpHdr,
233+
Src: c.destination.AsSlice(),
234+
Dst: c.source.Load().AsSlice(),
235+
}))
236+
}
235237
}
236238
return nil
237239
}
@@ -240,7 +242,7 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
240242
defer buffer.Release()
241243
if !c.destination.Is6() {
242244
ipHdr := header.IPv4(buffer.Bytes())
243-
if c.bitwiseID {
245+
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
244246
icmpHdr := header.ICMPv4(ipHdr.Payload())
245247
icmpHdr.SetIdent(^icmpHdr.Ident())
246248
icmpHdr.SetChecksum(0)
@@ -251,7 +253,7 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
251253
return common.Error(c.conn.Write(ipHdr.Payload()))
252254
} else {
253255
ipHdr := header.IPv6(buffer.Bytes())
254-
if c.bitwiseID {
256+
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
255257
icmpHdr := header.ICMPv6(ipHdr.Payload())
256258
icmpHdr.SetIdent(^icmpHdr.Ident())
257259
icmpHdr.SetChecksum(0)
@@ -269,6 +271,29 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
269271

270272
func (c *Conn) WriteICMP(buffer *buf.Buffer) error {
271273
defer buffer.Release()
274+
if !((runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged) {
275+
if !c.destination.Is6() {
276+
icmpHdr := header.ICMPv4(buffer.Bytes())
277+
icmpHdr.SetIdent(^icmpHdr.Ident())
278+
icmpHdr.SetChecksum(0)
279+
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
280+
c.logger.TraceContext(c.ctx, "write icmpv4 echo request to ", c.destination)
281+
} else {
282+
icmpHdr := header.ICMPv6(buffer.Bytes())
283+
icmpHdr.SetIdent(^icmpHdr.Ident())
284+
icmpHdr.SetChecksum(0)
285+
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
286+
Header: icmpHdr,
287+
Src: c.source.Load().AsSlice(),
288+
Dst: c.destination.AsSlice(),
289+
}))
290+
}
291+
}
292+
if !c.destination.Is6() {
293+
c.logger.TraceContext(c.ctx, "write icmpv4 echo request to ", c.destination)
294+
} else {
295+
c.logger.TraceContext(c.ctx, "write icmpv6 echo request to ", c.destination)
296+
}
272297
return common.Error(c.conn.Write(buffer.Bytes()))
273298
}
274299

ping/ping_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func testPingIPv4ReadIP(t *testing.T, privileged bool, addr string) {
8484
request.SetIdent(uint16(rand.Uint32()))
8585
request.SetChecksum(header.ICMPv4Checksum(request, 0))
8686

87-
err = conn.WriteICMP(buf.As(request))
87+
err = conn.WriteICMP(buf.As(request).ToOwned())
8888
require.NoError(t, err)
8989

9090
conn.SetLocalAddr(netip.MustParseAddr("127.0.0.1"))
@@ -117,7 +117,7 @@ func testPingIPv4ReadICMP(t *testing.T, privileged bool, addr string) {
117117
request.SetIdent(uint16(rand.Uint32()))
118118
request.SetChecksum(header.ICMPv4Checksum(request, 0))
119119

120-
err = conn.WriteICMP(buf.As(request))
120+
err = conn.WriteICMP(buf.As(request).ToOwned())
121121
require.NoError(t, err)
122122

123123
require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second)))
@@ -148,7 +148,7 @@ func testPingIPv6ReadIP(t *testing.T, privileged bool, addr string) {
148148
request.SetType(header.ICMPv6EchoRequest)
149149
request.SetIdent(uint16(rand.Uint32()))
150150

151-
err = conn.WriteICMP(buf.As(request))
151+
err = conn.WriteICMP(buf.As(request).ToOwned())
152152
require.NoError(t, err)
153153

154154
conn.SetLocalAddr(netip.MustParseAddr("::1"))
@@ -180,7 +180,7 @@ func testPingIPv6ReadICMP(t *testing.T, privileged bool, addr string) {
180180
request.SetType(header.ICMPv6EchoRequest)
181181
request.SetIdent(uint16(rand.Uint32()))
182182

183-
err = conn.WriteICMP(buf.As(request))
183+
err = conn.WriteICMP(buf.As(request).ToOwned())
184184
require.NoError(t, err)
185185

186186
require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second)))

0 commit comments

Comments
 (0)