Skip to content

Commit 0016bd5

Browse files
committed
🚲 ss2022: improve salt pool
- De-genericify SaltPool. - Add a singly linked list to make evictions cheaper (O(n) -> O(1)). - Improve the API.
1 parent c8c34ec commit 0016bd5

File tree

3 files changed

+148
-66
lines changed

3 files changed

+148
-66
lines changed

ss2022/saltpool.go

Lines changed: 76 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,67 +5,104 @@ import (
55
"time"
66
)
77

8-
// SaltPool stores salts for [retention, 2*retention) to protect against replay attacks
8+
// SaltPool stores salts for [ReplayWindowDuration] to protect against replay attacks
99
// during the replay window.
10-
type SaltPool[T comparable] struct {
11-
mu sync.RWMutex
12-
pool map[T]time.Time
13-
retention time.Duration
14-
lastClean time.Time
10+
type SaltPool struct {
11+
mu sync.RWMutex
12+
nodeBySalt map[[32]byte]*saltNode
13+
14+
// head is the oldest node.
15+
head *saltNode
16+
// tail is the newest node.
17+
tail *saltNode
1518
}
1619

17-
// clean removes expired salts from the pool,
18-
// if the amount of time since the last cleanup exceeds retention.
19-
func (p *SaltPool[T]) clean(now time.Time) {
20-
if now.Sub(p.lastClean) > p.retention {
21-
for salt, added := range p.pool {
22-
if now.Sub(added) > p.retention {
23-
delete(p.pool, salt)
24-
}
25-
}
26-
p.lastClean = now
27-
}
20+
type saltNode struct {
21+
next *saltNode
22+
salt [32]byte
23+
expiresAt time.Time
2824
}
2925

30-
// Check returns whether the given salt is valid (not in the pool).
31-
func (p *SaltPool[T]) Check(salt T) bool {
26+
// Contains returns whether the pool contains the given salt.
27+
func (p *SaltPool) Contains(salt [32]byte) bool {
3228
p.mu.RLock()
33-
_, ok := p.pool[salt]
29+
_, ok := p.nodeBySalt[salt]
3430
p.mu.RUnlock()
35-
return !ok
31+
return ok
3632
}
3733

38-
// TryCheck is like Check, but it immediately returns true if the pool is contended.
39-
func (p *SaltPool[T]) TryCheck(salt T) bool {
34+
// TryContains is like Contains, but it immediately returns false if the pool is contended.
35+
func (p *SaltPool) TryContains(salt [32]byte) bool {
4036
if p.mu.TryRLock() {
41-
_, ok := p.pool[salt]
37+
_, ok := p.nodeBySalt[salt]
4238
p.mu.RUnlock()
43-
return !ok
39+
return ok
4440
}
45-
return true
41+
return false
4642
}
4743

48-
// Add cleans the pool, checks if the salt already exists in the pool,
49-
// and adds the salt to the pool if the salt is not already in the pool.
44+
// Add adds the salt to the pool if it is not already in the pool.
5045
// It returns true if the salt was added, false if it already exists.
51-
func (p *SaltPool[T]) Add(now time.Time, salt T) bool {
46+
func (p *SaltPool) Add(now time.Time, salt [32]byte) bool {
5247
p.mu.Lock()
5348
defer p.mu.Unlock()
5449

55-
p.clean(now)
56-
if _, ok := p.pool[salt]; ok {
50+
p.pruneExpired(now)
51+
if _, ok := p.nodeBySalt[salt]; ok {
5752
return false
5853
}
59-
p.pool[salt] = now
54+
p.insert(now, salt)
6055
return true
6156
}
6257

63-
// NewSaltPool returns a new salt pool with the given retention as the minimum amount of time
64-
// for which an added salt is guaranteed to be kept in the pool.
65-
func NewSaltPool[T comparable](retention time.Duration) *SaltPool[T] {
66-
return &SaltPool[T]{
67-
pool: make(map[T]time.Time),
68-
retention: retention,
69-
lastClean: time.Now(),
58+
// Clear removes all salts from the pool.
59+
func (p *SaltPool) Clear() {
60+
p.mu.Lock()
61+
clear(p.nodeBySalt)
62+
p.head = nil
63+
p.tail = nil
64+
p.mu.Unlock()
65+
}
66+
67+
// pruneExpired removes all expired salts from the pool.
68+
func (p *SaltPool) pruneExpired(now time.Time) {
69+
node := p.head
70+
if node == nil || node.expiresAt.After(now) {
71+
return
72+
}
73+
for {
74+
delete(p.nodeBySalt, node.salt)
75+
node = node.next
76+
if node == nil {
77+
p.head = nil
78+
p.tail = nil
79+
return
80+
}
81+
if node.expiresAt.After(now) {
82+
p.head = node
83+
return
84+
}
85+
}
86+
}
87+
88+
// insert adds the new salt to the pool.
89+
func (p *SaltPool) insert(now time.Time, salt [32]byte) {
90+
node := &saltNode{
91+
salt: salt,
92+
expiresAt: now.Add(ReplayWindowDuration),
93+
}
94+
p.nodeBySalt[salt] = node
95+
if p.tail != nil {
96+
p.tail.next = node
97+
} else {
98+
p.head = node
99+
}
100+
p.tail = node
101+
}
102+
103+
// NewSaltPool returns a new salt pool.
104+
func NewSaltPool() *SaltPool {
105+
return &SaltPool{
106+
nodeBySalt: make(map[[32]byte]*saltNode),
70107
}
71108
}

ss2022/saltpool_test.go

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,89 @@
1-
package ss2022
1+
package ss2022_test
22

33
import (
44
"crypto/rand"
55
"testing"
66
"time"
7-
)
87

9-
func TestSaltPoolAddDuplicateSalts(t *testing.T) {
10-
const retention = 100 * time.Millisecond
11-
var salt [32]byte
12-
rand.Read(salt[:])
8+
"github.com/database64128/shadowsocks-go/ss2022"
9+
)
1310

14-
pool := NewSaltPool[[32]byte](retention)
11+
func TestSaltPool(t *testing.T) {
12+
pool := ss2022.NewSaltPool()
1513
now := time.Now()
14+
b := make([]byte, 64)
15+
rand.Read(b)
16+
salt0 := [32]byte(b)
17+
salt1 := [32]byte(b[32:])
18+
19+
// Clear empty pool.
20+
pool.Clear()
1621

17-
// Check fresh salt.
18-
if !pool.Check(salt) {
19-
t.Fatal("Denied fresh salt.")
22+
// Check salt0 and salt1.
23+
if pool.Contains(salt0) {
24+
t.Fatal("pool.Contains(salt0) = true, want false")
25+
}
26+
if pool.TryContains(salt1) {
27+
t.Fatal("pool.TryContains(salt1) = true, want false")
2028
}
2129

22-
// Add fresh salt.
23-
if !pool.Add(now, salt) {
24-
t.Fatal("Denied fresh salt.")
30+
// Add salt0.
31+
if !pool.Add(now, salt0) {
32+
t.Fatal("pool.Add(now, salt0) = false, want true")
33+
}
34+
if pool.Add(now, salt0) {
35+
t.Fatal("pool.Add(now, salt0) = true, want false")
2536
}
2637

27-
// Check the same salt again.
28-
if pool.Check(salt) {
29-
t.Fatal("Accepted duplicate salt.")
38+
// Advance some time.
39+
now = now.Add(ss2022.ReplayWindowDuration / 2)
40+
41+
// Add salt1.
42+
if !pool.Add(now, salt1) {
43+
t.Fatal("pool.Add(now, salt1) = false, want true")
44+
}
45+
if pool.Add(now, salt1) {
46+
t.Fatal("pool.Add(now, salt1) = true, want false")
3047
}
3148

32-
// Add the same salt again.
33-
if pool.Add(now, salt) {
34-
t.Fatal("Accepted duplicate salt.")
49+
// Check salt0 and salt1.
50+
if !pool.Contains(salt0) {
51+
t.Fatal("pool.Contains(salt0) = false, want true")
52+
}
53+
if !pool.Contains(salt1) {
54+
t.Fatal("pool.Contains(salt1) = false, want true")
3555
}
3656

37-
// Advance time to let the salt expire.
38-
now = now.Add(2 * retention)
57+
// Advance some time to let salt0 expire.
58+
now = now.Add(ss2022.ReplayWindowDuration / 2)
3959

40-
// Add the expired salt.
41-
if !pool.Add(now, salt) {
42-
t.Fatal("Denied expired salt.")
60+
// Add salt0 and salt1.
61+
if !pool.Add(now, salt0) {
62+
t.Fatal("pool.Add(now, salt0) = false, want true")
63+
}
64+
if pool.Add(now, salt1) {
65+
t.Fatal("pool.Add(now, salt1) = true, want false")
66+
}
67+
68+
// Advance some time to let both expire.
69+
now = now.Add(ss2022.ReplayWindowDuration)
70+
71+
// Add salt0 and salt1.
72+
if !pool.Add(now, salt0) {
73+
t.Fatal("pool.Add(now, salt0) = false, want true")
74+
}
75+
if !pool.Add(now, salt1) {
76+
t.Fatal("pool.Add(now, salt1) = false, want true")
77+
}
78+
79+
// Clear the pool.
80+
pool.Clear()
81+
82+
// Check salt0 and salt1 again.
83+
if pool.TryContains(salt0) {
84+
t.Fatal("pool.TryContains(salt0) = true, want false")
85+
}
86+
if pool.TryContains(salt1) {
87+
t.Fatal("pool.TryContains(salt1) = true, want false")
4388
}
4489
}

ss2022/tcp.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ type StreamServerConfig struct {
230230
// NewStreamServer returns a new Shadowsocks 2022 stream server.
231231
func (c *StreamServerConfig) NewStreamServer() *StreamServer {
232232
return &StreamServer{
233-
saltPool: *NewSaltPool[[32]byte](ReplayWindowDuration),
233+
saltPool: *NewSaltPool(),
234234
readOnceOrFull: readOnceOrFullFunc(c.AllowSegmentedFixedLengthHeader),
235235
userCipherConfig: c.UserCipherConfig,
236236
identityCipherConfig: c.IdentityCipherConfig,
@@ -246,7 +246,7 @@ func (c *StreamServerConfig) NewStreamServer() *StreamServer {
246246
// StreamServer implements [netio.StreamServer].
247247
type StreamServer struct {
248248
CredStore
249-
saltPool SaltPool[[32]byte]
249+
saltPool SaltPool
250250
readOnceOrFull func(io.Reader, []byte) (int, error)
251251
userCipherConfig UserCipherConfig
252252
identityCipherConfig ServerIdentityCipherConfig
@@ -329,7 +329,7 @@ func (s *StreamServer) HandleStream(rawRW netio.Conn, logger *zap.Logger) (req n
329329
extendedSalt := lengthExtendSalt(salt)
330330

331331
// Check but not add request salt to pool.
332-
if !s.saltPool.TryCheck(extendedSalt) {
332+
if s.saltPool.TryContains(extendedSalt) {
333333
err = ErrRepeatedSalt
334334
return
335335
}

0 commit comments

Comments
 (0)