Skip to content

Commit 144683d

Browse files
committed
ping: Add filter to destination
1 parent d0ff7b6 commit 144683d

File tree

3 files changed

+108
-31
lines changed

3 files changed

+108
-31
lines changed

ping/destination.go

Lines changed: 103 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ import (
66
"net/netip"
77
"os"
88
"runtime"
9+
"sync"
910
"time"
1011

1112
"github.com/sagernet/sing-tun"
13+
"github.com/sagernet/sing-tun/internal/gtcpip/header"
1214
"github.com/sagernet/sing/common/buf"
1315
"github.com/sagernet/sing/common/control"
1416
E "github.com/sagernet/sing/common/exceptions"
@@ -18,18 +20,28 @@ import (
1820
var _ tun.DirectRouteDestination = (*Destination)(nil)
1921

2022
type Destination struct {
21-
conn *Conn
22-
ctx context.Context
23-
logger logger.ContextLogger
24-
routeContext tun.DirectRouteContext
25-
timeout time.Duration
23+
conn *Conn
24+
ctx context.Context
25+
logger logger.ContextLogger
26+
destination netip.Addr
27+
routeContext tun.DirectRouteContext
28+
timeout time.Duration
29+
requestAccess sync.Mutex
30+
requests map[pingRequest]bool
31+
}
32+
33+
type pingRequest struct {
34+
Source netip.Addr
35+
Destination netip.Addr
36+
Identifier uint16
37+
Sequence uint16
2638
}
2739

2840
func ConnectDestination(
2941
ctx context.Context,
3042
logger logger.ContextLogger,
3143
controlFunc control.Func,
32-
address netip.Addr,
44+
destination netip.Addr,
3345
routeContext tun.DirectRouteContext,
3446
timeout time.Duration,
3547
) (tun.DirectRouteDestination, error) {
@@ -39,11 +51,11 @@ func ConnectDestination(
3951
)
4052
switch runtime.GOOS {
4153
case "darwin", "ios", "windows":
42-
conn, err = Connect(ctx, logger, false, controlFunc, address)
54+
conn, err = Connect(ctx, false, controlFunc, destination)
4355
default:
44-
conn, err = Connect(ctx, logger, true, controlFunc, address)
56+
conn, err = Connect(ctx, true, controlFunc, destination)
4557
if errors.Is(err, os.ErrPermission) {
46-
conn, err = Connect(ctx, logger, false, controlFunc, address)
58+
conn, err = Connect(ctx, false, controlFunc, destination)
4759
}
4860
}
4961
if err != nil {
@@ -53,8 +65,10 @@ func ConnectDestination(
5365
conn: conn,
5466
ctx: ctx,
5567
logger: logger,
68+
destination: destination,
5669
routeContext: routeContext,
5770
timeout: timeout,
71+
requests: make(map[pingRequest]bool),
5872
}
5973
go d.loopRead()
6074
return d, nil
@@ -76,6 +90,59 @@ func (d *Destination) loopRead() {
7690
}
7791
return
7892
}
93+
if !d.destination.Is6() {
94+
ipHdr := header.IPv4(buffer.Bytes())
95+
if !ipHdr.IsValid(buffer.Len()) {
96+
d.logger.ErrorContext(d.ctx, E.New("invalid IPv4 header received"))
97+
continue
98+
}
99+
if ipHdr.PayloadLength() < header.ICMPv4MinimumSize {
100+
d.logger.ErrorContext(d.ctx, E.New("invalid ICMPv4 header received"))
101+
continue
102+
}
103+
icmpHdr := header.ICMPv4(ipHdr.Payload())
104+
if icmpHdr.Type() != header.ICMPv4EchoReply {
105+
continue
106+
}
107+
var requestExists bool
108+
request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}
109+
d.requestAccess.Lock()
110+
if d.requests[request] {
111+
requestExists = true
112+
delete(d.requests, request)
113+
}
114+
d.requestAccess.Unlock()
115+
if !requestExists {
116+
continue
117+
}
118+
d.logger.TraceContext(d.ctx, "read ICMPv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
119+
} else {
120+
ipHdr := header.IPv6(buffer.Bytes())
121+
if !ipHdr.IsValid(buffer.Len()) {
122+
d.logger.ErrorContext(d.ctx, E.New("invalid IPv6 header received"))
123+
continue
124+
}
125+
if ipHdr.PayloadLength() < header.ICMPv6MinimumSize {
126+
d.logger.ErrorContext(d.ctx, E.New("invalid ICMPv6 header received"))
127+
continue
128+
}
129+
icmpHdr := header.ICMPv6(ipHdr.Payload())
130+
if icmpHdr.Type() != header.ICMPv6EchoReply {
131+
continue
132+
}
133+
var requestExists bool
134+
request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}
135+
d.requestAccess.Lock()
136+
if d.requests[request] {
137+
requestExists = true
138+
delete(d.requests, request)
139+
}
140+
d.requestAccess.Unlock()
141+
if !requestExists {
142+
continue
143+
}
144+
d.logger.TraceContext(d.ctx, "read ICMPv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
145+
}
79146
err = d.routeContext.WritePacket(buffer.Bytes())
80147
if err != nil {
81148
d.logger.ErrorContext(d.ctx, E.Cause(err, "write ICMP echo reply"))
@@ -85,6 +152,33 @@ func (d *Destination) loopRead() {
85152
}
86153

87154
func (d *Destination) WritePacket(packet *buf.Buffer) error {
155+
if !d.destination.Is6() {
156+
ipHdr := header.IPv4(packet.Bytes())
157+
if !ipHdr.IsValid(packet.Len()) {
158+
return E.New("invalid IPv4 header")
159+
}
160+
if ipHdr.PayloadLength() < header.ICMPv4MinimumSize {
161+
return E.New("invalid ICMPv4 header")
162+
}
163+
icmpHdr := header.ICMPv4(ipHdr.Payload())
164+
d.requestAccess.Lock()
165+
d.requests[pingRequest{Source: ipHdr.SourceAddr(), Destination: ipHdr.DestinationAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}] = true
166+
d.requestAccess.Unlock()
167+
d.logger.TraceContext(d.ctx, "write ICMPv4 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
168+
} else {
169+
ipHdr := header.IPv6(packet.Bytes())
170+
if !ipHdr.IsValid(packet.Len()) {
171+
return E.New("invalid IPv6 header")
172+
}
173+
if ipHdr.PayloadLength() < header.ICMPv6MinimumSize {
174+
return E.New("invalid ICMPv6 header")
175+
}
176+
icmpHdr := header.ICMPv6(ipHdr.Payload())
177+
d.requestAccess.Lock()
178+
d.requests[pingRequest{Source: ipHdr.SourceAddr(), Destination: ipHdr.DestinationAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}] = true
179+
d.requestAccess.Unlock()
180+
d.logger.TraceContext(d.ctx, "write ICMPv6 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
181+
}
88182
return d.conn.WriteIP(packet)
89183
}
90184

ping/ping.go

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515
"github.com/sagernet/sing/common/buf"
1616
"github.com/sagernet/sing/common/control"
1717
E "github.com/sagernet/sing/common/exceptions"
18-
"github.com/sagernet/sing/common/logger"
1918
M "github.com/sagernet/sing/common/metadata"
2019

2120
"golang.org/x/net/ipv4"
@@ -24,18 +23,16 @@ import (
2423

2524
type Conn struct {
2625
ctx context.Context
27-
logger logger.ContextLogger
2826
privileged bool
2927
conn net.Conn
3028
destination netip.Addr
3129
source common.TypedValue[netip.Addr]
3230
closed atomic.Bool
3331
}
3432

35-
func Connect(ctx context.Context, logger logger.ContextLogger, privileged bool, controlFunc control.Func, destination netip.Addr) (*Conn, error) {
33+
func Connect(ctx context.Context, privileged bool, controlFunc control.Func, destination netip.Addr) (*Conn, error) {
3634
c := &Conn{
3735
ctx: ctx,
38-
logger: logger,
3936
privileged: privileged,
4037
destination: destination,
4138
}
@@ -123,7 +120,6 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
123120
TotalLength: uint16(buffer.Len()),
124121
})
125122
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
126-
c.logger.TraceContext(c.ctx, "read icmpv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr())
127123
} else {
128124
oob := make([]byte, 1024)
129125
buffer.Advance(header.IPv6MinimumSize)
@@ -164,7 +160,6 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
164160
SrcAddr: addr,
165161
DstAddr: c.source.Load(),
166162
})
167-
c.logger.TraceContext(c.ctx, "read icmpv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr())
168163
}
169164
} else {
170165
_, err := buffer.ReadOnceFrom(c.conn)
@@ -192,7 +187,6 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
192187
}
193188
icmpHdr.SetChecksum(0)
194189
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
195-
c.logger.TraceContext(c.ctx, "read icmpv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr())
196190
} else {
197191
ipHdr := header.IPv6(buffer.Bytes())
198192
if !ipHdr.IsValid(buffer.Len()) {
@@ -209,7 +203,6 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
209203
Src: ipHdr.SourceAddressSlice(),
210204
Dst: ipHdr.DestinationAddressSlice(),
211205
}))
212-
c.logger.TraceContext(c.ctx, "read icmpv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr())
213206
}
214207
}
215208
return nil
@@ -254,7 +247,6 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
254247
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
255248
}
256249
c.source.Store(M.AddrFromIP(ipHdr.SourceAddressSlice()))
257-
c.logger.TraceContext(c.ctx, "write icmpv4 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr())
258250
return common.Error(c.conn.Write(ipHdr.Payload()))
259251
} else {
260252
ipHdr := header.IPv6(buffer.Bytes())
@@ -269,7 +261,6 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
269261
}))
270262
}
271263
c.source.Store(M.AddrFromIP(ipHdr.SourceAddressSlice()))
272-
c.logger.TraceContext(c.ctx, "write icmpv6 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr())
273264
return common.Error(c.conn.Write(ipHdr.Payload()))
274265
}
275266
}
@@ -282,7 +273,6 @@ func (c *Conn) WriteICMP(buffer *buf.Buffer) error {
282273
icmpHdr.SetIdent(^icmpHdr.Ident())
283274
icmpHdr.SetChecksum(0)
284275
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
285-
c.logger.TraceContext(c.ctx, "write icmpv4 echo request to ", c.destination)
286276
} else {
287277
icmpHdr := header.ICMPv6(buffer.Bytes())
288278
icmpHdr.SetIdent(^icmpHdr.Ident())
@@ -294,11 +284,6 @@ func (c *Conn) WriteICMP(buffer *buf.Buffer) error {
294284
}))
295285
}
296286
}
297-
if !c.destination.Is6() {
298-
c.logger.TraceContext(c.ctx, "write icmpv4 echo request to ", c.destination)
299-
} else {
300-
c.logger.TraceContext(c.ctx, "write icmpv6 echo request to ", c.destination)
301-
}
302287
return common.Error(c.conn.Write(buffer.Bytes()))
303288
}
304289

ping/ping_test.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ import (
1212
"github.com/sagernet/sing-tun/internal/gtcpip/header"
1313
"github.com/sagernet/sing-tun/ping"
1414
"github.com/sagernet/sing/common/buf"
15-
"github.com/sagernet/sing/common/logger"
16-
1715
"github.com/stretchr/testify/require"
1816
)
1917

@@ -73,7 +71,7 @@ func TestPing(t *testing.T) {
7371
}
7472

7573
func testPingIPv4ReadIP(t *testing.T, privileged bool, addr string) {
76-
conn, err := ping.Connect(context.Background(), logger.NOP(), privileged, nil, netip.MustParseAddr(addr))
74+
conn, err := ping.Connect(context.Background(), privileged, nil, netip.MustParseAddr(addr))
7775
if runtime.GOOS == "linux" && err != nil && err.Error() == "socket(): permission denied" {
7876
t.SkipNow()
7977
}
@@ -106,7 +104,7 @@ func testPingIPv4ReadIP(t *testing.T, privileged bool, addr string) {
106104
}
107105

108106
func testPingIPv4ReadICMP(t *testing.T, privileged bool, addr string) {
109-
conn, err := ping.Connect(context.Background(), logger.NOP(), privileged, nil, netip.MustParseAddr(addr))
107+
conn, err := ping.Connect(context.Background(), privileged, nil, netip.MustParseAddr(addr))
110108
if runtime.GOOS == "linux" && err != nil && err.Error() == "socket(): permission denied" {
111109
t.SkipNow()
112110
}
@@ -138,7 +136,7 @@ func testPingIPv4ReadICMP(t *testing.T, privileged bool, addr string) {
138136
}
139137

140138
func testPingIPv6ReadIP(t *testing.T, privileged bool, addr string) {
141-
conn, err := ping.Connect(context.Background(), logger.NOP(), privileged, nil, netip.MustParseAddr(addr))
139+
conn, err := ping.Connect(context.Background(), privileged, nil, netip.MustParseAddr(addr))
142140
if runtime.GOOS == "linux" && err != nil && err.Error() == "socket(): permission denied" {
143141
t.SkipNow()
144142
}
@@ -170,7 +168,7 @@ func testPingIPv6ReadIP(t *testing.T, privileged bool, addr string) {
170168
}
171169

172170
func testPingIPv6ReadICMP(t *testing.T, privileged bool, addr string) {
173-
conn, err := ping.Connect(context.Background(), logger.NOP(), privileged, nil, netip.MustParseAddr(addr))
171+
conn, err := ping.Connect(context.Background(), privileged, nil, netip.MustParseAddr(addr))
174172
if runtime.GOOS == "linux" && err != nil && err.Error() == "socket(): permission denied" {
175173
t.SkipNow()
176174
}

0 commit comments

Comments
 (0)