Skip to content

Commit 6a63fb9

Browse files
authored
fix(lbcache): check the existence before new Balancer to prevent leakage (cloudwego#1825)
1 parent a990796 commit 6a63fb9

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

pkg/loadbalance/lbcache/cache.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@ func (b *BalancerFactory) Get(ctx context.Context, target rpcinfo.EndpointInfo)
163163
return val.(*Balancer), nil
164164
}
165165
val, err, _ := b.sfg.Do(desc, func() (interface{}, error) {
166+
if v, ok := b.cache.Load(desc); ok {
167+
// cache may be set already
168+
return v.(*Balancer), nil
169+
}
166170
res, err := b.resolver.Resolve(ctx, desc)
167171
if err != nil {
168172
return nil, err

pkg/loadbalance/lbcache/cache_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"encoding/json"
2222
"fmt"
23+
"sync"
2324
"testing"
2425
"time"
2526

@@ -126,3 +127,49 @@ func (m *mockRebalancer) Delete(ch discovery.Change) {
126127
m.deleteFunc(ch)
127128
}
128129
}
130+
131+
type mockResolver struct{}
132+
133+
func (m *mockResolver) Target(ctx context.Context, target rpcinfo.EndpointInfo) (description string) {
134+
return "target"
135+
}
136+
137+
func (m *mockResolver) Resolve(ctx context.Context, desc string) (discovery.Result, error) {
138+
return discovery.Result{}, nil
139+
}
140+
141+
func (m *mockResolver) Diff(cacheKey string, prev, next discovery.Result) (discovery.Change, bool) {
142+
return discovery.Change{}, false
143+
}
144+
145+
func (m *mockResolver) Name() string {
146+
return "name"
147+
}
148+
149+
var _ discovery.Resolver = &mockResolver{}
150+
151+
func TestConcurrentGet(t *testing.T) {
152+
cacheOpts := Options{Cacheable: false, RefreshInterval: time.Second, ExpireInterval: 5 * time.Second}
153+
bf := newBalancerFactory(&mockResolver{}, loadbalance.NewWeightedBalancer(), cacheOpts)
154+
m := sync.Map{}
155+
wg := sync.WaitGroup{}
156+
157+
// concurrent get
158+
for i := 0; i < 100; i++ {
159+
wg.Add(1)
160+
go func() {
161+
defer wg.Done()
162+
b, _ := bf.Get(context.Background(), nil)
163+
m.Store(b, b)
164+
}()
165+
}
166+
wg.Wait()
167+
168+
// check if length == 1, (target -> balancer, 1:1)
169+
cnt := 0
170+
m.Range(func(key, value any) bool {
171+
cnt++
172+
return true
173+
})
174+
test.Assert(t, cnt == 1, cnt)
175+
}

0 commit comments

Comments
 (0)