Skip to content

Commit d38c316

Browse files
committed
ping: Fix unprivileged response on linux
1 parent 3faf8cf commit d38c316

File tree

3 files changed

+196
-8
lines changed

3 files changed

+196
-8
lines changed

ping/ping.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type Conn struct {
3232
}
3333

3434
func Connect(ctx context.Context, logger logger.ContextLogger, privileged bool, controlFunc control.Func, destination netip.Addr) (*Conn, error) {
35-
conn, err := connect(privileged, controlFunc, destination)
35+
conn, err := connect(ctx, privileged, controlFunc, destination)
3636
if err != nil {
3737
return nil, err
3838
}
@@ -53,20 +53,22 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
5353
readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
5454
var ipAddr *net.IPAddr
5555
n, oobn, _, ipAddr, err = conn.ReadMsgIP(b, oob)
56-
if ipAddr != nil {
56+
if err == nil {
5757
addr = M.AddrFromNet(ipAddr)
5858
}
5959
return
6060
}
6161
case *net.UDPConn:
6262
readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
63-
var udpAddr *net.UDPAddr
64-
n, oobn, _, udpAddr, err = conn.ReadMsgUDP(b, oob)
65-
if udpAddr != nil {
66-
addr = M.AddrFromNet(udpAddr)
63+
var addrPort netip.AddrPort
64+
n, oobn, _, addrPort, err = conn.ReadMsgUDPAddrPort(b, oob)
65+
if err == nil {
66+
addr = addrPort.Addr()
6767
}
6868
return
6969
}
70+
case *UnprivilegedConn:
71+
readMsg = conn.ReadMsg
7072
default:
7173
return E.New("unsupported conn type: ", reflect.TypeOf(c.conn))
7274
}
@@ -124,6 +126,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
124126
trafficClass = controlMessage.TrafficClass
125127
}
126128
icmpHdr := header.ICMPv6(buffer.Bytes())
129+
icmpHdr.SetChecksum(0)
127130
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
128131
Header: icmpHdr[:header.ICMPv6DstUnreachableMinimumSize],
129132
Src: addr.AsSlice(),
@@ -151,12 +154,14 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
151154
ipHdr.SetChecksum(0)
152155
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
153156
icmpHdr := header.ICMPv4(ipHdr.Payload())
157+
icmpHdr.SetChecksum(0)
154158
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
155159
c.logger.TraceContext(c.ctx, "read icmpv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr())
156160
} else {
157161
ipHdr := header.IPv6(buffer.Bytes())
158162
ipHdr.SetDestinationAddr(c.source.Load())
159163
icmpHdr := header.ICMPv6(ipHdr.Payload())
164+
icmpHdr.SetChecksum(0)
160165
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
161166
Header: icmpHdr,
162167
Src: ipHdr.SourceAddressSlice(),

ping/socket_linux_unprivileged.go

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
package ping
2+
3+
import (
4+
"context"
5+
"net"
6+
"net/netip"
7+
"os"
8+
"time"
9+
10+
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
11+
"github.com/sagernet/sing-tun/internal/gtcpip/header"
12+
"github.com/sagernet/sing/common/atomic"
13+
"github.com/sagernet/sing/common/buf"
14+
"github.com/sagernet/sing/common/control"
15+
M "github.com/sagernet/sing/common/metadata"
16+
)
17+
18+
type UnprivilegedConn struct {
19+
ctx context.Context
20+
cancel context.CancelFunc
21+
controlFunc control.Func
22+
destination netip.Addr
23+
receiveChan chan *unprivilegedResponse
24+
readDeadline atomic.TypedValue[time.Time]
25+
writeDeadline atomic.TypedValue[time.Time]
26+
}
27+
28+
type unprivilegedResponse struct {
29+
Buffer *buf.Buffer
30+
Cmsg *buf.Buffer
31+
Addr netip.Addr
32+
}
33+
34+
func newUnprivilegedConn(ctx context.Context, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
35+
conn, err := connect0(false, controlFunc, destination)
36+
if err != nil {
37+
return nil, err
38+
}
39+
conn.Close()
40+
ctx, cancel := context.WithCancel(ctx)
41+
return &UnprivilegedConn{
42+
ctx: ctx,
43+
cancel: cancel,
44+
controlFunc: controlFunc,
45+
destination: destination,
46+
receiveChan: make(chan *unprivilegedResponse),
47+
}, nil
48+
}
49+
50+
func (c *UnprivilegedConn) Read(b []byte) (n int, err error) {
51+
select {
52+
case packet := <-c.receiveChan:
53+
n = copy(b, packet.Buffer.Bytes())
54+
packet.Buffer.Release()
55+
packet.Cmsg.Release()
56+
return
57+
case <-c.ctx.Done():
58+
return 0, os.ErrClosed
59+
}
60+
}
61+
62+
func (c *UnprivilegedConn) ReadMsg(b []byte, oob []byte) (n, oobn int, addr netip.Addr, err error) {
63+
select {
64+
case packet := <-c.receiveChan:
65+
n = copy(b, packet.Buffer.Bytes())
66+
oobn = copy(oob, packet.Cmsg.Bytes())
67+
addr = packet.Addr
68+
packet.Buffer.Release()
69+
packet.Cmsg.Release()
70+
return
71+
case <-c.ctx.Done():
72+
return 0, 0, netip.Addr{}, os.ErrClosed
73+
}
74+
}
75+
76+
func (c *UnprivilegedConn) Write(b []byte) (n int, err error) {
77+
conn, err := connect0(false, c.controlFunc, c.destination)
78+
if err != nil {
79+
return
80+
}
81+
var identifier uint16
82+
if !c.destination.Is6() {
83+
icmpHdr := header.ICMPv4(b)
84+
identifier = icmpHdr.Ident()
85+
} else {
86+
icmpHdr := header.ICMPv6(b)
87+
identifier = icmpHdr.Ident()
88+
}
89+
if readDeadline := c.readDeadline.Load(); !readDeadline.IsZero() {
90+
conn.SetReadDeadline(readDeadline)
91+
}
92+
if writeDeadline := c.writeDeadline.Load(); !writeDeadline.IsZero() {
93+
conn.SetWriteDeadline(writeDeadline)
94+
}
95+
n, err = conn.Write(b)
96+
if err != nil {
97+
conn.Close()
98+
return
99+
}
100+
go c.fetchResponse(conn, identifier)
101+
return
102+
}
103+
104+
func (c *UnprivilegedConn) fetchResponse(conn net.Conn, identifier uint16) {
105+
done := make(chan struct{})
106+
defer close(done)
107+
go func() {
108+
select {
109+
case <-c.ctx.Done():
110+
case <-done:
111+
}
112+
conn.Close()
113+
}()
114+
buffer := buf.NewPacket()
115+
cmsgBuffer := buf.NewSize(1024)
116+
n, oobN, _, addr, err := conn.(*net.UDPConn).ReadMsgUDPAddrPort(buffer.FreeBytes(), cmsgBuffer.FreeBytes())
117+
if err != nil {
118+
buffer.Release()
119+
cmsgBuffer.Release()
120+
return
121+
}
122+
buffer.Truncate(n)
123+
cmsgBuffer.Truncate(oobN)
124+
if !c.destination.Is6() {
125+
icmpHdr := header.ICMPv4(buffer.Bytes())
126+
icmpHdr.SetIdent(identifier)
127+
icmpHdr.SetChecksum(0)
128+
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
129+
} else {
130+
icmpHdr := header.ICMPv6(buffer.Bytes())
131+
icmpHdr.SetIdent(identifier)
132+
// offload checksum here since we don't have source address here
133+
}
134+
select {
135+
case c.receiveChan <- &unprivilegedResponse{
136+
Buffer: buffer,
137+
Cmsg: cmsgBuffer,
138+
Addr: addr.Addr(),
139+
}:
140+
case <-c.ctx.Done():
141+
buffer.Release()
142+
cmsgBuffer.Release()
143+
}
144+
}
145+
146+
func (c *UnprivilegedConn) Close() error {
147+
c.cancel()
148+
return nil
149+
}
150+
151+
func (c *UnprivilegedConn) LocalAddr() net.Addr {
152+
return M.Socksaddr{}
153+
}
154+
155+
func (c *UnprivilegedConn) RemoteAddr() net.Addr {
156+
return M.SocksaddrFrom(c.destination, 0).UDPAddr()
157+
}
158+
159+
func (c *UnprivilegedConn) SetDeadline(t time.Time) error {
160+
c.readDeadline.Store(t)
161+
c.writeDeadline.Store(t)
162+
return nil
163+
}
164+
165+
func (c *UnprivilegedConn) SetReadDeadline(t time.Time) error {
166+
c.readDeadline.Store(t)
167+
return nil
168+
}
169+
170+
func (c *UnprivilegedConn) SetWriteDeadline(t time.Time) error {
171+
c.writeDeadline.Store(t)
172+
return nil
173+
}

ping/socket_unix.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package ping
44

55
import (
6+
"context"
67
"net"
78
"net/netip"
89
"os"
@@ -12,10 +13,19 @@ import (
1213
"github.com/sagernet/sing/common/control"
1314
E "github.com/sagernet/sing/common/exceptions"
1415
M "github.com/sagernet/sing/common/metadata"
16+
1517
"golang.org/x/sys/unix"
1618
)
1719

18-
func connect(privileged bool, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
20+
func connect(ctx context.Context, privileged bool, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
21+
if (runtime.GOOS == "linux" || runtime.GOOS == "android") && !privileged {
22+
return newUnprivilegedConn(ctx, controlFunc, destination)
23+
} else {
24+
return connect0(privileged, controlFunc, destination)
25+
}
26+
}
27+
28+
func connect0(privileged bool, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
1929
var (
2030
network string
2131
fd int
@@ -77,7 +87,7 @@ func connect(privileged bool, controlFunc control.Func, destination netip.Addr)
7787
if err != nil {
7888
return nil, E.Cause(err, "connect()")
7989
}
80-
90+
8191
conn, err := net.FileConn(file)
8292
if err != nil {
8393
return nil, err

0 commit comments

Comments
 (0)