Skip to content

Commit 548e259

Browse files
committed
Improve custom dialer support
1 parent 46ef877 commit 548e259

20 files changed

+1424
-80
lines changed

bidirectional_conn.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ type BidirectionalConn struct {
2929
// Buffer safety: when Read/Write return due to close/done, Cronet may
3030
// still hold the buffer. These channels are closed by callbacks to signal
3131
// it's safe to return. sync.Once ensures close happens exactly once.
32-
readDone chan struct{}
33-
writeDone chan struct{}
32+
readDone chan struct{}
33+
writeDone chan struct{}
3434
readDoneOnce sync.Once
3535
writeDoneOnce sync.Once
3636
}
@@ -263,6 +263,27 @@ func (c *BidirectionalConn) WaitForHeaders() (map[string]string, error) {
263263
}
264264
}
265265

266+
func (c *BidirectionalConn) WaitForHeadersContext(ctx context.Context) (map[string]string, error) {
267+
select {
268+
case <-c.close:
269+
return nil, net.ErrClosed
270+
case <-c.done:
271+
return nil, net.ErrClosed
272+
default:
273+
}
274+
275+
select {
276+
case <-ctx.Done():
277+
return nil, ctx.Err()
278+
case <-c.handshake:
279+
return c.headers, nil
280+
case <-c.done:
281+
return nil, c.err
282+
case <-c.close:
283+
return nil, net.ErrClosed
284+
}
285+
}
286+
266287
type bidirectionalHandler struct {
267288
*BidirectionalConn
268289
readyOnce sync.Once

dialer_test.go

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
package cronet
2+
3+
import (
4+
"sync"
5+
"testing"
6+
)
7+
8+
func TestDialerMapCleanup(t *testing.T) {
9+
// Verify Engine.Destroy() properly cleans up dialerMap
10+
engine := NewEngine()
11+
12+
engine.SetDialer(func(address string, port uint16) int {
13+
return -104 // ERR_CONNECTION_FAILED
14+
})
15+
16+
dialerAccess.RLock()
17+
_, exists := dialerMap[engine.ptr]
18+
dialerAccess.RUnlock()
19+
20+
if !exists {
21+
t.Error("dialer not registered in dialerMap")
22+
}
23+
24+
engine.Destroy()
25+
26+
dialerAccess.RLock()
27+
_, exists = dialerMap[engine.ptr]
28+
dialerAccess.RUnlock()
29+
30+
if exists {
31+
t.Error("dialer not cleaned up after Engine.Destroy()")
32+
}
33+
}
34+
35+
func TestSetDialerNil(t *testing.T) {
36+
engine := NewEngine()
37+
defer engine.Destroy()
38+
39+
// First set a dialer
40+
engine.SetDialer(func(address string, port uint16) int {
41+
return -104
42+
})
43+
44+
dialerAccess.RLock()
45+
_, exists := dialerMap[engine.ptr]
46+
dialerAccess.RUnlock()
47+
48+
if !exists {
49+
t.Error("dialer not registered")
50+
}
51+
52+
// Then set it to nil
53+
engine.SetDialer(nil)
54+
55+
dialerAccess.RLock()
56+
_, exists = dialerMap[engine.ptr]
57+
dialerAccess.RUnlock()
58+
59+
if exists {
60+
t.Error("dialer not removed after SetDialer(nil)")
61+
}
62+
}
63+
64+
func TestSetDialerOverwrite(t *testing.T) {
65+
engine := NewEngine()
66+
defer engine.Destroy()
67+
68+
callCount1 := 0
69+
callCount2 := 0
70+
71+
// Set first dialer
72+
engine.SetDialer(func(address string, port uint16) int {
73+
callCount1++
74+
return -104
75+
})
76+
77+
// Overwrite with second dialer
78+
engine.SetDialer(func(address string, port uint16) int {
79+
callCount2++
80+
return -102
81+
})
82+
83+
// Verify only one entry in map
84+
dialerAccess.RLock()
85+
count := 0
86+
for k := range dialerMap {
87+
if k == engine.ptr {
88+
count++
89+
}
90+
}
91+
dialerAccess.RUnlock()
92+
93+
if count != 1 {
94+
t.Errorf("expected 1 entry in dialerMap, got %d", count)
95+
}
96+
}
97+
98+
func TestDialerConcurrentAccess(t *testing.T) {
99+
engine := NewEngine()
100+
defer engine.Destroy()
101+
102+
var wg sync.WaitGroup
103+
iterations := 100
104+
105+
// Concurrent SetDialer calls (writers)
106+
for i := 0; i < iterations; i++ {
107+
wg.Add(1)
108+
go func(n int) {
109+
defer wg.Done()
110+
if n%2 == 0 {
111+
engine.SetDialer(func(address string, port uint16) int {
112+
return -104
113+
})
114+
} else {
115+
engine.SetDialer(nil)
116+
}
117+
}(i)
118+
}
119+
120+
// Concurrent dialerMap reads (simulating callback access)
121+
for i := 0; i < iterations; i++ {
122+
wg.Add(1)
123+
go func() {
124+
defer wg.Done()
125+
dialerAccess.RLock()
126+
_ = dialerMap[engine.ptr]
127+
dialerAccess.RUnlock()
128+
}()
129+
}
130+
131+
wg.Wait()
132+
133+
// Verify final state consistency: at most 1 entry for this engine
134+
dialerAccess.RLock()
135+
count := 0
136+
for k := range dialerMap {
137+
if k == engine.ptr {
138+
count++
139+
}
140+
}
141+
dialerAccess.RUnlock()
142+
143+
if count > 1 {
144+
t.Errorf("dialerMap has duplicate entries for engine: %d", count)
145+
}
146+
}
147+
148+
func TestMultipleEnginesDialers(t *testing.T) {
149+
engine1 := NewEngine()
150+
engine2 := NewEngine()
151+
152+
engine1.SetDialer(func(address string, port uint16) int {
153+
return -104
154+
})
155+
156+
engine2.SetDialer(func(address string, port uint16) int {
157+
return -102
158+
})
159+
160+
// Verify both dialers are registered
161+
dialerAccess.RLock()
162+
_, exists1 := dialerMap[engine1.ptr]
163+
_, exists2 := dialerMap[engine2.ptr]
164+
dialerAccess.RUnlock()
165+
166+
if !exists1 || !exists2 {
167+
t.Error("both dialers should be registered")
168+
}
169+
170+
// Destroy engine1, verify engine2's dialer still exists
171+
engine1.Destroy()
172+
173+
dialerAccess.RLock()
174+
_, exists1 = dialerMap[engine1.ptr]
175+
_, exists2 = dialerMap[engine2.ptr]
176+
dialerAccess.RUnlock()
177+
178+
if exists1 {
179+
t.Error("engine1's dialer should be removed")
180+
}
181+
if !exists2 {
182+
t.Error("engine2's dialer should still exist")
183+
}
184+
185+
engine2.Destroy()
186+
}

engine_cgo.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,39 @@ package cronet
55
// #include <stdlib.h>
66
// #include <stdbool.h>
77
// #include <cronet_c.h>
8+
//
9+
// extern CRONET_EXPORT int cronetDialerCallback(void* context, char* address, uint16_t port);
810
import "C"
911

1012
import (
13+
"sync"
1114
"unsafe"
1215
)
1316

17+
var (
18+
dialerAccess sync.RWMutex
19+
dialerMap = make(map[uintptr]Dialer)
20+
)
21+
22+
//export cronetDialerCallback
23+
func cronetDialerCallback(context unsafe.Pointer, address *C.char, port C.uint16_t) C.int {
24+
dialerAccess.RLock()
25+
dialer, ok := dialerMap[uintptr(context)]
26+
dialerAccess.RUnlock()
27+
if !ok {
28+
return -104 // ERR_CONNECTION_FAILED
29+
}
30+
return C.int(dialer(C.GoString(address), uint16(port)))
31+
}
32+
1433
func NewEngine() Engine {
1534
return Engine{uintptr(unsafe.Pointer(C.Cronet_Engine_Create()))}
1635
}
1736

1837
func (e Engine) Destroy() {
38+
dialerAccess.Lock()
39+
delete(dialerMap, e.ptr)
40+
dialerAccess.Unlock()
1941
C.Cronet_Engine_Destroy(C.Cronet_EnginePtr(unsafe.Pointer(e.ptr)))
2042
}
2143

@@ -179,3 +201,26 @@ func (e Engine) SetCertVerifierWithPublicKeySHA256(hashes [][]byte) bool {
179201
C.Cronet_Engine_SetMockCertVerifierForTesting(C.Cronet_EnginePtr(unsafe.Pointer(e.ptr)), certVerifier)
180202
return true
181203
}
204+
205+
// SetDialer sets a custom dialer for TCP connections.
206+
// When set, the engine will use this callback to establish TCP connections
207+
// instead of the default system socket API.
208+
// Must be called before StartWithParams().
209+
// Pass nil to disable custom dialing.
210+
func (e Engine) SetDialer(dialer Dialer) {
211+
if dialer == nil {
212+
C.Cronet_Engine_SetDialer(C.Cronet_EnginePtr(unsafe.Pointer(e.ptr)), nil, nil)
213+
dialerAccess.Lock()
214+
delete(dialerMap, e.ptr)
215+
dialerAccess.Unlock()
216+
return
217+
}
218+
dialerAccess.Lock()
219+
dialerMap[e.ptr] = dialer
220+
dialerAccess.Unlock()
221+
C.Cronet_Engine_SetDialer(
222+
C.Cronet_EnginePtr(unsafe.Pointer(e.ptr)),
223+
(*[0]byte)(C.cronetDialerCallback),
224+
unsafe.Pointer(e.ptr),
225+
)
226+
}

engine_purego.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,40 @@
33
package cronet
44

55
import (
6+
"sync"
67
"unsafe"
78

89
"github.com/sagernet/cronet-go/internal/cronet"
10+
11+
"github.com/ebitengine/purego"
12+
)
13+
14+
var (
15+
dialerAccess sync.RWMutex
16+
dialerMap = make(map[uintptr]Dialer)
17+
dialerCallback uintptr
918
)
1019

20+
func init() {
21+
dialerCallback = purego.NewCallback(func(context uintptr, address uintptr, port uint16) int {
22+
dialerAccess.RLock()
23+
dialer, ok := dialerMap[context]
24+
dialerAccess.RUnlock()
25+
if !ok {
26+
return -104 // ERR_CONNECTION_FAILED
27+
}
28+
return dialer(cronet.GoString(address), port)
29+
})
30+
}
31+
1132
func NewEngine() Engine {
1233
return Engine{cronet.EngineCreate()}
1334
}
1435

1536
func (e Engine) Destroy() {
37+
dialerAccess.Lock()
38+
delete(dialerMap, e.ptr)
39+
dialerAccess.Unlock()
1640
cronet.EngineDestroy(e.ptr)
1741
}
1842

@@ -156,3 +180,22 @@ func (e Engine) SetCertVerifierWithPublicKeySHA256(hashes [][]byte) bool {
156180
cronet.EngineSetMockCertVerifierForTesting(e.ptr, certVerifier)
157181
return true
158182
}
183+
184+
// SetDialer sets a custom dialer for TCP connections.
185+
// When set, the engine will use this callback to establish TCP connections
186+
// instead of the default system socket API.
187+
// Must be called before StartWithParams().
188+
// Pass nil to disable custom dialing.
189+
func (e Engine) SetDialer(dialer Dialer) {
190+
if dialer == nil {
191+
cronet.EngineSetDialer(e.ptr, 0, 0)
192+
dialerAccess.Lock()
193+
delete(dialerMap, e.ptr)
194+
dialerAccess.Unlock()
195+
return
196+
}
197+
dialerAccess.Lock()
198+
dialerMap[e.ptr] = dialer
199+
dialerAccess.Unlock()
200+
cronet.EngineSetDialer(e.ptr, dialerCallback, e.ptr)
201+
}

0 commit comments

Comments
 (0)