@@ -5,14 +5,16 @@ import (
55 "net"
66 "net/netip"
77 "os"
8+ "sync"
89 "time"
910
1011 "github.com/metacubex/sing-tun/internal/gtcpip/checksum"
1112 "github.com/metacubex/sing-tun/internal/gtcpip/header"
12- "github.com/metacubex/sing/common/atomic "
13+ "github.com/metacubex/sing/common"
1314 "github.com/metacubex/sing/common/buf"
1415 "github.com/metacubex/sing/common/control"
1516 M "github.com/metacubex/sing/common/metadata"
17+ "github.com/metacubex/sing/common/pipe"
1618)
1719
1820type UnprivilegedConn struct {
@@ -21,7 +23,9 @@ type UnprivilegedConn struct {
2123 controlFunc control.Func
2224 destination netip.Addr
2325 receiveChan chan * unprivilegedResponse
24- readDeadline atomic.TypedValue [time.Time ]
26+ readDeadline pipe.Deadline
27+ natMap map [uint16 ]net.Conn
28+ natMapMutex sync.Mutex
2529}
2630
2731type unprivilegedResponse struct {
@@ -38,11 +42,13 @@ func newUnprivilegedConn(ctx context.Context, controlFunc control.Func, destinat
3842 conn .Close ()
3943 ctx , cancel := context .WithCancel (ctx )
4044 return & UnprivilegedConn {
41- ctx : ctx ,
42- cancel : cancel ,
43- controlFunc : controlFunc ,
44- destination : destination ,
45- receiveChan : make (chan * unprivilegedResponse ),
45+ ctx : ctx ,
46+ cancel : cancel ,
47+ controlFunc : controlFunc ,
48+ destination : destination ,
49+ receiveChan : make (chan * unprivilegedResponse ),
50+ readDeadline : pipe .MakeDeadline (),
51+ natMap : make (map [uint16 ]net.Conn ),
4652 }, nil
4753}
4854
@@ -55,6 +61,8 @@ func (c *UnprivilegedConn) Read(b []byte) (n int, err error) {
5561 return
5662 case <- c .ctx .Done ():
5763 return 0 , os .ErrClosed
64+ case <- c .readDeadline .Wait ():
65+ return 0 , os .ErrDeadlineExceeded
5866 }
5967}
6068
@@ -69,14 +77,12 @@ func (c *UnprivilegedConn) ReadMsg(b []byte, oob []byte) (n, oobn int, addr neti
6977 return
7078 case <- c .ctx .Done ():
7179 return 0 , 0 , netip.Addr {}, os .ErrClosed
80+ case <- c .readDeadline .Wait ():
81+ return 0 , 0 , netip.Addr {}, os .ErrDeadlineExceeded
7282 }
7383}
7484
7585func (c * UnprivilegedConn ) Write (b []byte ) (n int , err error ) {
76- conn , err := connect (false , c .controlFunc , c .destination )
77- if err != nil {
78- return
79- }
8086 var identifier uint16
8187 if ! c .destination .Is6 () {
8288 icmpHdr := header .ICMPv4 (b )
@@ -85,62 +91,85 @@ func (c *UnprivilegedConn) Write(b []byte) (n int, err error) {
8591 icmpHdr := header .ICMPv6 (b )
8692 identifier = icmpHdr .Ident ()
8793 }
88- if readDeadline := c .readDeadline .Load (); ! readDeadline .IsZero () {
89- conn .SetReadDeadline (readDeadline )
94+
95+ c .natMapMutex .Lock ()
96+ if err = c .ctx .Err (); err != nil {
97+ c .natMapMutex .Unlock ()
98+ return 0 , err
99+ }
100+ conn , ok := c .natMap [identifier ]
101+ if ! ok {
102+ conn , err = connect (false , c .controlFunc , c .destination )
103+ if err != nil {
104+ c .natMapMutex .Unlock ()
105+ return 0 , err
106+ }
107+ go c .fetchResponse (conn .(* net.UDPConn ), identifier )
90108 }
109+ c .natMapMutex .Unlock ()
110+
91111 n , err = conn .Write (b )
92112 if err != nil {
93- conn .Close ( )
113+ c . removeConn ( conn .( * net. UDPConn ), identifier )
94114 return
95115 }
96- go c .fetchResponse (conn , identifier )
97116 return
98117}
99118
100- func (c * UnprivilegedConn ) fetchResponse (conn net.Conn , identifier uint16 ) {
101- done := make (chan struct {})
102- defer close (done )
103- go func () {
119+ func (c * UnprivilegedConn ) fetchResponse (conn * net.UDPConn , identifier uint16 ) {
120+ defer c .removeConn (conn , identifier )
121+ for {
122+ buffer := buf .NewPacket ()
123+ cmsgBuffer := buf .NewSize (1024 )
124+ n , oobN , _ , addr , err := conn .ReadMsgUDPAddrPort (buffer .FreeBytes (), cmsgBuffer .FreeBytes ())
125+ if err != nil {
126+ buffer .Release ()
127+ cmsgBuffer .Release ()
128+ return
129+ }
130+ buffer .Truncate (n )
131+ cmsgBuffer .Truncate (oobN )
132+ if ! c .destination .Is6 () {
133+ icmpHdr := header .ICMPv4 (buffer .Bytes ())
134+ icmpHdr .SetIdent (identifier )
135+ icmpHdr .SetChecksum (0 )
136+ icmpHdr .SetChecksum (header .ICMPv4Checksum (icmpHdr [:header .ICMPv4MinimumSize ], checksum .Checksum (icmpHdr .Payload (), 0 )))
137+ } else {
138+ icmpHdr := header .ICMPv6 (buffer .Bytes ())
139+ icmpHdr .SetIdent (identifier )
140+ // offload checksum here since we don't have source address here
141+ }
104142 select {
143+ case c .receiveChan <- & unprivilegedResponse {
144+ Buffer : buffer ,
145+ Cmsg : cmsgBuffer ,
146+ Addr : addr .Addr (),
147+ }:
105148 case <- c .ctx .Done ():
106- case <- done :
149+ buffer .Release ()
150+ cmsgBuffer .Release ()
151+ return
107152 }
108- conn .Close ()
109- }()
110- buffer := buf .NewPacket ()
111- cmsgBuffer := buf .NewSize (1024 )
112- n , oobN , _ , addr , err := conn .(* net.UDPConn ).ReadMsgUDPAddrPort (buffer .FreeBytes (), cmsgBuffer .FreeBytes ())
113- if err != nil {
114- buffer .Release ()
115- cmsgBuffer .Release ()
116- return
117153 }
118- buffer .Truncate (n )
119- cmsgBuffer .Truncate (oobN )
120- if ! c .destination .Is6 () {
121- icmpHdr := header .ICMPv4 (buffer .Bytes ())
122- icmpHdr .SetIdent (identifier )
123- icmpHdr .SetChecksum (0 )
124- icmpHdr .SetChecksum (header .ICMPv4Checksum (icmpHdr [:header .ICMPv4MinimumSize ], checksum .Checksum (icmpHdr .Payload (), 0 )))
125- } else {
126- icmpHdr := header .ICMPv6 (buffer .Bytes ())
127- icmpHdr .SetIdent (identifier )
128- // offload checksum here since we don't have source address here
129- }
130- select {
131- case c .receiveChan <- & unprivilegedResponse {
132- Buffer : buffer ,
133- Cmsg : cmsgBuffer ,
134- Addr : addr .Addr (),
135- }:
136- case <- c .ctx .Done ():
137- buffer .Release ()
138- cmsgBuffer .Release ()
154+ }
155+
156+ func (c * UnprivilegedConn ) removeConn (conn * net.UDPConn , identifier uint16 ) {
157+ c .natMapMutex .Lock ()
158+ _ = conn .Close ()
159+ if c .natMap [identifier ] == conn {
160+ delete (c .natMap , identifier )
139161 }
162+ c .natMapMutex .Unlock ()
140163}
141164
142165func (c * UnprivilegedConn ) Close () error {
166+ c .natMapMutex .Lock ()
143167 c .cancel ()
168+ for _ , conn := range c .natMap {
169+ _ = conn .Close ()
170+ }
171+ common .ClearMap (c .natMap )
172+ c .natMapMutex .Unlock ()
144173 return nil
145174}
146175
@@ -153,14 +182,14 @@ func (c *UnprivilegedConn) RemoteAddr() net.Addr {
153182}
154183
155184func (c * UnprivilegedConn ) SetDeadline (t time.Time ) error {
156- return os . ErrInvalid
185+ return c . SetReadDeadline ( t )
157186}
158187
159188func (c * UnprivilegedConn ) SetReadDeadline (t time.Time ) error {
160- c .readDeadline .Store (t )
189+ c .readDeadline .Set (t )
161190 return nil
162191}
163192
164193func (c * UnprivilegedConn ) SetWriteDeadline (t time.Time ) error {
165- return os . ErrInvalid
194+ return nil
166195}
0 commit comments