Skip to content

Commit 71bbc58

Browse files
committed
Add UDP dialer support
Add Engine.SetUDPDialer() API that allows custom UDP socket creation. The callback receives address/port and returns a socket fd along with local address information. Supported socket types: - AF_INET/AF_INET6 SOCK_DGRAM: Standard UDP socket - AF_UNIX SOCK_DGRAM: Unix domain datagram socket (Unix/macOS/Linux) - AF_UNIX SOCK_STREAM: Unix domain stream socket (Windows, with framing)
1 parent 341e56b commit 71bbc58

File tree

8 files changed

+136
-6
lines changed

8 files changed

+136
-6
lines changed

engine_cgo.go

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ package cronet
44

55
// #include <stdlib.h>
66
// #include <stdbool.h>
7+
// #include <string.h>
78
// #include <cronet_c.h>
89
//
910
// extern CRONET_EXPORT int cronetDialerCallback(void* context, char* address, uint16_t port);
11+
// extern CRONET_EXPORT int cronetUdpDialerCallback(void* context, char* address, uint16_t port, char* out_local_address, uint16_t* out_local_port);
1012
import "C"
1113

1214
import (
@@ -15,8 +17,10 @@ import (
1517
)
1618

1719
var (
18-
dialerAccess sync.RWMutex
19-
dialerMap = make(map[uintptr]Dialer)
20+
dialerAccess sync.RWMutex
21+
dialerMap = make(map[uintptr]Dialer)
22+
udpDialerAccess sync.RWMutex
23+
udpDialerMap = make(map[uintptr]UDPDialer)
2024
)
2125

2226
//export cronetDialerCallback
@@ -30,6 +34,29 @@ func cronetDialerCallback(context unsafe.Pointer, address *C.char, port C.uint16
3034
return C.int(dialer(C.GoString(address), uint16(port)))
3135
}
3236

37+
//export cronetUdpDialerCallback
38+
func cronetUdpDialerCallback(context unsafe.Pointer, address *C.char, port C.uint16_t, outLocalAddress *C.char, outLocalPort *C.uint16_t) C.int {
39+
udpDialerAccess.RLock()
40+
dialer, ok := udpDialerMap[uintptr(context)]
41+
udpDialerAccess.RUnlock()
42+
if !ok {
43+
return -104 // ERR_CONNECTION_FAILED
44+
}
45+
fd, localAddress, localPort := dialer(C.GoString(address), uint16(port))
46+
47+
// Write output parameters
48+
if outLocalAddress != nil && localAddress != "" {
49+
localAddressC := C.CString(localAddress)
50+
C.strcpy(outLocalAddress, localAddressC)
51+
C.free(unsafe.Pointer(localAddressC))
52+
}
53+
if outLocalPort != nil {
54+
*outLocalPort = C.uint16_t(localPort)
55+
}
56+
57+
return C.int(fd)
58+
}
59+
3360
func NewEngine() Engine {
3461
return Engine{uintptr(unsafe.Pointer(C.Cronet_Engine_Create()))}
3562
}
@@ -38,6 +65,9 @@ func (e Engine) Destroy() {
3865
dialerAccess.Lock()
3966
delete(dialerMap, e.ptr)
4067
dialerAccess.Unlock()
68+
udpDialerAccess.Lock()
69+
delete(udpDialerMap, e.ptr)
70+
udpDialerAccess.Unlock()
4171
C.Cronet_Engine_Destroy(C.Cronet_EnginePtr(unsafe.Pointer(e.ptr)))
4272
}
4373

@@ -224,3 +254,26 @@ func (e Engine) SetDialer(dialer Dialer) {
224254
unsafe.Pointer(e.ptr),
225255
)
226256
}
257+
258+
// SetUDPDialer sets a custom dialer for UDP sockets.
259+
// When set, the engine will use this callback to create UDP sockets instead of
260+
// the default system socket API.
261+
// Must be called before StartWithParams().
262+
// Pass nil to disable custom dialing.
263+
func (e Engine) SetUDPDialer(dialer UDPDialer) {
264+
if dialer == nil {
265+
C.Cronet_Engine_SetUdpDialer(C.Cronet_EnginePtr(unsafe.Pointer(e.ptr)), nil, nil)
266+
udpDialerAccess.Lock()
267+
delete(udpDialerMap, e.ptr)
268+
udpDialerAccess.Unlock()
269+
return
270+
}
271+
udpDialerAccess.Lock()
272+
udpDialerMap[e.ptr] = dialer
273+
udpDialerAccess.Unlock()
274+
C.Cronet_Engine_SetUdpDialer(
275+
C.Cronet_EnginePtr(unsafe.Pointer(e.ptr)),
276+
(*[0]byte)(C.cronetUdpDialerCallback),
277+
unsafe.Pointer(e.ptr),
278+
)
279+
}

engine_purego.go

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@ import (
1212
)
1313

1414
var (
15-
dialerAccess sync.RWMutex
16-
dialerMap = make(map[uintptr]Dialer)
17-
dialerCallback uintptr
15+
dialerAccess sync.RWMutex
16+
dialerMap = make(map[uintptr]Dialer)
17+
udpDialerAccess sync.RWMutex
18+
udpDialerMap = make(map[uintptr]UDPDialer)
19+
dialerCallback uintptr
20+
udpDialerCallback uintptr
1821
)
1922

2023
func init() {
@@ -27,6 +30,31 @@ func init() {
2730
}
2831
return dialer(cronet.GoString(address), port)
2932
})
33+
34+
udpDialerCallback = purego.NewCallback(func(context uintptr, address uintptr, port uint16, outLocalAddress uintptr, outLocalPort uintptr) int {
35+
udpDialerAccess.RLock()
36+
dialer, ok := udpDialerMap[context]
37+
udpDialerAccess.RUnlock()
38+
if !ok {
39+
return -104 // ERR_CONNECTION_FAILED
40+
}
41+
fd, localAddress, localPort := dialer(cronet.GoString(address), port)
42+
43+
// Write output parameters using unsafe
44+
if outLocalAddress != 0 && localAddress != "" {
45+
localAddressBytes := []byte(localAddress)
46+
for i, b := range localAddressBytes {
47+
*(*byte)(unsafe.Add(unsafe.Pointer(outLocalAddress), i)) = b
48+
}
49+
// Null terminator
50+
*(*byte)(unsafe.Add(unsafe.Pointer(outLocalAddress), len(localAddressBytes))) = 0
51+
}
52+
if outLocalPort != 0 {
53+
*(*uint16)(unsafe.Pointer(outLocalPort)) = localPort
54+
}
55+
56+
return fd
57+
})
3058
}
3159

3260
func NewEngine() Engine {
@@ -37,6 +65,9 @@ func (e Engine) Destroy() {
3765
dialerAccess.Lock()
3866
delete(dialerMap, e.ptr)
3967
dialerAccess.Unlock()
68+
udpDialerAccess.Lock()
69+
delete(udpDialerMap, e.ptr)
70+
udpDialerAccess.Unlock()
4071
cronet.EngineDestroy(e.ptr)
4172
}
4273

@@ -199,3 +230,22 @@ func (e Engine) SetDialer(dialer Dialer) {
199230
dialerAccess.Unlock()
200231
cronet.EngineSetDialer(e.ptr, dialerCallback, e.ptr)
201232
}
233+
234+
// SetUDPDialer sets a custom dialer for UDP sockets.
235+
// When set, the engine will use this callback to create UDP sockets instead of
236+
// the default system socket API.
237+
// Must be called before StartWithParams().
238+
// Pass nil to disable custom dialing.
239+
func (e Engine) SetUDPDialer(dialer UDPDialer) {
240+
if dialer == nil {
241+
cronet.EngineSetUdpDialer(e.ptr, 0, 0)
242+
udpDialerAccess.Lock()
243+
delete(udpDialerMap, e.ptr)
244+
udpDialerAccess.Unlock()
245+
return
246+
}
247+
udpDialerAccess.Lock()
248+
udpDialerMap[e.ptr] = dialer
249+
udpDialerAccess.Unlock()
250+
cronet.EngineSetUdpDialer(e.ptr, udpDialerCallback, e.ptr)
251+
}

internal/cronet/api_purego.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ func EngineSetDialer(engine, dialer, context uintptr) {
7373
cronetEngineSetDialer(engine, dialer, context)
7474
}
7575

76+
func EngineSetUdpDialer(engine, dialer, context uintptr) {
77+
cronetEngineSetUdpDialer(engine, dialer, context)
78+
}
79+
7680
func EngineGetStreamEngine(engine uintptr) uintptr {
7781
return cronetEngineGetStreamEngine(engine)
7882
}

internal/cronet/loader_unix.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,9 @@ func registerSymbols() error {
236236
if err := registerFunc(&cronetEngineSetDialer, "Cronet_Engine_SetDialer"); err != nil {
237237
return err
238238
}
239+
if err := registerFunc(&cronetEngineSetUdpDialer, "Cronet_Engine_SetUdpDialer"); err != nil {
240+
return err
241+
}
239242

240243
// EngineParams
241244
if err := registerFunc(&cronetEngineParamsCreate, "Cronet_EngineParams_Create"); err != nil {

internal/cronet/loader_windows.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,9 @@ func registerSymbols() error {
267267
if err := registerFunc(&cronetEngineSetDialer, "Cronet_Engine_SetDialer"); err != nil {
268268
return err
269269
}
270+
if err := registerFunc(&cronetEngineSetUdpDialer, "Cronet_Engine_SetUdpDialer"); err != nil {
271+
return err
272+
}
270273

271274
// EngineParams
272275
if err := registerFunc(&cronetEngineParamsCreate, "Cronet_EngineParams_Create"); err != nil {

internal/cronet/symbols_purego.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ var (
5454
cronetEngineGetStreamEngine func(uintptr) uintptr
5555
cronetEngineSetMockCertVerifierForTesting func(uintptr, uintptr)
5656
cronetEngineSetDialer func(uintptr, uintptr, uintptr)
57+
cronetEngineSetUdpDialer func(uintptr, uintptr, uintptr)
5758

5859
// EngineParams functions
5960
cronetEngineParamsCreate func() uintptr

types.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,19 @@ const (
291291
// ERR_ADDRESS_UNREACHABLE (-109)
292292
// ERR_CONNECTION_TIMED_OUT (-118)
293293
type Dialer func(address string, port uint16) int
294+
295+
// UDPDialer is a callback function for custom UDP socket creation.
296+
// address: IP address string (e.g. "1.2.3.4" or "::1")
297+
// port: Port number
298+
// Returns:
299+
// - fd: socket fd on success, negative net error code on failure
300+
// - localAddress: local IP address string (may be empty)
301+
// - localPort: local port number
302+
//
303+
// The returned socket can be:
304+
// - AF_INET/AF_INET6 SOCK_DGRAM: Standard UDP socket (may be connected)
305+
// - AF_UNIX SOCK_DGRAM: Unix domain datagram socket (Unix/macOS/Linux)
306+
// - AF_UNIX SOCK_STREAM: Unix domain stream socket (Windows, with framing)
307+
//
308+
// Cronet will NOT call connect() on the returned socket.
309+
type UDPDialer func(address string, port uint16) (fd int, localAddress string, localPort uint16)

0 commit comments

Comments
 (0)