Skip to content

Commit 69d1af6

Browse files
committed
ping: Fix unprivileged response on linux
1 parent 3faf8cf commit 69d1af6

File tree

3 files changed

+191
-8
lines changed

3 files changed

+191
-8
lines changed

ping/ping.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type Conn struct {
3232
}
3333

3434
func Connect(ctx context.Context, logger logger.ContextLogger, privileged bool, controlFunc control.Func, destination netip.Addr) (*Conn, error) {
35-
conn, err := connect(privileged, controlFunc, destination)
35+
conn, err := connect(ctx, privileged, controlFunc, destination)
3636
if err != nil {
3737
return nil, err
3838
}
@@ -53,20 +53,22 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
5353
readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
5454
var ipAddr *net.IPAddr
5555
n, oobn, _, ipAddr, err = conn.ReadMsgIP(b, oob)
56-
if ipAddr != nil {
56+
if err == nil {
5757
addr = M.AddrFromNet(ipAddr)
5858
}
5959
return
6060
}
6161
case *net.UDPConn:
6262
readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
63-
var udpAddr *net.UDPAddr
64-
n, oobn, _, udpAddr, err = conn.ReadMsgUDP(b, oob)
65-
if udpAddr != nil {
66-
addr = M.AddrFromNet(udpAddr)
63+
var addrPort netip.AddrPort
64+
n, oobn, _, addrPort, err = conn.ReadMsgUDPAddrPort(b, oob)
65+
if err == nil {
66+
addr = addrPort.Addr()
6767
}
6868
return
6969
}
70+
case *UnprivilegedConn:
71+
readMsg = conn.ReadMsg
7072
default:
7173
return E.New("unsupported conn type: ", reflect.TypeOf(c.conn))
7274
}
@@ -124,6 +126,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
124126
trafficClass = controlMessage.TrafficClass
125127
}
126128
icmpHdr := header.ICMPv6(buffer.Bytes())
129+
icmpHdr.SetChecksum(0)
127130
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
128131
Header: icmpHdr[:header.ICMPv6DstUnreachableMinimumSize],
129132
Src: addr.AsSlice(),
@@ -151,12 +154,14 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
151154
ipHdr.SetChecksum(0)
152155
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
153156
icmpHdr := header.ICMPv4(ipHdr.Payload())
157+
icmpHdr.SetChecksum(0)
154158
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
155159
c.logger.TraceContext(c.ctx, "read icmpv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr())
156160
} else {
157161
ipHdr := header.IPv6(buffer.Bytes())
158162
ipHdr.SetDestinationAddr(c.source.Load())
159163
icmpHdr := header.ICMPv6(ipHdr.Payload())
164+
icmpHdr.SetChecksum(0)
160165
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
161166
Header: icmpHdr,
162167
Src: ipHdr.SourceAddressSlice(),

ping/socket_linux_unprivileged.go

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
package ping
2+
3+
import (
4+
"context"
5+
"net"
6+
"net/netip"
7+
"os"
8+
"time"
9+
10+
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
11+
"github.com/sagernet/sing-tun/internal/gtcpip/header"
12+
"github.com/sagernet/sing/common/atomic"
13+
"github.com/sagernet/sing/common/buf"
14+
"github.com/sagernet/sing/common/control"
15+
M "github.com/sagernet/sing/common/metadata"
16+
)
17+
18+
type UnprivilegedConn struct {
19+
ctx context.Context
20+
cancel context.CancelFunc
21+
controlFunc control.Func
22+
destination netip.Addr
23+
receiveChan chan *unprivilegedResponse
24+
readDeadline atomic.TypedValue[time.Time]
25+
writeDeadline atomic.TypedValue[time.Time]
26+
}
27+
28+
type unprivilegedResponse struct {
29+
Buffer *buf.Buffer
30+
Cmsg *buf.Buffer
31+
Addr netip.Addr
32+
}
33+
34+
func newUnprivilegedConn(ctx context.Context, controlFunc control.Func, destination netip.Addr) net.Conn {
35+
ctx, cancel := context.WithCancel(ctx)
36+
return &UnprivilegedConn{
37+
ctx: ctx,
38+
cancel: cancel,
39+
controlFunc: controlFunc,
40+
destination: destination,
41+
receiveChan: make(chan *unprivilegedResponse),
42+
}
43+
}
44+
45+
func (c *UnprivilegedConn) Read(b []byte) (n int, err error) {
46+
select {
47+
case packet := <-c.receiveChan:
48+
n = copy(b, packet.Buffer.Bytes())
49+
packet.Buffer.Release()
50+
packet.Cmsg.Release()
51+
return
52+
case <-c.ctx.Done():
53+
return 0, os.ErrClosed
54+
}
55+
}
56+
57+
func (c *UnprivilegedConn) ReadMsg(b []byte, oob []byte) (n, oobn int, addr netip.Addr, err error) {
58+
select {
59+
case packet := <-c.receiveChan:
60+
n = copy(b, packet.Buffer.Bytes())
61+
oobn = copy(oob, packet.Cmsg.Bytes())
62+
addr = packet.Addr
63+
packet.Buffer.Release()
64+
packet.Cmsg.Release()
65+
return
66+
case <-c.ctx.Done():
67+
return 0, 0, netip.Addr{}, os.ErrClosed
68+
}
69+
}
70+
71+
func (c *UnprivilegedConn) Write(b []byte) (n int, err error) {
72+
conn, err := connect0(false, c.controlFunc, c.destination)
73+
if err != nil {
74+
return
75+
}
76+
var identifier uint16
77+
if !c.destination.Is6() {
78+
icmpHdr := header.ICMPv4(b)
79+
identifier = icmpHdr.Ident()
80+
} else {
81+
icmpHdr := header.ICMPv6(b)
82+
identifier = icmpHdr.Ident()
83+
}
84+
if readDeadline := c.readDeadline.Load(); !readDeadline.IsZero() {
85+
conn.SetReadDeadline(readDeadline)
86+
}
87+
if writeDeadline := c.writeDeadline.Load(); !writeDeadline.IsZero() {
88+
conn.SetWriteDeadline(writeDeadline)
89+
}
90+
n, err = conn.Write(b)
91+
if err != nil {
92+
conn.Close()
93+
return
94+
}
95+
go c.fetchResponse(conn, identifier)
96+
return
97+
}
98+
99+
func (c *UnprivilegedConn) fetchResponse(conn net.Conn, identifier uint16) {
100+
done := make(chan struct{})
101+
defer close(done)
102+
go func() {
103+
select {
104+
case <-c.ctx.Done():
105+
case <-done:
106+
}
107+
conn.Close()
108+
}()
109+
buffer := buf.NewPacket()
110+
cmsgBuffer := buf.NewSize(1024)
111+
n, oobN, _, addr, err := conn.(*net.UDPConn).ReadMsgUDPAddrPort(buffer.FreeBytes(), cmsgBuffer.FreeBytes())
112+
if err != nil {
113+
buffer.Release()
114+
cmsgBuffer.Release()
115+
return
116+
}
117+
buffer.Truncate(n)
118+
cmsgBuffer.Truncate(oobN)
119+
if !c.destination.Is6() {
120+
icmpHdr := header.ICMPv4(buffer.Bytes())
121+
icmpHdr.SetIdent(identifier)
122+
icmpHdr.SetChecksum(0)
123+
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
124+
} else {
125+
icmpHdr := header.ICMPv6(buffer.Bytes())
126+
icmpHdr.SetIdent(identifier)
127+
// offload checksum here since we don't have source address here
128+
}
129+
select {
130+
case c.receiveChan <- &unprivilegedResponse{
131+
Buffer: buffer,
132+
Cmsg: cmsgBuffer,
133+
Addr: addr.Addr(),
134+
}:
135+
case <-c.ctx.Done():
136+
buffer.Release()
137+
cmsgBuffer.Release()
138+
}
139+
}
140+
141+
func (c *UnprivilegedConn) Close() error {
142+
c.cancel()
143+
return nil
144+
}
145+
146+
func (c *UnprivilegedConn) LocalAddr() net.Addr {
147+
return M.Socksaddr{}
148+
}
149+
150+
func (c *UnprivilegedConn) RemoteAddr() net.Addr {
151+
return M.SocksaddrFrom(c.destination, 0).UDPAddr()
152+
}
153+
154+
func (c *UnprivilegedConn) SetDeadline(t time.Time) error {
155+
c.readDeadline.Store(t)
156+
c.writeDeadline.Store(t)
157+
return nil
158+
}
159+
160+
func (c *UnprivilegedConn) SetReadDeadline(t time.Time) error {
161+
c.readDeadline.Store(t)
162+
return nil
163+
}
164+
165+
func (c *UnprivilegedConn) SetWriteDeadline(t time.Time) error {
166+
c.writeDeadline.Store(t)
167+
return nil
168+
}

ping/socket_unix.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package ping
44

55
import (
6+
"context"
67
"net"
78
"net/netip"
89
"os"
@@ -12,10 +13,19 @@ import (
1213
"github.com/sagernet/sing/common/control"
1314
E "github.com/sagernet/sing/common/exceptions"
1415
M "github.com/sagernet/sing/common/metadata"
16+
1517
"golang.org/x/sys/unix"
1618
)
1719

18-
func connect(privileged bool, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
20+
func connect(ctx context.Context, privileged bool, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
21+
if (runtime.GOOS == "linux" || runtime.GOOS == "android") && !privileged {
22+
return newUnprivilegedConn(ctx, controlFunc, destination), nil
23+
} else {
24+
return connect0(privileged, controlFunc, destination)
25+
}
26+
}
27+
28+
func connect0(privileged bool, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
1929
var (
2030
network string
2131
fd int
@@ -77,7 +87,7 @@ func connect(privileged bool, controlFunc control.Func, destination netip.Addr)
7787
if err != nil {
7888
return nil, E.Cause(err, "connect()")
7989
}
80-
90+
8191
conn, err := net.FileConn(file)
8292
if err != nil {
8393
return nil, err

0 commit comments

Comments
 (0)