Skip to content

Commit 56824f1

Browse files
committed
ping: Rewrite UnprivilegedConn
1 parent c5f7371 commit 56824f1

File tree

1 file changed

+82
-53
lines changed

1 file changed

+82
-53
lines changed

ping/socket_linux_unprivileged.go

Lines changed: 82 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@ import (
55
"net"
66
"net/netip"
77
"os"
8+
"sync"
89
"time"
910

1011
"github.com/metacubex/sing-tun/internal/gtcpip/checksum"
1112
"github.com/metacubex/sing-tun/internal/gtcpip/header"
12-
"github.com/metacubex/sing/common/atomic"
13+
"github.com/metacubex/sing/common"
1314
"github.com/metacubex/sing/common/buf"
1415
"github.com/metacubex/sing/common/control"
1516
M "github.com/metacubex/sing/common/metadata"
17+
"github.com/metacubex/sing/common/pipe"
1618
)
1719

1820
type UnprivilegedConn struct {
@@ -21,7 +23,9 @@ type UnprivilegedConn struct {
2123
controlFunc control.Func
2224
destination netip.Addr
2325
receiveChan chan *unprivilegedResponse
24-
readDeadline atomic.TypedValue[time.Time]
26+
readDeadline pipe.Deadline
27+
natMap map[uint16]net.Conn
28+
natMapMutex sync.Mutex
2529
}
2630

2731
type unprivilegedResponse struct {
@@ -38,11 +42,13 @@ func newUnprivilegedConn(ctx context.Context, controlFunc control.Func, destinat
3842
conn.Close()
3943
ctx, cancel := context.WithCancel(ctx)
4044
return &UnprivilegedConn{
41-
ctx: ctx,
42-
cancel: cancel,
43-
controlFunc: controlFunc,
44-
destination: destination,
45-
receiveChan: make(chan *unprivilegedResponse),
45+
ctx: ctx,
46+
cancel: cancel,
47+
controlFunc: controlFunc,
48+
destination: destination,
49+
receiveChan: make(chan *unprivilegedResponse),
50+
readDeadline: pipe.MakeDeadline(),
51+
natMap: make(map[uint16]net.Conn),
4652
}, nil
4753
}
4854

@@ -55,6 +61,8 @@ func (c *UnprivilegedConn) Read(b []byte) (n int, err error) {
5561
return
5662
case <-c.ctx.Done():
5763
return 0, os.ErrClosed
64+
case <-c.readDeadline.Wait():
65+
return 0, os.ErrDeadlineExceeded
5866
}
5967
}
6068

@@ -69,14 +77,12 @@ func (c *UnprivilegedConn) ReadMsg(b []byte, oob []byte) (n, oobn int, addr neti
6977
return
7078
case <-c.ctx.Done():
7179
return 0, 0, netip.Addr{}, os.ErrClosed
80+
case <-c.readDeadline.Wait():
81+
return 0, 0, netip.Addr{}, os.ErrDeadlineExceeded
7282
}
7383
}
7484

7585
func (c *UnprivilegedConn) Write(b []byte) (n int, err error) {
76-
conn, err := connect(false, c.controlFunc, c.destination)
77-
if err != nil {
78-
return
79-
}
8086
var identifier uint16
8187
if !c.destination.Is6() {
8288
icmpHdr := header.ICMPv4(b)
@@ -85,62 +91,85 @@ func (c *UnprivilegedConn) Write(b []byte) (n int, err error) {
8591
icmpHdr := header.ICMPv6(b)
8692
identifier = icmpHdr.Ident()
8793
}
88-
if readDeadline := c.readDeadline.Load(); !readDeadline.IsZero() {
89-
conn.SetReadDeadline(readDeadline)
94+
95+
c.natMapMutex.Lock()
96+
if err = c.ctx.Err(); err != nil {
97+
c.natMapMutex.Unlock()
98+
return 0, err
99+
}
100+
conn, ok := c.natMap[identifier]
101+
if !ok {
102+
conn, err = connect(false, c.controlFunc, c.destination)
103+
if err != nil {
104+
c.natMapMutex.Unlock()
105+
return 0, err
106+
}
107+
go c.fetchResponse(conn.(*net.UDPConn), identifier)
90108
}
109+
c.natMapMutex.Unlock()
110+
91111
n, err = conn.Write(b)
92112
if err != nil {
93-
conn.Close()
113+
c.removeConn(conn.(*net.UDPConn), identifier)
94114
return
95115
}
96-
go c.fetchResponse(conn, identifier)
97116
return
98117
}
99118

100-
func (c *UnprivilegedConn) fetchResponse(conn net.Conn, identifier uint16) {
101-
done := make(chan struct{})
102-
defer close(done)
103-
go func() {
119+
func (c *UnprivilegedConn) fetchResponse(conn *net.UDPConn, identifier uint16) {
120+
defer c.removeConn(conn, identifier)
121+
for {
122+
buffer := buf.NewPacket()
123+
cmsgBuffer := buf.NewSize(1024)
124+
n, oobN, _, addr, err := conn.ReadMsgUDPAddrPort(buffer.FreeBytes(), cmsgBuffer.FreeBytes())
125+
if err != nil {
126+
buffer.Release()
127+
cmsgBuffer.Release()
128+
return
129+
}
130+
buffer.Truncate(n)
131+
cmsgBuffer.Truncate(oobN)
132+
if !c.destination.Is6() {
133+
icmpHdr := header.ICMPv4(buffer.Bytes())
134+
icmpHdr.SetIdent(identifier)
135+
icmpHdr.SetChecksum(0)
136+
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
137+
} else {
138+
icmpHdr := header.ICMPv6(buffer.Bytes())
139+
icmpHdr.SetIdent(identifier)
140+
// offload checksum here since we don't have source address here
141+
}
104142
select {
143+
case c.receiveChan <- &unprivilegedResponse{
144+
Buffer: buffer,
145+
Cmsg: cmsgBuffer,
146+
Addr: addr.Addr(),
147+
}:
105148
case <-c.ctx.Done():
106-
case <-done:
149+
buffer.Release()
150+
cmsgBuffer.Release()
151+
return
107152
}
108-
conn.Close()
109-
}()
110-
buffer := buf.NewPacket()
111-
cmsgBuffer := buf.NewSize(1024)
112-
n, oobN, _, addr, err := conn.(*net.UDPConn).ReadMsgUDPAddrPort(buffer.FreeBytes(), cmsgBuffer.FreeBytes())
113-
if err != nil {
114-
buffer.Release()
115-
cmsgBuffer.Release()
116-
return
117153
}
118-
buffer.Truncate(n)
119-
cmsgBuffer.Truncate(oobN)
120-
if !c.destination.Is6() {
121-
icmpHdr := header.ICMPv4(buffer.Bytes())
122-
icmpHdr.SetIdent(identifier)
123-
icmpHdr.SetChecksum(0)
124-
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
125-
} else {
126-
icmpHdr := header.ICMPv6(buffer.Bytes())
127-
icmpHdr.SetIdent(identifier)
128-
// offload checksum here since we don't have source address here
129-
}
130-
select {
131-
case c.receiveChan <- &unprivilegedResponse{
132-
Buffer: buffer,
133-
Cmsg: cmsgBuffer,
134-
Addr: addr.Addr(),
135-
}:
136-
case <-c.ctx.Done():
137-
buffer.Release()
138-
cmsgBuffer.Release()
154+
}
155+
156+
func (c *UnprivilegedConn) removeConn(conn *net.UDPConn, identifier uint16) {
157+
c.natMapMutex.Lock()
158+
_ = conn.Close()
159+
if c.natMap[identifier] == conn {
160+
delete(c.natMap, identifier)
139161
}
162+
c.natMapMutex.Unlock()
140163
}
141164

142165
func (c *UnprivilegedConn) Close() error {
166+
c.natMapMutex.Lock()
143167
c.cancel()
168+
for _, conn := range c.natMap {
169+
_ = conn.Close()
170+
}
171+
common.ClearMap(c.natMap)
172+
c.natMapMutex.Unlock()
144173
return nil
145174
}
146175

@@ -153,14 +182,14 @@ func (c *UnprivilegedConn) RemoteAddr() net.Addr {
153182
}
154183

155184
func (c *UnprivilegedConn) SetDeadline(t time.Time) error {
156-
return os.ErrInvalid
185+
return c.SetReadDeadline(t)
157186
}
158187

159188
func (c *UnprivilegedConn) SetReadDeadline(t time.Time) error {
160-
c.readDeadline.Store(t)
189+
c.readDeadline.Set(t)
161190
return nil
162191
}
163192

164193
func (c *UnprivilegedConn) SetWriteDeadline(t time.Time) error {
165-
return os.ErrInvalid
194+
return nil
166195
}

0 commit comments

Comments
 (0)