Skip to content

Commit 8c640a0

Browse files
committed
Add custom dialer support for NaiveClient
1 parent 3f4fe26 commit 8c640a0

21 files changed

+1140
-48
lines changed

.github/workflows/naive-build.yml

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,96 @@ jobs:
417417
include/
418418
include_cgo.go
419419
420+
integration-test:
421+
if: github.event_name == 'push'
422+
needs: [linux, darwin, windows]
423+
strategy:
424+
fail-fast: false
425+
matrix:
426+
include:
427+
- os: ubuntu-22.04
428+
name: linux-amd64
429+
artifact: cronet-linux-amd64
430+
- os: macos-15
431+
name: macos-arm64
432+
artifact: cronet-darwin-arm64
433+
- os: windows-2025
434+
name: windows-amd64
435+
artifact: cronet-windows-amd64
436+
runs-on: ${{ matrix.os }}
437+
steps:
438+
- uses: actions/checkout@v4
439+
- uses: actions/setup-go@v5
440+
with:
441+
go-version: ^1.22
442+
443+
- name: Download artifact
444+
uses: actions/download-artifact@v4
445+
with:
446+
name: ${{ matrix.artifact }}
447+
path: .
448+
449+
# Linux setup
450+
- name: Install iperf3 (Linux)
451+
if: runner.os == 'Linux'
452+
run: sudo apt-get update && sudo apt-get install -y iperf3
453+
454+
# macOS setup
455+
- name: Install iperf3 (macOS)
456+
if: runner.os == 'macOS'
457+
run: |
458+
brew install iperf3
459+
ln -s $(which iperf3) test/iperf3-darwin
460+
461+
# Windows setup (WSL2 + Docker)
462+
- name: Setup WSL2 and Docker (Windows)
463+
if: runner.os == 'Windows'
464+
shell: pwsh
465+
run: |
466+
# Install Ubuntu on WSL2
467+
wsl --install Ubuntu --no-launch
468+
wsl --set-default Ubuntu
469+
470+
# Install Docker in WSL2
471+
wsl -d Ubuntu -e bash -c "curl -fsSL https://get.docker.com | sh"
472+
wsl -d Ubuntu -e bash -c "sudo usermod -aG docker \$USER"
473+
474+
# Start Docker daemon with sudo
475+
Start-Process -NoNewWindow wsl -ArgumentList "-d", "Ubuntu", "-e", "sudo", "dockerd"
476+
477+
# Wait for Docker daemon to be ready
478+
$maxRetries = 30
479+
for ($i = 0; $i -lt $maxRetries; $i++) {
480+
$result = wsl -d Ubuntu -e bash -c "sudo docker info" 2>&1
481+
if ($LASTEXITCODE -eq 0) { break }
482+
Start-Sleep -Seconds 2
483+
}
484+
if ($i -eq $maxRetries) {
485+
Write-Error "Docker daemon failed to start after $maxRetries retries"
486+
exit 1
487+
}
488+
489+
- name: Install iperf3 (Windows)
490+
if: runner.os == 'Windows'
491+
shell: pwsh
492+
run: |
493+
choco install iperf3 -y
494+
wsl -d Ubuntu -e bash -c "sudo apt-get update && sudo apt-get install -y iperf3"
495+
496+
# Run tests
497+
- name: Run integration tests (Linux/macOS)
498+
if: runner.os != 'Windows'
499+
run: |
500+
cd test
501+
go test -v -timeout 30m .
502+
503+
- name: Run integration tests (Windows)
504+
if: runner.os == 'Windows'
505+
shell: pwsh
506+
run: |
507+
$winPath = (Get-Location).Path
508+
wsl -d Ubuntu -e bash -c "export DOCKER_HOST=unix:///var/run/docker.sock && cd \$(wslpath '$winPath')/test && go test -v -timeout 30m ."
509+
420510
publish:
421511
if: github.event_name == 'push'
422512
needs: [linux, linux-musl, darwin, windows, android]

dialer_cgo.c

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// +build !with_purego
2+
3+
#include <stdint.h>
4+
5+
// Go callback function (defined in engine_cgo.go)
6+
extern int goDialerCallbackBridge(uintptr_t context, const char* address, uint16_t port);
7+
8+
// C wrapper function with correct signature for Cronet_DialerFunc
9+
int cronetDialerCallbackWrapper(void* context, const char* address, uint16_t port) {
10+
return goDialerCallbackBridge((uintptr_t)context, address, port);
11+
}

dialer_test.go

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
package cronet
2+
3+
import (
4+
"sync"
5+
"testing"
6+
)
7+
8+
func TestDialerMapCleanup(t *testing.T) {
9+
// Verify Engine.Destroy() properly cleans up dialerMap
10+
engine := NewEngine()
11+
12+
engine.SetDialer(func(address string, port uint16) int {
13+
return -104 // ERR_CONNECTION_FAILED
14+
})
15+
16+
dialerAccess.RLock()
17+
_, exists := dialerMap[engine.ptr]
18+
dialerAccess.RUnlock()
19+
20+
if !exists {
21+
t.Error("dialer not registered in dialerMap")
22+
}
23+
24+
engine.Destroy()
25+
26+
dialerAccess.RLock()
27+
_, exists = dialerMap[engine.ptr]
28+
dialerAccess.RUnlock()
29+
30+
if exists {
31+
t.Error("dialer not cleaned up after Engine.Destroy()")
32+
}
33+
}
34+
35+
func TestSetDialerNil(t *testing.T) {
36+
engine := NewEngine()
37+
defer engine.Destroy()
38+
39+
// First set a dialer
40+
engine.SetDialer(func(address string, port uint16) int {
41+
return -104
42+
})
43+
44+
dialerAccess.RLock()
45+
_, exists := dialerMap[engine.ptr]
46+
dialerAccess.RUnlock()
47+
48+
if !exists {
49+
t.Error("dialer not registered")
50+
}
51+
52+
// Then set it to nil
53+
engine.SetDialer(nil)
54+
55+
dialerAccess.RLock()
56+
_, exists = dialerMap[engine.ptr]
57+
dialerAccess.RUnlock()
58+
59+
if exists {
60+
t.Error("dialer not removed after SetDialer(nil)")
61+
}
62+
}
63+
64+
func TestSetDialerOverwrite(t *testing.T) {
65+
engine := NewEngine()
66+
defer engine.Destroy()
67+
68+
callCount1 := 0
69+
callCount2 := 0
70+
71+
// Set first dialer
72+
engine.SetDialer(func(address string, port uint16) int {
73+
callCount1++
74+
return -104
75+
})
76+
77+
// Overwrite with second dialer
78+
engine.SetDialer(func(address string, port uint16) int {
79+
callCount2++
80+
return -102
81+
})
82+
83+
// Verify only one entry in map
84+
dialerAccess.RLock()
85+
count := 0
86+
for k := range dialerMap {
87+
if k == engine.ptr {
88+
count++
89+
}
90+
}
91+
dialerAccess.RUnlock()
92+
93+
if count != 1 {
94+
t.Errorf("expected 1 entry in dialerMap, got %d", count)
95+
}
96+
}
97+
98+
func TestDialerConcurrentAccess(t *testing.T) {
99+
engine := NewEngine()
100+
defer engine.Destroy()
101+
102+
var wg sync.WaitGroup
103+
iterations := 100
104+
105+
// Concurrent SetDialer calls (writers)
106+
for i := 0; i < iterations; i++ {
107+
wg.Add(1)
108+
go func(n int) {
109+
defer wg.Done()
110+
if n%2 == 0 {
111+
engine.SetDialer(func(address string, port uint16) int {
112+
return -104
113+
})
114+
} else {
115+
engine.SetDialer(nil)
116+
}
117+
}(i)
118+
}
119+
120+
// Concurrent dialerMap reads (simulating callback access)
121+
for i := 0; i < iterations; i++ {
122+
wg.Add(1)
123+
go func() {
124+
defer wg.Done()
125+
dialerAccess.RLock()
126+
_ = dialerMap[engine.ptr]
127+
dialerAccess.RUnlock()
128+
}()
129+
}
130+
131+
wg.Wait()
132+
133+
// Verify final state consistency: at most 1 entry for this engine
134+
dialerAccess.RLock()
135+
count := 0
136+
for k := range dialerMap {
137+
if k == engine.ptr {
138+
count++
139+
}
140+
}
141+
dialerAccess.RUnlock()
142+
143+
if count > 1 {
144+
t.Errorf("dialerMap has duplicate entries for engine: %d", count)
145+
}
146+
}
147+
148+
func TestMultipleEnginesDialers(t *testing.T) {
149+
engine1 := NewEngine()
150+
engine2 := NewEngine()
151+
152+
engine1.SetDialer(func(address string, port uint16) int {
153+
return -104
154+
})
155+
156+
engine2.SetDialer(func(address string, port uint16) int {
157+
return -102
158+
})
159+
160+
// Verify both dialers are registered
161+
dialerAccess.RLock()
162+
_, exists1 := dialerMap[engine1.ptr]
163+
_, exists2 := dialerMap[engine2.ptr]
164+
dialerAccess.RUnlock()
165+
166+
if !exists1 || !exists2 {
167+
t.Error("both dialers should be registered")
168+
}
169+
170+
// Destroy engine1, verify engine2's dialer still exists
171+
engine1.Destroy()
172+
173+
dialerAccess.RLock()
174+
_, exists1 = dialerMap[engine1.ptr]
175+
_, exists2 = dialerMap[engine2.ptr]
176+
dialerAccess.RUnlock()
177+
178+
if exists1 {
179+
t.Error("engine1's dialer should be removed")
180+
}
181+
if !exists2 {
182+
t.Error("engine2's dialer should still exist")
183+
}
184+
185+
engine2.Destroy()
186+
}

engine_cgo.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,42 @@ package cronet
55
// #include <stdlib.h>
66
// #include <stdbool.h>
77
// #include <cronet_c.h>
8+
//
9+
// extern int cronetDialerCallbackWrapper(void* context, const char* address, uint16_t port);
810
import "C"
911

1012
import (
13+
"sync"
1114
"unsafe"
1215
)
1316

17+
var (
18+
dialerAccess sync.RWMutex
19+
dialerMap = make(map[uintptr]Dialer)
20+
)
21+
22+
//export goDialerCallbackBridge
23+
func goDialerCallbackBridge(context C.uintptr_t, address *C.char, port C.uint16_t) C.int {
24+
dialerAccess.RLock()
25+
dialer, ok := dialerMap[uintptr(context)]
26+
dialerAccess.RUnlock()
27+
if !ok {
28+
return -104 // ERR_CONNECTION_FAILED
29+
}
30+
return C.int(dialer(C.GoString(address), uint16(port)))
31+
}
32+
1433
func NewEngine() Engine {
1534
return Engine{uintptr(unsafe.Pointer(C.Cronet_Engine_Create()))}
1635
}
1736

1837
func (e Engine) Destroy() {
38+
// Shutdown first to wait for all network activity to complete,
39+
// preventing race conditions with active dialer callbacks
40+
e.Shutdown()
41+
dialerAccess.Lock()
42+
delete(dialerMap, e.ptr)
43+
dialerAccess.Unlock()
1944
C.Cronet_Engine_Destroy(C.Cronet_EnginePtr(unsafe.Pointer(e.ptr)))
2045
}
2146

@@ -179,3 +204,26 @@ func (e Engine) SetCertVerifierWithPublicKeySHA256(hashes [][]byte) bool {
179204
C.Cronet_Engine_SetMockCertVerifierForTesting(C.Cronet_EnginePtr(unsafe.Pointer(e.ptr)), certVerifier)
180205
return true
181206
}
207+
208+
// SetDialer sets a custom dialer for TCP connections.
209+
// When set, the engine will use this callback to establish TCP connections
210+
// instead of the default system socket API.
211+
// Must be called before StartWithParams().
212+
// Pass nil to disable custom dialing.
213+
func (e Engine) SetDialer(dialer Dialer) {
214+
if dialer == nil {
215+
C.Cronet_Engine_SetDialer(C.Cronet_EnginePtr(unsafe.Pointer(e.ptr)), nil, nil)
216+
dialerAccess.Lock()
217+
delete(dialerMap, e.ptr)
218+
dialerAccess.Unlock()
219+
return
220+
}
221+
dialerAccess.Lock()
222+
dialerMap[e.ptr] = dialer
223+
dialerAccess.Unlock()
224+
C.Cronet_Engine_SetDialer(
225+
C.Cronet_EnginePtr(unsafe.Pointer(e.ptr)),
226+
C.Cronet_DialerFunc(C.cronetDialerCallbackWrapper),
227+
unsafe.Pointer(e.ptr),
228+
)
229+
}

0 commit comments

Comments
 (0)