Skip to content

Commit 341e56b

Browse files
committed
windows: prefer AF_UNIX socketpair for fallback
1 parent fe7ab10 commit 341e56b

File tree

2 files changed

+265
-4
lines changed

2 files changed

+265
-4
lines changed

naive_client_fd_windows.go

Lines changed: 262 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,30 @@
33
package cronet
44

55
import (
6+
"crypto/rand"
7+
"encoding/hex"
68
"net"
9+
"os"
10+
"path/filepath"
11+
"strconv"
12+
"sync"
713
"syscall"
14+
"time"
15+
"unsafe"
816

917
E "github.com/sagernet/sing/common/exceptions"
1018

1119
"golang.org/x/sys/windows"
1220
)
1321

22+
var winsockSystemLibrary = windows.NewLazySystemDLL("ws2_32.dll")
23+
24+
var (
25+
winsockProcAccept = winsockSystemLibrary.NewProc("accept")
26+
winsockProcRecv = winsockSystemLibrary.NewProc("recv")
27+
winsockProcSend = winsockSystemLibrary.NewProc("send")
28+
)
29+
1430
// dupSocketFD extracts and duplicates the socket handle from a syscall.Conn.
1531
// Returns a new independent handle; the caller should close both the returned handle (after use)
1632
// and the original connection (immediately after this call).
@@ -41,7 +57,7 @@ func dupSocketFD(syscallConn syscall.Conn) (int, error) {
4157
if err != nil {
4258
// Close duplicated handle if Control itself fails after DuplicateHandle succeeded
4359
if socket != 0 {
44-
windows.CloseHandle(socket)
60+
_ = windows.Closesocket(socket)
4561
}
4662
return -1, E.Cause(err, "control raw conn")
4763
}
@@ -51,12 +67,257 @@ func dupSocketFD(syscallConn syscall.Conn) (int, error) {
5167
return int(socket), nil
5268
}
5369

70+
type socketTimeoutError struct{}
71+
72+
func (socketTimeoutError) Error() string { return "i/o timeout" }
73+
func (socketTimeoutError) Timeout() bool { return true }
74+
func (socketTimeoutError) Temporary() bool { return true }
75+
76+
type unixSocketAddress struct {
77+
name string
78+
}
79+
80+
func (a unixSocketAddress) Network() string { return "unix" }
81+
func (a unixSocketAddress) String() string { return a.name }
82+
83+
type winsockStreamConn struct {
84+
socketHandle windows.Handle
85+
closeOnce sync.Once
86+
}
87+
88+
func newWinsockStreamConn(socketHandle windows.Handle) *winsockStreamConn {
89+
return &winsockStreamConn{socketHandle: socketHandle}
90+
}
91+
92+
func (c *winsockStreamConn) Read(buffer []byte) (int, error) {
93+
if len(buffer) == 0 {
94+
return 0, nil
95+
}
96+
r1, _, err := winsockProcRecv.Call(
97+
uintptr(c.socketHandle),
98+
uintptr(unsafe.Pointer(&buffer[0])),
99+
uintptr(len(buffer)),
100+
0,
101+
)
102+
n := int(r1)
103+
if n != -1 {
104+
return n, nil
105+
}
106+
if isWinsockTimeout(err) {
107+
return 0, socketTimeoutError{}
108+
}
109+
return 0, err
110+
}
111+
112+
func (c *winsockStreamConn) Write(buffer []byte) (int, error) {
113+
if len(buffer) == 0 {
114+
return 0, nil
115+
}
116+
r1, _, err := winsockProcSend.Call(
117+
uintptr(c.socketHandle),
118+
uintptr(unsafe.Pointer(&buffer[0])),
119+
uintptr(len(buffer)),
120+
0,
121+
)
122+
n := int(r1)
123+
if n != -1 {
124+
return n, nil
125+
}
126+
if isWinsockTimeout(err) {
127+
return 0, socketTimeoutError{}
128+
}
129+
return 0, err
130+
}
131+
132+
func (c *winsockStreamConn) Close() error {
133+
var closeError error
134+
c.closeOnce.Do(func() {
135+
closeError = windows.Closesocket(c.socketHandle)
136+
})
137+
return closeError
138+
}
139+
140+
func (c *winsockStreamConn) LocalAddr() net.Addr {
141+
return unixSocketAddress{name: "winsock-unix-local"}
142+
}
143+
144+
func (c *winsockStreamConn) RemoteAddr() net.Addr {
145+
return unixSocketAddress{name: "winsock-unix-remote"}
146+
}
147+
148+
func (c *winsockStreamConn) SetDeadline(deadline time.Time) error {
149+
readError := c.SetReadDeadline(deadline)
150+
writeError := c.SetWriteDeadline(deadline)
151+
if readError != nil {
152+
return readError
153+
}
154+
return writeError
155+
}
156+
157+
func (c *winsockStreamConn) SetReadDeadline(deadline time.Time) error {
158+
return setSocketTimeout(c.socketHandle, windows.SO_RCVTIMEO, deadline)
159+
}
160+
161+
func (c *winsockStreamConn) SetWriteDeadline(deadline time.Time) error {
162+
return setSocketTimeout(c.socketHandle, winsockSO_SNDTIMEO, deadline)
163+
}
164+
165+
func (c *winsockStreamConn) CloseWrite() error {
166+
return windows.Shutdown(c.socketHandle, windows.SHUT_WR)
167+
}
168+
169+
const winsockSO_SNDTIMEO = 0x1005
170+
171+
func setSocketTimeout(socketHandle windows.Handle, option int, deadline time.Time) error {
172+
timeoutMilliseconds := 0
173+
if !deadline.IsZero() {
174+
timeout := time.Until(deadline)
175+
if timeout <= 0 {
176+
timeoutMilliseconds = 1
177+
} else {
178+
timeoutMilliseconds = int(timeout / time.Millisecond)
179+
if timeoutMilliseconds <= 0 {
180+
timeoutMilliseconds = 1
181+
}
182+
}
183+
}
184+
return windows.SetsockoptInt(socketHandle, windows.SOL_SOCKET, option, timeoutMilliseconds)
185+
}
186+
187+
func isWinsockTimeout(err error) bool {
188+
winsockError, ok := err.(syscall.Errno)
189+
if !ok {
190+
return false
191+
}
192+
return winsockError == windows.WSAETIMEDOUT || winsockError == windows.WSAEWOULDBLOCK
193+
}
194+
195+
func createUnixSocketPair() (cronetSocket windows.Handle, proxySocket windows.Handle, err error) {
196+
socketSuffix, err := randomHexString(8)
197+
if err != nil {
198+
return 0, 0, err
199+
}
200+
socketBaseName := "cronet-go-" + strconv.Itoa(os.Getpid()) + "-" + socketSuffix + ".sock"
201+
202+
candidates := []string{
203+
"@" + socketBaseName,
204+
}
205+
206+
temporaryPathCandidate := filepath.Join(os.TempDir(), socketBaseName)
207+
if len(temporaryPathCandidate) < windows.UNIX_PATH_MAX {
208+
candidates = append(candidates, temporaryPathCandidate)
209+
}
210+
211+
var lastError error
212+
for _, name := range candidates {
213+
cronetSocket, proxySocket, lastError = createUnixSocketPairWithName(name)
214+
if lastError == nil {
215+
return cronetSocket, proxySocket, nil
216+
}
217+
}
218+
return 0, 0, lastError
219+
}
220+
221+
func createUnixSocketPairWithName(name string) (cronetSocket windows.Handle, proxySocket windows.Handle, err error) {
222+
if name != "" && name[0] != '@' {
223+
_ = os.Remove(name)
224+
}
225+
226+
listenerSocket, err := windows.Socket(windows.AF_UNIX, windows.SOCK_STREAM, 0)
227+
if err != nil {
228+
return 0, 0, err
229+
}
230+
listenerClosed := false
231+
closeListenerSocket := func() {
232+
if listenerClosed {
233+
return
234+
}
235+
listenerClosed = true
236+
_ = windows.Closesocket(listenerSocket)
237+
}
238+
defer closeListenerSocket()
239+
240+
listenerAddress := &windows.SockaddrUnix{Name: name}
241+
err = windows.Bind(listenerSocket, listenerAddress)
242+
if err != nil {
243+
return 0, 0, err
244+
}
245+
err = windows.Listen(listenerSocket, 1)
246+
if err != nil {
247+
return 0, 0, err
248+
}
249+
250+
clientSocket, err := windows.Socket(windows.AF_UNIX, windows.SOCK_STREAM, 0)
251+
if err != nil {
252+
return 0, 0, err
253+
}
254+
clientClosed := false
255+
closeClientSocket := func() {
256+
if clientClosed {
257+
return
258+
}
259+
clientClosed = true
260+
_ = windows.Closesocket(clientSocket)
261+
}
262+
263+
acceptedDone := make(chan struct{})
264+
var acceptedSocket windows.Handle
265+
var acceptError error
266+
go func() {
267+
defer close(acceptedDone)
268+
r1, _, callError := winsockProcAccept.Call(uintptr(listenerSocket), 0, 0)
269+
if uintptr(r1) == uintptr(^uintptr(0)) {
270+
acceptError = callError
271+
return
272+
}
273+
acceptedSocket = windows.Handle(r1)
274+
}()
275+
276+
connectError := windows.Connect(clientSocket, listenerAddress)
277+
if connectError != nil {
278+
closeListenerSocket()
279+
closeClientSocket()
280+
<-acceptedDone
281+
if acceptedSocket != 0 {
282+
_ = windows.Closesocket(acceptedSocket)
283+
}
284+
return 0, 0, connectError
285+
}
286+
287+
<-acceptedDone
288+
if acceptError != nil {
289+
closeClientSocket()
290+
return 0, 0, acceptError
291+
}
292+
293+
closeListenerSocket()
294+
if name != "" && name[0] != '@' {
295+
_ = os.Remove(name)
296+
}
297+
298+
return acceptedSocket, clientSocket, nil
299+
}
300+
301+
func randomHexString(byteCount int) (string, error) {
302+
randomBytes := make([]byte, byteCount)
303+
_, err := rand.Read(randomBytes)
304+
if err != nil {
305+
return "", err
306+
}
307+
return hex.EncodeToString(randomBytes), nil
308+
}
309+
54310
// createSocketPair creates a bidirectional socket pair using TCP loopback.
55311
// This approach is necessary on Windows because Named Pipes are not compatible
56312
// with Cronet's socket layer (ioctlsocket fails on pipe handles).
57313
// Returns the cronet-side socket handle and a net.Conn for the proxy side.
58314
// The caller is responsible for closing both the handle and the connection.
59315
func createSocketPair() (cronetFD int, proxyConn net.Conn, err error) {
316+
cronetSocket, proxySocket, err := createUnixSocketPair()
317+
if err == nil {
318+
return int(cronetSocket), newWinsockStreamConn(proxySocket), nil
319+
}
320+
60321
// Listen on random port on loopback interface
61322
listener, err := net.Listen("tcp", "127.0.0.1:0")
62323
if err != nil {

socket_fd_windows_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,15 @@ func TestDupSocketFD(t *testing.T) {
8989
}
9090

9191
// Clean up
92-
windows.CloseHandle(windows.Handle(fd))
92+
_ = windows.Closesocket(windows.Handle(fd))
9393
}
9494

9595
func TestCreateSocketPair(t *testing.T) {
9696
fd, conn, err := createSocketPair()
9797
if err != nil {
9898
t.Fatal(err)
9999
}
100-
defer windows.CloseHandle(windows.Handle(fd))
100+
defer windows.Closesocket(windows.Handle(fd))
101101
defer conn.Close()
102102

103103
if fd < 0 {
@@ -160,7 +160,7 @@ func TestCreateSocketPair_MultipleCreation(t *testing.T) {
160160

161161
// Clean up
162162
for _, pair := range pairs {
163-
windows.CloseHandle(windows.Handle(pair.fd))
163+
_ = windows.Closesocket(windows.Handle(pair.fd))
164164
pair.conn.Close()
165165
}
166166
}

0 commit comments

Comments
 (0)