Skip to content

Commit ab8caea

Browse files
committed
🥞 socks5: use string interning for domain cache
1 parent fbb75f7 commit ab8caea

File tree

6 files changed

+137
-92
lines changed

6 files changed

+137
-92
lines changed

direct/packet.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ func (ShadowsocksNonePacketServerPacker) PackInPlace(b []byte, sourceAddrPort ne
249249

250250
// ShadowsocksNonePacketServerUnpacker implements the zerocopy Unpacker interface.
251251
type ShadowsocksNonePacketServerUnpacker struct {
252-
// cachedDomain caches the last used domain target to avoid allocating new strings.
253-
cachedDomain string
252+
// domainCache caches the last used domain target to avoid allocating new strings.
253+
domainCache socks5.DomainCache
254254
}
255255

256256
// ServerUnpackerInfo implements the zerocopy.ServerUnpacker ServerUnpackerInfo method.
@@ -263,7 +263,7 @@ func (ShadowsocksNonePacketServerUnpacker) ServerUnpackerInfo() zerocopy.ServerU
263263
// UnpackInPlace implements the zerocopy.ServerUnpacker UnpackInPlace method.
264264
func (p *ShadowsocksNonePacketServerUnpacker) UnpackInPlace(b []byte, sourceAddrPort netip.AddrPort, packetStart, packetLen int) (targetAddr conn.Addr, payloadStart, payloadLen int, err error) {
265265
var targetAddrLen int
266-
targetAddr, targetAddrLen, p.cachedDomain, err = socks5.ConnAddrFromSliceWithDomainCache(b[packetStart:packetStart+packetLen], p.cachedDomain)
266+
targetAddr, targetAddrLen, err = p.domainCache.ConnAddrFromSlice(b[packetStart : packetStart+packetLen])
267267
payloadStart = packetStart + targetAddrLen
268268
payloadLen = packetLen - targetAddrLen
269269
return
@@ -395,8 +395,8 @@ func (Socks5PacketServerPacker) PackInPlace(b []byte, sourceAddrPort netip.AddrP
395395

396396
// Socks5PacketServerUnpacker implements the zerocopy Unpacker interface.
397397
type Socks5PacketServerUnpacker struct {
398-
// cachedDomain caches the last used domain target to avoid allocating new strings.
399-
cachedDomain string
398+
// domainCache caches the last used domain target to avoid allocating new strings.
399+
domainCache socks5.DomainCache
400400
}
401401

402402
// ServerUnpackerInfo implements the zerocopy.ServerUnpacker ServerUnpackerInfo method.
@@ -420,7 +420,7 @@ func (p *Socks5PacketServerUnpacker) UnpackInPlace(b []byte, sourceAddrPort neti
420420
}
421421

422422
var targetAddrLen int
423-
targetAddr, targetAddrLen, p.cachedDomain, err = socks5.ConnAddrFromSliceWithDomainCache(pkt[3:], p.cachedDomain)
423+
targetAddr, targetAddrLen, err = p.domainCache.ConnAddrFromSlice(pkt[3:])
424424
payloadStart = packetStart + targetAddrLen + 3
425425
payloadLen = packetLen - targetAddrLen - 3
426426
return

socks5/addr.go

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io"
88
"net/netip"
99
"slices"
10+
"unique"
1011
"unsafe"
1112

1213
"github.com/database64128/shadowsocks-go/conn"
@@ -284,72 +285,87 @@ func ConnAddrFromSlice(b []byte) (conn.Addr, int, error) {
284285

285286
switch b[0] {
286287
case AtypDomainName:
287-
if len(b) < 1+1+int(b[1])+2 {
288-
return conn.Addr{}, 0, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
288+
domainLen := int(b[1])
289+
domainEnd := 1 + 1 + domainLen
290+
portEnd := domainEnd + 2
291+
if len(b) < portEnd {
292+
return conn.Addr{}, 0, fmt.Errorf("addr length %d is too short for ATYP %#x", len(b), b[0])
289293
}
290-
domain := string(b[2 : 2+int(b[1])])
291-
port := binary.BigEndian.Uint16(b[2+int(b[1]):])
294+
domain := string(b[2:domainEnd])
295+
port := binary.BigEndian.Uint16(b[domainEnd:])
292296
addr, err := conn.AddrFromDomainPort(domain, port)
293-
return addr, 2 + int(b[1]) + 2, err
297+
return addr, portEnd, err
294298

295299
case AtypIPv4:
296300
if len(b) < 1+4+2 {
297-
return conn.Addr{}, 0, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
301+
return conn.Addr{}, 0, fmt.Errorf("addr length %d is too short for ATYP %#x", len(b), b[0])
298302
}
299303
ip := netip.AddrFrom4(*(*[4]byte)(b[1:]))
300304
port := binary.BigEndian.Uint16(b[1+4:])
301305
return conn.AddrFromIPAndPort(ip, port), 1 + 4 + 2, nil
302306

303307
case AtypIPv6:
304308
if len(b) < 1+16+2 {
305-
return conn.Addr{}, 0, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
309+
return conn.Addr{}, 0, fmt.Errorf("addr length %d is too short for ATYP %#x", len(b), b[0])
306310
}
307311
ip := netip.AddrFrom16(*(*[16]byte)(b[1:]))
308312
port := binary.BigEndian.Uint16(b[1+16:])
309313
return conn.AddrFromIPAndPort(ip, port), 1 + 16 + 2, nil
310314

311315
default:
312-
return conn.Addr{}, 0, fmt.Errorf("invalid ATYP: %d", b[0])
316+
return conn.Addr{}, 0, fmt.Errorf("invalid ATYP: %#x", b[0])
313317
}
314318
}
315319

316-
// ConnAddrFromSliceWithDomainCache is like [ConnAddrFromSlice] but uses a domain cache to minimize string allocations.
317-
// The returned string is the updated domain cache.
318-
func ConnAddrFromSliceWithDomainCache(b []byte, cachedDomain string) (conn.Addr, int, string, error) {
320+
// DomainCache uses string interning to avoid unnecessary allocations when parsing domain name SOCKS5 addresses.
321+
//
322+
// The zero value is ready for use.
323+
type DomainCache struct {
324+
domain string
325+
handle unique.Handle[string]
326+
}
327+
328+
// ConnAddrFromSlice is like [ConnAddrFromSlice] but uses the domain cache to minimize string allocations.
329+
func (c *DomainCache) ConnAddrFromSlice(b []byte) (conn.Addr, int, error) {
319330
if len(b) < 2 {
320-
return conn.Addr{}, 0, cachedDomain, fmt.Errorf("addr length too short: %d", len(b))
331+
return conn.Addr{}, 0, fmt.Errorf("addr length too short: %d", len(b))
321332
}
322333

323334
switch b[0] {
324335
case AtypDomainName:
325-
if len(b) < 1+1+int(b[1])+2 {
326-
return conn.Addr{}, 0, cachedDomain, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
336+
domainLen := int(b[1])
337+
domainEnd := 1 + 1 + domainLen
338+
portEnd := domainEnd + 2
339+
if len(b) < portEnd {
340+
return conn.Addr{}, 0, fmt.Errorf("addr length %d is too short for ATYP %#x", len(b), b[0])
327341
}
328-
domain := b[2 : 2+int(b[1])]
329-
if cachedDomain != string(domain) { // Hopefully the compiler will optimize the string allocation away.
330-
cachedDomain = string(domain)
342+
if domainBytes := b[2:domainEnd]; string(domainBytes) != c.domain {
343+
// Unsafe is required for Go 1.24 and earlier to avoid allocating on lookup.
344+
// Drop unsafe when we upgrade to Go 1.25.
345+
c.handle = unique.Make(unsafe.String(unsafe.SliceData(domainBytes), len(domainBytes)))
346+
c.domain = c.handle.Value()
331347
}
332-
port := binary.BigEndian.Uint16(b[2+int(b[1]):])
333-
addr, err := conn.AddrFromDomainPort(cachedDomain, port)
334-
return addr, 2 + int(b[1]) + 2, cachedDomain, err
348+
port := binary.BigEndian.Uint16(b[domainEnd:])
349+
addr, err := conn.AddrFromDomainPort(c.domain, port)
350+
return addr, portEnd, err
335351

336352
case AtypIPv4:
337353
if len(b) < 1+4+2 {
338-
return conn.Addr{}, 0, cachedDomain, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
354+
return conn.Addr{}, 0, fmt.Errorf("addr length %d is too short for ATYP %#x", len(b), b[0])
339355
}
340356
ip := netip.AddrFrom4(*(*[4]byte)(b[1 : 1+4]))
341357
port := binary.BigEndian.Uint16(b[1+4:])
342-
return conn.AddrFromIPAndPort(ip, port), 1 + 4 + 2, cachedDomain, nil
358+
return conn.AddrFromIPAndPort(ip, port), 1 + 4 + 2, nil
343359

344360
case AtypIPv6:
345361
if len(b) < 1+16+2 {
346-
return conn.Addr{}, 0, cachedDomain, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
362+
return conn.Addr{}, 0, fmt.Errorf("addr length %d is too short for ATYP %#x", len(b), b[0])
347363
}
348364
ip := netip.AddrFrom16(*(*[16]byte)(b[1 : 1+16]))
349365
port := binary.BigEndian.Uint16(b[1+16:])
350-
return conn.AddrFromIPAndPort(ip, port), 1 + 16 + 2, cachedDomain, nil
366+
return conn.AddrFromIPAndPort(ip, port), 1 + 16 + 2, nil
351367

352368
default:
353-
return conn.Addr{}, 0, cachedDomain, fmt.Errorf("invalid ATYP: %d", b[0])
369+
return conn.Addr{}, 0, fmt.Errorf("invalid ATYP: %#x", b[0])
354370
}
355371
}

socks5/addr_test.go

Lines changed: 76 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,44 +11,72 @@ import (
1111
)
1212

1313
// Test zero value address.
14+
1415
var (
1516
addrZero = IPv4UnspecifiedAddr
1617
addrZeroConnAddr conn.Addr
1718
)
1819

1920
// Test IPv4 address.
21+
22+
const addr4port uint16 = 1080
23+
2024
var (
21-
addr4 = [IPv4AddrLen]byte{AtypIPv4, 127, 0, 0, 1, 4, 56}
22-
addr4addr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
23-
addr4port uint16 = 1080
24-
addr4addrport = netip.AddrPortFrom(addr4addr, addr4port)
25-
addr4connaddr = conn.AddrFromIPPort(addr4addrport)
25+
addr4 = [IPv4AddrLen]byte{
26+
AtypIPv4,
27+
127, 0, 0, 1,
28+
byte(addr4port >> 8), byte(addr4port & 0xff),
29+
}
30+
addr4addr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
31+
addr4addrport = netip.AddrPortFrom(addr4addr, addr4port)
32+
addr4connaddr = conn.AddrFromIPPort(addr4addrport)
2633
)
2734

2835
// Test IPv4-mapped IPv6 address.
36+
37+
const addr4in6port uint16 = 1080
38+
2939
var (
30-
addr4in6 = [IPv4AddrLen]byte{AtypIPv4, 127, 0, 0, 1, 4, 56}
31-
addr4in6addr = netip.AddrFrom16([16]byte{10: 0xff, 11: 0xff, 127, 0, 0, 1})
32-
addr4in6port uint16 = 1080
33-
addr4in6addrport = netip.AddrPortFrom(addr4in6addr, addr4in6port)
34-
addr4in6connaddr = conn.AddrFromIPPort(addr4in6addrport)
40+
addr4in6 = [IPv4AddrLen]byte{
41+
AtypIPv4,
42+
127, 0, 0, 1,
43+
byte(addr4in6port >> 8), byte(addr4in6port & 0xff),
44+
}
45+
addr4in6addr = netip.AddrFrom16([16]byte{10: 0xff, 11: 0xff, 127, 0, 0, 1})
46+
addr4in6addrport = netip.AddrPortFrom(addr4in6addr, addr4in6port)
47+
addr4in6connaddr = conn.AddrFromIPPort(addr4in6addrport)
3548
)
3649

3750
// Test IPv6 address.
51+
52+
const addr6port uint16 = 1080
53+
3854
var (
39-
addr6 = [IPv6AddrLen]byte{AtypIPv6, 0x20, 0x01, 0x0d, 0xb8, 0xfa, 0xd6, 0x05, 0x72, 0xac, 0xbe, 0x71, 0x43, 0x14, 0xe5, 0x7a, 0x6e, 4, 56}
40-
addr6addr = netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8, 0xfa, 0xd6, 0x05, 0x72, 0xac, 0xbe, 0x71, 0x43, 0x14, 0xe5, 0x7a, 0x6e})
41-
addr6port uint16 = 1080
42-
addr6addrport = netip.AddrPortFrom(addr6addr, addr6port)
43-
addr6connaddr = conn.AddrFromIPPort(addr6addrport)
55+
addr6 = [IPv6AddrLen]byte{
56+
AtypIPv6,
57+
0x20, 0x01, 0x0d, 0xb8, 0xfa, 0xd6, 0x05, 0x72, 0xac, 0xbe, 0x71, 0x43, 0x14, 0xe5, 0x7a, 0x6e,
58+
byte(addr6port >> 8), byte(addr6port & 0xff),
59+
}
60+
addr6addr = netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8, 0xfa, 0xd6, 0x05, 0x72, 0xac, 0xbe, 0x71, 0x43, 0x14, 0xe5, 0x7a, 0x6e})
61+
addr6addrport = netip.AddrPortFrom(addr6addr, addr6port)
62+
addr6connaddr = conn.AddrFromIPPort(addr6addrport)
4463
)
4564

4665
// Test domain name.
66+
67+
const (
68+
addrDomainHost = "example.com"
69+
addrDomainPort uint16 = 443
70+
)
71+
4772
var (
48-
addrDomain = [1 + 1 + 11 + 2]byte{AtypDomainName, 11, 'e', 'x', 'a', 'm', 'p', 'l', 'e', '.', 'c', 'o', 'm', 1, 187}
49-
addrDomainHost = "example.com"
50-
addrDomainPort uint16 = 443
51-
addrDomainConnAddr = conn.MustAddrFromDomainPort(addrDomainHost, addrDomainPort)
73+
addrDomain = [1 + 1 + len(addrDomainHost) + 2]byte{
74+
AtypDomainName,
75+
byte(len(addrDomainHost)),
76+
'e', 'x', 'a', 'm', 'p', 'l', 'e', '.', 'c', 'o', 'm',
77+
byte(addrDomainPort >> 8), byte(addrDomainPort & 0xff),
78+
}
79+
addrDomainConnAddr = conn.MustAddrFromDomainPort(addrDomainHost, addrDomainPort)
5280
)
5381

5482
func testAddrFromReader(t *testing.T, addr []byte) {
@@ -156,51 +184,54 @@ func TestConnAddrFromSliceAndReader(t *testing.T) {
156184
testConnAddrFromSliceAndReader(t, addrDomain[:], addrDomainConnAddr)
157185
}
158186

159-
func testConnAddrFromSliceWithDomainCache(t *testing.T, sa []byte, cachedDomain string, expectedAddr conn.Addr) string {
160-
b := make([]byte, 512)
187+
func testConnAddrFromSliceWithDomainCache(t *testing.T, b, sa []byte, dc *DomainCache, expectedAddr conn.Addr) {
161188
n := copy(b, sa)
162-
rand.Read(b[n:])
163-
expectedTail := make([]byte, 512-n)
164-
copy(expectedTail, b[n:])
189+
tail := b[n:]
190+
rand.Read(tail)
191+
expectedTail := make([]byte, 0, 512)
192+
expectedTail = append(expectedTail, tail...)
165193

166-
addr, n, cachedDomain, err := ConnAddrFromSliceWithDomainCache(b, cachedDomain)
194+
addr, n, err := dc.ConnAddrFromSlice(b)
167195
if err != nil {
168196
t.Fatal(err)
169197
}
170198
if n != len(sa) {
171-
t.Errorf("ConnAddrFromSlice(b) returned n=%d, expected n=%d.", n, len(sa))
199+
t.Errorf("dc.ConnAddrFromSlice(%x) returned n=%d, expected n=%d", b, n, len(sa))
172200
}
173201
if !addr.Equals(expectedAddr) {
174-
t.Errorf("ConnAddrFromSlice(b) returned %s, expected %s.", addr, expectedAddr)
202+
t.Errorf("dc.ConnAddrFromSlice(%x) returned %s, expected %s", b, addr, expectedAddr)
175203
}
176204
if !bytes.Equal(b[n:], expectedTail) {
177-
t.Error("ConnAddrFromSlice(b) modified non-address bytes.")
205+
t.Errorf("dc.ConnAddrFromSlice(%x) modified non-address bytes", b)
178206
}
179-
return cachedDomain
180207
}
181208

182209
func TestConnAddrFromSliceWithDomainCache(t *testing.T) {
183-
const s = "🌐"
184-
cachedDomain := s
185-
186-
cachedDomain = testConnAddrFromSliceWithDomainCache(t, addr4[:], cachedDomain, addr4connaddr)
187-
if cachedDomain != s {
188-
t.Errorf("ConnAddrFromSliceWithDomainCache(addr4) modified cachedDomain to %s.", cachedDomain)
189-
}
210+
var dc DomainCache
211+
b := make([]byte, 512)
190212

191-
cachedDomain = testConnAddrFromSliceWithDomainCache(t, addr4in6[:], cachedDomain, addr4connaddr)
192-
if cachedDomain != s {
193-
t.Errorf("ConnAddrFromSliceWithDomainCache(addr4in6) modified cachedDomain to %s.", cachedDomain)
213+
if n := testing.AllocsPerRun(10, func() {
214+
testConnAddrFromSliceWithDomainCache(t, b, addr4[:], &dc, addr4connaddr)
215+
testConnAddrFromSliceWithDomainCache(t, b, addr4in6[:], &dc, addr4connaddr)
216+
testConnAddrFromSliceWithDomainCache(t, b, addr6[:], &dc, addr6connaddr)
217+
testConnAddrFromSliceWithDomainCache(t, b, addrDomain[:], &dc, addrDomainConnAddr)
218+
}); n > 0 {
219+
t.Errorf("AllocsPerRun(10, ...) = %f, want 0", n)
194220
}
195221

196-
cachedDomain = testConnAddrFromSliceWithDomainCache(t, addr6[:], cachedDomain, addr6connaddr)
197-
if cachedDomain != s {
198-
t.Errorf("ConnAddrFromSliceWithDomainCache(addr6) modified cachedDomain to %s.", cachedDomain)
222+
const addrDomain2Host = "www.google.com"
223+
addrDomain2 := [1 + 1 + len(addrDomain2Host) + 2]byte{
224+
AtypDomainName,
225+
byte(len(addrDomain2Host)),
226+
'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm',
227+
byte(addrDomainPort >> 8), byte(addrDomainPort & 0xff),
199228
}
229+
addrDomain2ConnAddr := conn.MustAddrFromDomainPort(addrDomain2Host, addrDomainPort)
200230

201-
cachedDomain = testConnAddrFromSliceWithDomainCache(t, addrDomain[:], cachedDomain, addrDomainConnAddr)
202-
if cachedDomain != addrDomainHost {
203-
t.Errorf("ConnAddrFromSliceWithDomainCache(addrDomain) modified cachedDomain to %s, expected %s.", cachedDomain, addrDomainHost)
231+
if n := testing.AllocsPerRun(10, func() {
232+
testConnAddrFromSliceWithDomainCache(t, b, addrDomain2[:], &dc, addrDomain2ConnAddr)
233+
}); n > 0 {
234+
t.Errorf("AllocsPerRun(10, ...) = %f, want 0", n)
204235
}
205236
}
206237

ss2022/header.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,7 @@ func PutSessionIDAndPacketID(b []byte, sid, pid uint64) {
333333
// +------+---------------+----------------+----------+------+----------+-------+----------+
334334
// | 1B | 8B unix epoch | u16be | variable | 1B | variable | u16be | variable |
335335
// +------+---------------+----------------+----------+------+----------+-------+----------+
336-
func ParseUDPClientMessageHeader(b []byte, now time.Time, cachedDomain string) (targetAddr conn.Addr, updatedCachedDomain string, payloadStart, payloadLen int, err error) {
337-
updatedCachedDomain = cachedDomain
338-
336+
func ParseUDPClientMessageHeader(b []byte, now time.Time, domainCache *socks5.DomainCache) (targetAddr conn.Addr, payloadStart, payloadLen int, err error) {
339337
// Make sure buffer has type + timestamp + padding length.
340338
if len(b) < UDPClientMessageHeaderFixedLength {
341339
err = ErrPacketIncompleteHeader
@@ -366,7 +364,7 @@ func ParseUDPClientMessageHeader(b []byte, now time.Time, cachedDomain string) (
366364

367365
// SOCKS address
368366
var n int
369-
targetAddr, n, updatedCachedDomain, err = socks5.ConnAddrFromSliceWithDomainCache(b[payloadStart:], cachedDomain)
367+
targetAddr, n, err = domainCache.ConnAddrFromSlice(b[payloadStart:])
370368
if err != nil {
371369
return
372370
}

0 commit comments

Comments
 (0)