Skip to content

Commit fd698fe

Browse files
committed
add tests, fix nil pointer issue
1 parent 7d6faf0 commit fd698fe

File tree

2 files changed

+187
-26
lines changed

2 files changed

+187
-26
lines changed

adapter/groups/manager.go

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package adapter
1+
package groups
22

33
import (
44
"context"
@@ -62,7 +62,7 @@ func (m *MutableGroupManager) Close() {
6262
if m.closed.Swap(true) {
6363
return
6464
}
65-
m.removalQueue.stop()
65+
m.removalQueue.close()
6666
}
6767

6868
func (m *MutableGroupManager) Groups() []adapter.MutableOutboundGroup {
@@ -126,10 +126,12 @@ func (m *MutableGroupManager) createForGroup(
126126
}
127127
return fmt.Errorf("failed to add %s to %s: %w", tag, group, err)
128128
}
129+
// remove from removal queue in case it was scheduled for removal
130+
m.removalQueue.dequeue(tag)
129131
return nil
130132
}
131133

132-
// CreateOutboundForGroup creates an outbound for the specified group.
134+
// RemoveFromGroup removes an outbound/endpoint from the specified group.
133135
func (m *MutableGroupManager) RemoveFromGroup(group, tag string) error {
134136
m.mu.Lock()
135137
defer m.mu.Unlock()
@@ -150,7 +152,7 @@ func (m *MutableGroupManager) RemoveFromGroup(group, tag string) error {
150152
}
151153

152154
_, isEndpoint := m.endpointMgr.Get(tag)
153-
m.removalQueue.add(tag, isEndpoint)
155+
m.removalQueue.enqueue(tag, isEndpoint)
154156
return nil
155157
}
156158

@@ -161,12 +163,12 @@ type removalQueue struct {
161163
epMgr A.EndpointManager
162164
connMgr ConnectionManager
163165
pending map[string]item
164-
ticker *time.Ticker
165166
pollInterval time.Duration
166167
forceAfter time.Duration
167168
mu sync.RWMutex
169+
running atomic.Bool
168170
done chan struct{}
169-
closed atomic.Bool
171+
once sync.Once
170172
}
171173

172174
type item struct {
@@ -194,10 +196,13 @@ func newRemovalQueue(
194196
}
195197
}
196198

197-
func (rq *removalQueue) add(tag string, isEndpoint bool) {
198-
if rq.closed.Load() {
199+
func (rq *removalQueue) enqueue(tag string, isEndpoint bool) {
200+
select {
201+
case <-rq.done:
199202
return
203+
default:
200204
}
205+
201206
rq.mu.Lock()
202207
defer rq.mu.Unlock()
203208
if _, exists := rq.pending[tag]; exists {
@@ -208,19 +213,37 @@ func (rq *removalQueue) add(tag string, isEndpoint bool) {
208213
isEndpoint: isEndpoint,
209214
addedAt: time.Now(),
210215
}
211-
if rq.ticker == nil {
212-
rq.ticker = time.NewTicker(rq.pollInterval)
216+
if !rq.running.Load() {
213217
go rq.checkLoop()
214218
}
215219
}
216220

221+
func (rq *removalQueue) dequeue(tag string) {
222+
rq.mu.Lock()
223+
delete(rq.pending, tag)
224+
rq.mu.Unlock()
225+
}
226+
217227
func (rq *removalQueue) checkLoop() {
228+
if !rq.running.CompareAndSwap(false, true) {
229+
return
230+
}
231+
defer rq.running.Store(false)
232+
233+
rq.checkPending()
234+
ticker := time.NewTicker(rq.pollInterval)
235+
defer ticker.Stop()
218236
for {
237+
rq.mu.Lock()
238+
if len(rq.pending) == 0 {
239+
rq.mu.Unlock()
240+
return
241+
}
242+
rq.mu.Unlock()
219243
select {
220-
case <-rq.ticker.C:
244+
case <-ticker.C:
221245
rq.checkPending()
222246
case <-rq.done:
223-
rq.ticker.Stop()
224247
return
225248
}
226249
}
@@ -231,15 +254,13 @@ func (rq *removalQueue) checkLoop() {
231254
func (rq *removalQueue) checkPending() {
232255
rq.mu.RLock()
233256
pending := make(map[string]item, len(rq.pending))
234-
for tag, item := range rq.pending {
235-
pending[tag] = item
236-
}
257+
maps.Copy(pending, rq.pending)
237258
rq.mu.RUnlock()
238259

239260
hasConns := make(map[string]bool, len(rq.pending))
240261
for _, conn := range rq.connMgr.Connections() {
241262
if _, exists := pending[conn.Outbound]; exists {
242-
hasConns[conn.Outbound] = hasConns[conn.Outbound] || !conn.ClosedAt.IsZero()
263+
hasConns[conn.Outbound] = hasConns[conn.Outbound] || conn.ClosedAt.IsZero()
243264
}
244265
}
245266

@@ -266,18 +287,11 @@ func (rq *removalQueue) checkPending() {
266287
}
267288
delete(rq.pending, tag)
268289
}
269-
// Stop ticker if no pending items remain
270-
if len(rq.pending) == 0 && rq.ticker != nil {
271-
rq.ticker.Stop()
272-
rq.ticker = nil
273-
}
274290
}
275291
}
276292

277-
func (rq *removalQueue) stop() {
278-
select {
279-
case <-rq.done:
280-
default:
293+
func (rq *removalQueue) close() {
294+
rq.once.Do(func() {
281295
close(rq.done)
282-
}
296+
})
283297
}

adapter/groups/manager_test.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package groups
2+
3+
import (
4+
"slices"
5+
"sync"
6+
"testing"
7+
"time"
8+
9+
"github.com/sagernet/sing-box/adapter"
10+
"github.com/sagernet/sing-box/experimental/clashapi/trafficontrol"
11+
"github.com/sagernet/sing-box/log"
12+
"github.com/stretchr/testify/assert"
13+
"github.com/stretchr/testify/require"
14+
)
15+
16+
func TestRemovalQueue(t *testing.T) {
17+
logger := log.NewNOPFactory().Logger()
18+
tag := "outbound"
19+
tests := []struct {
20+
name string
21+
outMgr *mockOutboundManager
22+
epMgr *mockEndpointManager
23+
connMgr *mockConnectionManager
24+
pending map[string]item
25+
forceAfter time.Duration
26+
assertFn func(t *testing.T, rq *removalQueue)
27+
}{
28+
{
29+
name: "remove outbound",
30+
outMgr: &mockOutboundManager{tags: []string{tag}},
31+
connMgr: &mockConnectionManager{},
32+
pending: map[string]item{tag: {tag, false, time.Now()}},
33+
forceAfter: time.Minute,
34+
assertFn: func(t *testing.T, rq *removalQueue) {
35+
assert.NotContains(t, rq.outMgr.(*mockOutboundManager).tags, tag, "tag should be removed")
36+
},
37+
},
38+
{
39+
name: "remove endpoint",
40+
epMgr: &mockEndpointManager{tags: []string{tag}},
41+
connMgr: &mockConnectionManager{},
42+
pending: map[string]item{tag: {tag, true, time.Now()}},
43+
forceAfter: time.Minute,
44+
assertFn: func(t *testing.T, rq *removalQueue) {
45+
assert.NotContains(t, rq.epMgr.(*mockEndpointManager).tags, tag, "tag should be removed")
46+
},
47+
},
48+
{
49+
name: "force removal after duration",
50+
outMgr: &mockOutboundManager{tags: []string{tag}},
51+
connMgr: &mockConnectionManager{
52+
conns: []trafficontrol.TrackerMetadata{{Outbound: tag, ClosedAt: time.Time{}}},
53+
},
54+
pending: map[string]item{
55+
tag: {tag, false, time.Now().Add(-time.Second * 10)},
56+
},
57+
forceAfter: time.Second,
58+
assertFn: func(t *testing.T, rq *removalQueue) {
59+
assert.NotContains(t, rq.outMgr.(*mockOutboundManager).tags, tag, "tag should be removed")
60+
},
61+
},
62+
{
63+
name: "don't remove if still in use",
64+
outMgr: &mockOutboundManager{tags: []string{tag}},
65+
connMgr: &mockConnectionManager{
66+
conns: []trafficontrol.TrackerMetadata{{Outbound: tag, ClosedAt: time.Time{}}},
67+
},
68+
pending: map[string]item{tag: {tag, false, time.Now()}},
69+
forceAfter: time.Minute,
70+
assertFn: func(t *testing.T, rq *removalQueue) {
71+
assert.Contains(t, rq.outMgr.(*mockOutboundManager).tags, tag, "tag should still be present")
72+
},
73+
},
74+
{
75+
name: "don't remove if re-added",
76+
outMgr: &mockOutboundManager{tags: []string{tag}},
77+
connMgr: &mockConnectionManager{
78+
conns: []trafficontrol.TrackerMetadata{{Outbound: tag, ClosedAt: time.Time{}}},
79+
},
80+
pending: map[string]item{tag: {tag, false, time.Now()}},
81+
forceAfter: time.Minute,
82+
assertFn: func(t *testing.T, rq *removalQueue) {
83+
require.Contains(t, rq.outMgr.(*mockOutboundManager).tags, tag, "tag should still be present before re-adding")
84+
85+
rq.dequeue(tag)
86+
rq.connMgr.Connections()[0].ClosedAt = time.Now() // simulate connection closed
87+
rq.checkPending()
88+
assert.Contains(t, rq.outMgr.(*mockOutboundManager).tags, tag, "tag should still be present after re-adding")
89+
},
90+
},
91+
}
92+
for _, tt := range tests {
93+
t.Run(tt.name, func(t *testing.T) {
94+
rq := &removalQueue{
95+
logger: logger,
96+
outMgr: tt.outMgr,
97+
epMgr: tt.epMgr,
98+
connMgr: tt.connMgr,
99+
pending: tt.pending,
100+
forceAfter: tt.forceAfter,
101+
}
102+
rq.checkPending()
103+
tt.assertFn(t, rq)
104+
})
105+
}
106+
}
107+
108+
type mockOutboundManager struct {
109+
adapter.OutboundManager
110+
tags []string
111+
mu sync.Mutex
112+
}
113+
114+
func (m *mockOutboundManager) Remove(tag string) error {
115+
m.mu.Lock()
116+
defer m.mu.Unlock()
117+
if idx := slices.Index(m.tags, tag); idx != -1 {
118+
m.tags = append(m.tags[:idx], m.tags[idx+1:]...)
119+
}
120+
return nil
121+
}
122+
123+
type mockEndpointManager struct {
124+
adapter.EndpointManager
125+
tags []string
126+
mu sync.Mutex
127+
}
128+
129+
func (m *mockEndpointManager) Remove(tag string) error {
130+
m.mu.Lock()
131+
defer m.mu.Unlock()
132+
if idx := slices.Index(m.tags, tag); idx != -1 {
133+
m.tags = append(m.tags[:idx], m.tags[idx+1:]...)
134+
}
135+
return nil
136+
}
137+
138+
type mockConnectionManager struct {
139+
conns []trafficontrol.TrackerMetadata
140+
}
141+
142+
func (m *mockConnectionManager) Connections() []trafficontrol.TrackerMetadata {
143+
if m == nil {
144+
return nil
145+
}
146+
return m.conns
147+
}

0 commit comments

Comments
 (0)