Skip to content

Commit 1dda761

Browse files
committed
ping: Add needFilter
1 parent 6e4e045 commit 1dda761

File tree

2 files changed

+60
-41
lines changed

2 files changed

+60
-41
lines changed

ping/destination.go

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -101,20 +101,22 @@ func (d *Destination) loopRead() {
101101
continue
102102
}
103103
icmpHdr := header.ICMPv4(ipHdr.Payload())
104-
if icmpHdr.Type() != header.ICMPv4EchoReply {
105-
continue
106-
}
107-
var requestExists bool
108-
request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}
109-
d.requestAccess.Lock()
110-
_, loaded := d.requests[request]
111-
if loaded {
112-
requestExists = true
113-
delete(d.requests, request)
114-
}
115-
d.requestAccess.Unlock()
116-
if !requestExists {
117-
continue
104+
if d.needFilter() {
105+
if icmpHdr.Type() != header.ICMPv4EchoReply {
106+
continue
107+
}
108+
var requestExists bool
109+
request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}
110+
d.requestAccess.Lock()
111+
_, loaded := d.requests[request]
112+
if loaded {
113+
requestExists = true
114+
delete(d.requests, request)
115+
}
116+
d.requestAccess.Unlock()
117+
if !requestExists {
118+
continue
119+
}
118120
}
119121
d.logger.TraceContext(d.ctx, "read ICMPv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
120122
} else {
@@ -128,20 +130,22 @@ func (d *Destination) loopRead() {
128130
continue
129131
}
130132
icmpHdr := header.ICMPv6(ipHdr.Payload())
131-
if icmpHdr.Type() != header.ICMPv6EchoReply {
132-
continue
133-
}
134-
var requestExists bool
135-
request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}
136-
d.requestAccess.Lock()
137-
_, loaded := d.requests[request]
138-
if loaded {
139-
requestExists = true
140-
delete(d.requests, request)
141-
}
142-
d.requestAccess.Unlock()
143-
if !requestExists {
144-
continue
133+
if d.needFilter() {
134+
if icmpHdr.Type() != header.ICMPv6EchoReply {
135+
continue
136+
}
137+
var requestExists bool
138+
request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}
139+
d.requestAccess.Lock()
140+
_, loaded := d.requests[request]
141+
if loaded {
142+
requestExists = true
143+
delete(d.requests, request)
144+
}
145+
d.requestAccess.Unlock()
146+
if !requestExists {
147+
continue
148+
}
145149
}
146150
d.logger.TraceContext(d.ctx, "read ICMPv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
147151
}
@@ -163,7 +167,9 @@ func (d *Destination) WritePacket(packet *buf.Buffer) error {
163167
return E.New("invalid ICMPv4 header")
164168
}
165169
icmpHdr := header.ICMPv4(ipHdr.Payload())
166-
d.registerRequest(pingRequest{Source: ipHdr.SourceAddr(), Destination: ipHdr.DestinationAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()})
170+
if d.needFilter() {
171+
d.registerRequest(pingRequest{Source: ipHdr.SourceAddr(), Destination: ipHdr.DestinationAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()})
172+
}
167173
d.logger.TraceContext(d.ctx, "write ICMPv4 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
168174
} else {
169175
ipHdr := header.IPv6(packet.Bytes())
@@ -174,12 +180,18 @@ func (d *Destination) WritePacket(packet *buf.Buffer) error {
174180
return E.New("invalid ICMPv6 header")
175181
}
176182
icmpHdr := header.ICMPv6(ipHdr.Payload())
177-
d.registerRequest(pingRequest{Source: ipHdr.SourceAddr(), Destination: ipHdr.DestinationAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()})
183+
if d.needFilter() {
184+
d.registerRequest(pingRequest{Source: ipHdr.SourceAddr(), Destination: ipHdr.DestinationAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()})
185+
}
178186
d.logger.TraceContext(d.ctx, "write ICMPv6 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
179187
}
180188
return d.conn.WriteIP(packet)
181189
}
182190

191+
func (d *Destination) needFilter() bool {
192+
return !d.conn.needChangeIdent()
193+
}
194+
183195
func (d *Destination) registerRequest(request pingRequest) {
184196
const requestsLimit = 1024
185197
d.requestAccess.Lock()

ping/ping.go

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,24 @@ func Connect(ctx context.Context, privileged bool, controlFunc control.Func, des
4444
}
4545

4646
func (c *Conn) connect(controlFunc control.Func) (err error) {
47-
if c.IsLinuxUnprivileged() {
47+
if c.isLinuxUnprivileged() {
4848
c.conn, err = newUnprivilegedConn(c.ctx, controlFunc, c.destination)
4949
} else {
5050
c.conn, err = connect(c.privileged, controlFunc, c.destination)
5151
}
5252
return
5353
}
5454

55-
func (c *Conn) IsLinuxUnprivileged() bool {
55+
func (c *Conn) isLinuxUnprivileged() bool {
5656
return (runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged
5757
}
5858

59+
func (c *Conn) needChangeIdent() bool {
60+
return runtime.GOOS != "windows" && !c.isLinuxUnprivileged()
61+
}
62+
5963
func (c *Conn) ReadIP(buffer *buf.Buffer) error {
60-
if c.destination.Is6() || c.IsLinuxUnprivileged() {
64+
if c.destination.Is6() || c.isLinuxUnprivileged() {
6165
var readMsg func(b, oob []byte) (n, oobn int, addr netip.Addr, err error)
6266
switch conn := c.conn.(type) {
6367
case *net.IPConn:
@@ -104,7 +108,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
104108
}
105109
ttl = controlMessage.TTL
106110
}
107-
if !c.IsLinuxUnprivileged() {
111+
if c.needChangeIdent() {
108112
icmpHdr := header.ICMPv4(buffer.Bytes())
109113
icmpHdr.SetIdent(^icmpHdr.Ident())
110114
icmpHdr.SetChecksum(0)
@@ -142,7 +146,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
142146
trafficClass = controlMessage.TrafficClass
143147
}
144148
icmpHdr := header.ICMPv6(buffer.Bytes())
145-
if !c.IsLinuxUnprivileged() {
149+
if c.needChangeIdent() {
146150
icmpHdr.SetIdent(^icmpHdr.Ident())
147151
}
148152
icmpHdr.SetChecksum(0)
@@ -182,7 +186,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
182186
ipHdr.SetChecksum(0)
183187
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
184188
icmpHdr := header.ICMPv4(ipHdr.Payload())
185-
if !c.IsLinuxUnprivileged() {
189+
if c.needChangeIdent() {
186190
icmpHdr.SetIdent(^icmpHdr.Ident())
187191
}
188192
icmpHdr.SetChecksum(0)
@@ -194,7 +198,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
194198
}
195199
ipHdr.SetDestinationAddr(c.source.Load())
196200
icmpHdr := header.ICMPv6(ipHdr.Payload())
197-
if !c.IsLinuxUnprivileged() {
201+
if c.needChangeIdent() {
198202
icmpHdr.SetIdent(^icmpHdr.Ident())
199203
}
200204
icmpHdr.SetChecksum(0)
@@ -213,7 +217,7 @@ func (c *Conn) ReadICMP(buffer *buf.Buffer) error {
213217
if err != nil {
214218
return err
215219
}
216-
if !c.IsLinuxUnprivileged() {
220+
if c.needChangeIdent() {
217221
if !c.destination.Is6() {
218222
ipHdr := header.IPv4(buffer.Bytes())
219223
buffer.Advance(int(ipHdr.HeaderLength()))
@@ -232,6 +236,9 @@ func (c *Conn) ReadICMP(buffer *buf.Buffer) error {
232236
Dst: c.source.Load().AsSlice(),
233237
}))
234238
}
239+
} else if !c.isLinuxUnprivileged() && !c.destination.Is6() {
240+
ipHdr := header.IPv4(buffer.Bytes())
241+
buffer.Advance(int(ipHdr.HeaderLength()))
235242
}
236243
return nil
237244
}
@@ -240,7 +247,7 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
240247
defer buffer.Release()
241248
if !c.destination.Is6() {
242249
ipHdr := header.IPv4(buffer.Bytes())
243-
if !c.IsLinuxUnprivileged() {
250+
if c.needChangeIdent() {
244251
icmpHdr := header.ICMPv4(ipHdr.Payload())
245252
icmpHdr.SetIdent(^icmpHdr.Ident())
246253
icmpHdr.SetChecksum(0)
@@ -250,7 +257,7 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
250257
return common.Error(c.conn.Write(ipHdr.Payload()))
251258
} else {
252259
ipHdr := header.IPv6(buffer.Bytes())
253-
if !c.IsLinuxUnprivileged() {
260+
if c.needChangeIdent() {
254261
icmpHdr := header.ICMPv6(ipHdr.Payload())
255262
icmpHdr.SetIdent(^icmpHdr.Ident())
256263
icmpHdr.SetChecksum(0)
@@ -267,7 +274,7 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
267274

268275
func (c *Conn) WriteICMP(buffer *buf.Buffer) error {
269276
defer buffer.Release()
270-
if !c.IsLinuxUnprivileged() {
277+
if c.needChangeIdent() {
271278
if !c.destination.Is6() {
272279
icmpHdr := header.ICMPv4(buffer.Bytes())
273280
icmpHdr.SetIdent(^icmpHdr.Ident())

0 commit comments

Comments
 (0)