Skip to content

Commit 493f5ea

Browse files
committed
Add DNS resolver interception for NaiveClient
Add DNSResolverFunc callback support to NaiveClient that allows custom DNS resolution. When configured, DNS queries are intercepted via socketpair and handled by the user-provided resolver. Features: - DNSResolverFunc type using github.com/miekg/dns - EngineParams helpers: SetAsyncDNS(), SetDNSServerOverride() - Cross-platform socketpair implementation (Unix DGRAM, Windows framed) - Support for both UDP and TCP DNS protocols
1 parent 71bbc58 commit 493f5ea

15 files changed

+1144
-22
lines changed

dialer_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,32 @@ func TestDialerMapCleanup(t *testing.T) {
3232
}
3333
}
3434

35+
func TestUDPDialerMapCleanup(t *testing.T) {
36+
engine := NewEngine()
37+
38+
engine.SetUDPDialer(func(address string, port uint16) (int, string, uint16) {
39+
return -104, "", 0 // ERR_CONNECTION_FAILED
40+
})
41+
42+
udpDialerAccess.RLock()
43+
_, exists := udpDialerMap[engine.ptr]
44+
udpDialerAccess.RUnlock()
45+
46+
if !exists {
47+
t.Error("dialer not registered in udpDialerMap")
48+
}
49+
50+
engine.Destroy()
51+
52+
udpDialerAccess.RLock()
53+
_, exists = udpDialerMap[engine.ptr]
54+
udpDialerAccess.RUnlock()
55+
56+
if exists {
57+
t.Error("dialer not cleaned up after Engine.Destroy()")
58+
}
59+
}
60+
3561
func TestSetDialerNil(t *testing.T) {
3662
engine := NewEngine()
3763
defer engine.Destroy()
@@ -61,6 +87,33 @@ func TestSetDialerNil(t *testing.T) {
6187
}
6288
}
6389

90+
func TestSetUDPDialerNil(t *testing.T) {
91+
engine := NewEngine()
92+
defer engine.Destroy()
93+
94+
engine.SetUDPDialer(func(address string, port uint16) (int, string, uint16) {
95+
return -104, "", 0
96+
})
97+
98+
udpDialerAccess.RLock()
99+
_, exists := udpDialerMap[engine.ptr]
100+
udpDialerAccess.RUnlock()
101+
102+
if !exists {
103+
t.Error("dialer not registered")
104+
}
105+
106+
engine.SetUDPDialer(nil)
107+
108+
udpDialerAccess.RLock()
109+
_, exists = udpDialerMap[engine.ptr]
110+
udpDialerAccess.RUnlock()
111+
112+
if exists {
113+
t.Error("dialer not removed after SetUDPDialer(nil)")
114+
}
115+
}
116+
64117
func TestSetDialerOverwrite(t *testing.T) {
65118
engine := NewEngine()
66119
defer engine.Destroy()

dns_intercept.go

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
package cronet
2+
3+
import (
4+
"context"
5+
"encoding/binary"
6+
"errors"
7+
"io"
8+
"net"
9+
"time"
10+
11+
mDNS "github.com/miekg/dns"
12+
)
13+
14+
const chromiumDNSUDPMaxSize = 512
15+
16+
func serveDNSPacketConn(ctx context.Context, conn net.PacketConn, resolver DNSResolverFunc) error {
17+
defer conn.Close()
18+
19+
// For Unix socketpair, use Write() instead of WriteTo()
20+
// because socketpair is connected and WriteTo() with address doesn't work
21+
var unixConn *net.UnixConn
22+
if uc, ok := conn.(*net.UnixConn); ok {
23+
unixConn = uc
24+
}
25+
buffer := make([]byte, 64*1024)
26+
for {
27+
select {
28+
case <-ctx.Done():
29+
return ctx.Err()
30+
default:
31+
}
32+
33+
if deadlineConn, ok := conn.(interface{ SetReadDeadline(time.Time) error }); ok {
34+
_ = deadlineConn.SetReadDeadline(time.Now().Add(time.Second))
35+
}
36+
37+
n, remoteAddress, err := conn.ReadFrom(buffer)
38+
if err != nil {
39+
if errors.Is(err, net.ErrClosed) {
40+
return nil
41+
}
42+
if netError := (*net.OpError)(nil); errors.As(err, &netError) && netError.Timeout() {
43+
continue
44+
}
45+
continue
46+
}
47+
48+
var request mDNS.Msg
49+
if err := request.Unpack(buffer[:n]); err != nil {
50+
continue
51+
}
52+
53+
response := resolver(ctx, &request)
54+
response = normalizeDNSResponse(&request, response)
55+
56+
packed, err := response.Pack()
57+
if err != nil {
58+
continue
59+
}
60+
if len(packed) > chromiumDNSUDPMaxSize {
61+
truncated := truncatedDNSResponse(&request, response.Rcode)
62+
packed, err = truncated.Pack()
63+
if err != nil {
64+
continue
65+
}
66+
}
67+
68+
// For Unix socketpair, use Write(); for regular UDP, use WriteTo()
69+
if unixConn != nil {
70+
_, _ = unixConn.Write(packed)
71+
} else if remoteAddress != nil {
72+
_, _ = conn.WriteTo(packed, remoteAddress)
73+
}
74+
}
75+
}
76+
77+
func serveDNSStreamConn(ctx context.Context, conn net.Conn, resolver DNSResolverFunc) error {
78+
defer conn.Close()
79+
80+
for {
81+
select {
82+
case <-ctx.Done():
83+
return ctx.Err()
84+
default:
85+
}
86+
87+
_ = conn.SetReadDeadline(time.Now().Add(30 * time.Second))
88+
var queryLength uint16
89+
if err := binary.Read(conn, binary.BigEndian, &queryLength); err != nil {
90+
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
91+
return nil
92+
}
93+
return err
94+
}
95+
if queryLength == 0 {
96+
return nil
97+
}
98+
99+
query := make([]byte, int(queryLength))
100+
_, err := io.ReadFull(conn, query)
101+
if err != nil {
102+
if errors.Is(err, net.ErrClosed) {
103+
return nil
104+
}
105+
return err
106+
}
107+
108+
var request mDNS.Msg
109+
if err := request.Unpack(query); err != nil {
110+
continue
111+
}
112+
113+
response := resolver(ctx, &request)
114+
response = normalizeDNSResponse(&request, response)
115+
116+
packed, err := response.Pack()
117+
if err != nil {
118+
continue
119+
}
120+
121+
_ = conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
122+
var lengthPrefix [2]byte
123+
binary.BigEndian.PutUint16(lengthPrefix[:], uint16(len(packed)))
124+
if _, err := conn.Write(lengthPrefix[:]); err != nil {
125+
return err
126+
}
127+
if _, err := conn.Write(packed); err != nil {
128+
return err
129+
}
130+
}
131+
}
132+
133+
func normalizeDNSResponse(request *mDNS.Msg, response *mDNS.Msg) *mDNS.Msg {
134+
if response == nil {
135+
fallback := new(mDNS.Msg)
136+
fallback.SetReply(request)
137+
fallback.Rcode = mDNS.RcodeServerFailure
138+
return fallback
139+
}
140+
141+
response.Id = request.Id
142+
response.Response = true
143+
if len(response.Question) == 0 {
144+
response.Question = request.Question
145+
}
146+
return response
147+
}
148+
149+
func truncatedDNSResponse(request *mDNS.Msg, rcode int) *mDNS.Msg {
150+
response := new(mDNS.Msg)
151+
response.SetReply(request)
152+
response.Truncated = true
153+
response.Rcode = rcode
154+
return response
155+
}

dns_socketpair_unix.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//go:build unix
2+
3+
package cronet
4+
5+
import (
6+
"net"
7+
"os"
8+
"syscall"
9+
10+
E "github.com/sagernet/sing/common/exceptions"
11+
)
12+
13+
func createPacketSocketPair() (cronetFD int, proxyConn net.PacketConn, err error) {
14+
fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
15+
if err != nil {
16+
return -1, nil, E.Cause(err, "create dgram socketpair")
17+
}
18+
19+
syscall.CloseOnExec(fds[0])
20+
21+
file := os.NewFile(uintptr(fds[1]), "cronet-dgram-socketpair")
22+
conn, err := net.FilePacketConn(file)
23+
_ = file.Close()
24+
if err != nil {
25+
syscall.Close(fds[0])
26+
return -1, nil, E.Cause(err, "create packet conn from socketpair")
27+
}
28+
29+
return fds[0], conn, nil
30+
}
31+

dns_socketpair_unix_test.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
//go:build unix
2+
3+
package cronet
4+
5+
import (
6+
"syscall"
7+
"testing"
8+
)
9+
10+
func TestCreatePacketSocketPair(t *testing.T) {
11+
fd, conn, err := createPacketSocketPair()
12+
if err != nil {
13+
t.Fatalf("createPacketSocketPair failed: %v", err)
14+
}
15+
defer syscall.Close(fd)
16+
defer conn.Close()
17+
18+
if fd <= 0 {
19+
t.Errorf("expected valid fd, got %d", fd)
20+
}
21+
}
22+
23+
func TestCreatePacketSocketPair_BidirectionalCommunication(t *testing.T) {
24+
fd, conn, err := createPacketSocketPair()
25+
if err != nil {
26+
t.Fatalf("createPacketSocketPair failed: %v", err)
27+
}
28+
defer syscall.Close(fd)
29+
defer conn.Close()
30+
31+
// fd → conn (datagram boundary preserved)
32+
testData := []byte("hello from fd")
33+
_, err = syscall.Write(fd, testData)
34+
if err != nil {
35+
t.Fatalf("write to fd failed: %v", err)
36+
}
37+
38+
buf := make([]byte, 1024)
39+
n, _, err := conn.ReadFrom(buf)
40+
if err != nil {
41+
t.Fatalf("read from conn failed: %v", err)
42+
}
43+
if string(buf[:n]) != string(testData) {
44+
t.Errorf("expected %q, got %q", testData, buf[:n])
45+
}
46+
47+
// conn → fd: For Unix socketpairs, use Write() via net.Conn interface
48+
// (same as serveDNSPacketConn does when remoteAddress is nil)
49+
streamConn, ok := conn.(interface{ Write([]byte) (int, error) })
50+
if !ok {
51+
t.Fatal("PacketConn should implement Write for socketpair")
52+
}
53+
testData2 := []byte("hello from conn")
54+
_, err = streamConn.Write(testData2)
55+
if err != nil {
56+
t.Fatalf("write to conn failed: %v", err)
57+
}
58+
59+
n, err = syscall.Read(fd, buf)
60+
if err != nil {
61+
t.Fatalf("read from fd failed: %v", err)
62+
}
63+
if string(buf[:n]) != string(testData2) {
64+
t.Errorf("expected %q, got %q", testData2, buf[:n])
65+
}
66+
}
67+
68+
func TestCreatePacketSocketPair_MessageBoundary(t *testing.T) {
69+
fd, conn, err := createPacketSocketPair()
70+
if err != nil {
71+
t.Fatalf("createPacketSocketPair failed: %v", err)
72+
}
73+
defer syscall.Close(fd)
74+
defer conn.Close()
75+
76+
// Send multiple different-sized messages
77+
messages := [][]byte{
78+
[]byte("short"),
79+
make([]byte, 512),
80+
make([]byte, 1400),
81+
}
82+
// Fill with recognizable patterns
83+
for i := range messages[1] {
84+
messages[1][i] = byte(i % 256)
85+
}
86+
for i := range messages[2] {
87+
messages[2][i] = byte((i * 7) % 256)
88+
}
89+
90+
for _, msg := range messages {
91+
_, err = syscall.Write(fd, msg)
92+
if err != nil {
93+
t.Fatalf("write failed: %v", err)
94+
}
95+
}
96+
97+
// Verify each message boundary is preserved
98+
for i, expected := range messages {
99+
buf := make([]byte, 2048)
100+
n, _, err := conn.ReadFrom(buf)
101+
if err != nil {
102+
t.Fatalf("read %d failed: %v", i, err)
103+
}
104+
if n != len(expected) {
105+
t.Errorf("message %d: expected length %d, got %d", i, len(expected), n)
106+
}
107+
if string(buf[:n]) != string(expected) {
108+
t.Errorf("message %d: content mismatch", i)
109+
}
110+
}
111+
}

0 commit comments

Comments
 (0)