Skip to content

Commit 8f6cc9f

Browse files
committed
ping: Fix unprivileged response on linux
1 parent 3faf8cf commit 8f6cc9f

File tree

3 files changed

+194
-7
lines changed

3 files changed

+194
-7
lines changed

ping/ping.go

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type Conn struct {
3232
}
3333

3434
func Connect(ctx context.Context, logger logger.ContextLogger, privileged bool, controlFunc control.Func, destination netip.Addr) (*Conn, error) {
35-
conn, err := connect(privileged, controlFunc, destination)
35+
conn, err := connect0(ctx, privileged, controlFunc, destination)
3636
if err != nil {
3737
return nil, err
3838
}
@@ -45,6 +45,14 @@ func Connect(ctx context.Context, logger logger.ContextLogger, privileged bool,
4545
}, nil
4646
}
4747

48+
func connect0(ctx context.Context, privileged bool, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
49+
if (runtime.GOOS == "linux" || runtime.GOOS == "android") && !privileged {
50+
return newUnprivilegedConn(ctx, controlFunc, destination)
51+
} else {
52+
return connect(privileged, controlFunc, destination)
53+
}
54+
}
55+
4856
func (c *Conn) ReadIP(buffer *buf.Buffer) error {
4957
if c.destination.Is6() || (runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged {
5058
var readMsg func(b, oob []byte) (n, oobn int, addr netip.Addr, err error)
@@ -53,20 +61,22 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
5361
readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
5462
var ipAddr *net.IPAddr
5563
n, oobn, _, ipAddr, err = conn.ReadMsgIP(b, oob)
56-
if ipAddr != nil {
64+
if err == nil {
5765
addr = M.AddrFromNet(ipAddr)
5866
}
5967
return
6068
}
6169
case *net.UDPConn:
6270
readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
63-
var udpAddr *net.UDPAddr
64-
n, oobn, _, udpAddr, err = conn.ReadMsgUDP(b, oob)
65-
if udpAddr != nil {
66-
addr = M.AddrFromNet(udpAddr)
71+
var addrPort netip.AddrPort
72+
n, oobn, _, addrPort, err = conn.ReadMsgUDPAddrPort(b, oob)
73+
if err == nil {
74+
addr = addrPort.Addr()
6775
}
6876
return
6977
}
78+
case *UnprivilegedConn:
79+
readMsg = conn.ReadMsg
7080
default:
7181
return E.New("unsupported conn type: ", reflect.TypeOf(c.conn))
7282
}
@@ -124,6 +134,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
124134
trafficClass = controlMessage.TrafficClass
125135
}
126136
icmpHdr := header.ICMPv6(buffer.Bytes())
137+
icmpHdr.SetChecksum(0)
127138
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
128139
Header: icmpHdr[:header.ICMPv6DstUnreachableMinimumSize],
129140
Src: addr.AsSlice(),
@@ -151,12 +162,14 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
151162
ipHdr.SetChecksum(0)
152163
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
153164
icmpHdr := header.ICMPv4(ipHdr.Payload())
165+
icmpHdr.SetChecksum(0)
154166
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
155167
c.logger.TraceContext(c.ctx, "read icmpv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr())
156168
} else {
157169
ipHdr := header.IPv6(buffer.Bytes())
158170
ipHdr.SetDestinationAddr(c.source.Load())
159171
icmpHdr := header.ICMPv6(ipHdr.Payload())
172+
icmpHdr.SetChecksum(0)
160173
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
161174
Header: icmpHdr,
162175
Src: ipHdr.SourceAddressSlice(),

ping/socket_linux_unprivileged.go

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
package ping
2+
3+
import (
4+
"context"
5+
"net"
6+
"net/netip"
7+
"os"
8+
"time"
9+
10+
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
11+
"github.com/sagernet/sing-tun/internal/gtcpip/header"
12+
"github.com/sagernet/sing/common/atomic"
13+
"github.com/sagernet/sing/common/buf"
14+
"github.com/sagernet/sing/common/control"
15+
M "github.com/sagernet/sing/common/metadata"
16+
)
17+
18+
type UnprivilegedConn struct {
19+
ctx context.Context
20+
cancel context.CancelFunc
21+
controlFunc control.Func
22+
destination netip.Addr
23+
receiveChan chan *unprivilegedResponse
24+
readDeadline atomic.TypedValue[time.Time]
25+
writeDeadline atomic.TypedValue[time.Time]
26+
}
27+
28+
type unprivilegedResponse struct {
29+
Buffer *buf.Buffer
30+
Cmsg *buf.Buffer
31+
Addr netip.Addr
32+
}
33+
34+
func newUnprivilegedConn(ctx context.Context, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
35+
conn, err := connect(false, controlFunc, destination)
36+
if err != nil {
37+
return nil, err
38+
}
39+
conn.Close()
40+
ctx, cancel := context.WithCancel(ctx)
41+
return &UnprivilegedConn{
42+
ctx: ctx,
43+
cancel: cancel,
44+
controlFunc: controlFunc,
45+
destination: destination,
46+
receiveChan: make(chan *unprivilegedResponse),
47+
}, nil
48+
}
49+
50+
func (c *UnprivilegedConn) Read(b []byte) (n int, err error) {
51+
select {
52+
case packet := <-c.receiveChan:
53+
n = copy(b, packet.Buffer.Bytes())
54+
packet.Buffer.Release()
55+
packet.Cmsg.Release()
56+
return
57+
case <-c.ctx.Done():
58+
return 0, os.ErrClosed
59+
}
60+
}
61+
62+
func (c *UnprivilegedConn) ReadMsg(b []byte, oob []byte) (n, oobn int, addr netip.Addr, err error) {
63+
select {
64+
case packet := <-c.receiveChan:
65+
n = copy(b, packet.Buffer.Bytes())
66+
oobn = copy(oob, packet.Cmsg.Bytes())
67+
addr = packet.Addr
68+
packet.Buffer.Release()
69+
packet.Cmsg.Release()
70+
return
71+
case <-c.ctx.Done():
72+
return 0, 0, netip.Addr{}, os.ErrClosed
73+
}
74+
}
75+
76+
func (c *UnprivilegedConn) Write(b []byte) (n int, err error) {
77+
conn, err := connect(false, c.controlFunc, c.destination)
78+
if err != nil {
79+
return
80+
}
81+
var identifier uint16
82+
if !c.destination.Is6() {
83+
icmpHdr := header.ICMPv4(b)
84+
identifier = icmpHdr.Ident()
85+
} else {
86+
icmpHdr := header.ICMPv6(b)
87+
identifier = icmpHdr.Ident()
88+
}
89+
if readDeadline := c.readDeadline.Load(); !readDeadline.IsZero() {
90+
conn.SetReadDeadline(readDeadline)
91+
}
92+
if writeDeadline := c.writeDeadline.Load(); !writeDeadline.IsZero() {
93+
conn.SetWriteDeadline(writeDeadline)
94+
}
95+
n, err = conn.Write(b)
96+
if err != nil {
97+
conn.Close()
98+
return
99+
}
100+
go c.fetchResponse(conn, identifier)
101+
return
102+
}
103+
104+
func (c *UnprivilegedConn) fetchResponse(conn net.Conn, identifier uint16) {
105+
done := make(chan struct{})
106+
defer close(done)
107+
go func() {
108+
select {
109+
case <-c.ctx.Done():
110+
case <-done:
111+
}
112+
conn.Close()
113+
}()
114+
buffer := buf.NewPacket()
115+
cmsgBuffer := buf.NewSize(1024)
116+
n, oobN, _, addr, err := conn.(*net.UDPConn).ReadMsgUDPAddrPort(buffer.FreeBytes(), cmsgBuffer.FreeBytes())
117+
if err != nil {
118+
buffer.Release()
119+
cmsgBuffer.Release()
120+
return
121+
}
122+
buffer.Truncate(n)
123+
cmsgBuffer.Truncate(oobN)
124+
if !c.destination.Is6() {
125+
icmpHdr := header.ICMPv4(buffer.Bytes())
126+
icmpHdr.SetIdent(identifier)
127+
icmpHdr.SetChecksum(0)
128+
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
129+
} else {
130+
icmpHdr := header.ICMPv6(buffer.Bytes())
131+
icmpHdr.SetIdent(identifier)
132+
// offload checksum here since we don't have source address here
133+
}
134+
select {
135+
case c.receiveChan <- &unprivilegedResponse{
136+
Buffer: buffer,
137+
Cmsg: cmsgBuffer,
138+
Addr: addr.Addr(),
139+
}:
140+
case <-c.ctx.Done():
141+
buffer.Release()
142+
cmsgBuffer.Release()
143+
}
144+
}
145+
146+
func (c *UnprivilegedConn) Close() error {
147+
c.cancel()
148+
return nil
149+
}
150+
151+
func (c *UnprivilegedConn) LocalAddr() net.Addr {
152+
return M.Socksaddr{}
153+
}
154+
155+
func (c *UnprivilegedConn) RemoteAddr() net.Addr {
156+
return M.SocksaddrFrom(c.destination, 0).UDPAddr()
157+
}
158+
159+
func (c *UnprivilegedConn) SetDeadline(t time.Time) error {
160+
c.readDeadline.Store(t)
161+
c.writeDeadline.Store(t)
162+
return nil
163+
}
164+
165+
func (c *UnprivilegedConn) SetReadDeadline(t time.Time) error {
166+
c.readDeadline.Store(t)
167+
return nil
168+
}
169+
170+
func (c *UnprivilegedConn) SetWriteDeadline(t time.Time) error {
171+
c.writeDeadline.Store(t)
172+
return nil
173+
}

ping/socket_unix.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/sagernet/sing/common/control"
1313
E "github.com/sagernet/sing/common/exceptions"
1414
M "github.com/sagernet/sing/common/metadata"
15+
1516
"golang.org/x/sys/unix"
1617
)
1718

@@ -77,7 +78,7 @@ func connect(privileged bool, controlFunc control.Func, destination netip.Addr)
7778
if err != nil {
7879
return nil, E.Cause(err, "connect()")
7980
}
80-
81+
8182
conn, err := net.FileConn(file)
8283
if err != nil {
8384
return nil, err

0 commit comments

Comments
 (0)