Skip to content

Commit f6aef5c

Browse files
committed
feat: wrap singleflight.group to provide CheckAndDo
1 parent 6a63fb9 commit f6aef5c

File tree

7 files changed

+189
-100
lines changed

7 files changed

+189
-100
lines changed

pkg/loadbalance/consist.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ import (
2323
"time"
2424

2525
"github.com/bytedance/gopkg/util/xxhash3"
26-
"golang.org/x/sync/singleflight"
27-
2826
"github.com/cloudwego/kitex/pkg/discovery"
2927
"github.com/cloudwego/kitex/pkg/utils"
3028
)
@@ -226,7 +224,7 @@ func buildConsistResult(info *consistInfo, key uint64) *consistResult {
226224
type consistBalancer struct {
227225
cachedConsistInfo sync.Map
228226
opt ConsistentHashOption
229-
sfg singleflight.Group
227+
sfg utils.SingleFlightGroup
230228
}
231229

232230
// NewConsistBalancer creates a new consist balancer with the given option.
@@ -247,13 +245,17 @@ func NewConsistBalancer(opt ConsistentHashOption) Loadbalancer {
247245
func (cb *consistBalancer) GetPicker(e discovery.Result) Picker {
248246
var ci *consistInfo
249247
if e.Cacheable {
250-
cii, ok := cb.cachedConsistInfo.Load(e.CacheKey)
251-
if !ok {
252-
cii, _, _ = cb.sfg.Do(e.CacheKey, func() (interface{}, error) {
253-
return cb.newConsistInfo(e), nil
254-
})
255-
cb.cachedConsistInfo.Store(e.CacheKey, cii)
256-
}
248+
cii, _, _ := cb.sfg.CheckAndDo(
249+
e.CacheKey,
250+
func() (any, bool) {
251+
return cb.cachedConsistInfo.Load(e.CacheKey)
252+
},
253+
func() (interface{}, error) {
254+
res := cb.newConsistInfo(e)
255+
cb.cachedConsistInfo.Store(e.CacheKey, res)
256+
return res, nil
257+
},
258+
)
257259
ci = cii.(*consistInfo)
258260
} else {
259261
ci = cb.newConsistInfo(e)

pkg/loadbalance/lbcache/cache.go

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ import (
2424
"sync/atomic"
2525
"time"
2626

27-
"golang.org/x/sync/singleflight"
28-
2927
"github.com/cloudwego/kitex/pkg/diagnosis"
3028
"github.com/cloudwego/kitex/pkg/discovery"
3129
"github.com/cloudwego/kitex/pkg/klog"
@@ -41,7 +39,7 @@ const (
4139

4240
var (
4341
balancerFactories sync.Map // key: resolver name + loadbalance name
44-
balancerFactoriesSfg singleflight.Group
42+
balancerFactoriesSfg utils.SingleFlightGroup
4543
)
4644

4745
// Options for create builder
@@ -88,7 +86,7 @@ type BalancerFactory struct {
8886
resolver discovery.Resolver
8987
balancer loadbalance.Loadbalancer
9088
rebalancer loadbalance.Rebalancer
91-
sfg singleflight.Group
89+
sfg utils.SingleFlightGroup
9290
}
9391

9492
func cacheKey(resolver, balancer string, opts Options) string {
@@ -120,15 +118,15 @@ func NewBalancerFactory(resolver discovery.Resolver, balancer loadbalance.Loadba
120118
return newBalancerFactory(resolver, balancer, opts)
121119
}
122120
uniqueKey := cacheKey(resolver.Name(), balancer.Name(), opts)
123-
val, ok := balancerFactories.Load(uniqueKey)
124-
if ok {
125-
return val.(*BalancerFactory)
126-
}
127-
val, _, _ = balancerFactoriesSfg.Do(uniqueKey, func() (interface{}, error) {
128-
b := newBalancerFactory(resolver, balancer, opts)
129-
balancerFactories.Store(uniqueKey, b)
130-
return b, nil
131-
})
121+
val, _, _ := balancerFactoriesSfg.CheckAndDo(uniqueKey,
122+
func() (any, bool) {
123+
return balancerFactories.Load(uniqueKey)
124+
},
125+
func() (interface{}, error) {
126+
b := newBalancerFactory(resolver, balancer, opts)
127+
balancerFactories.Store(uniqueKey, b)
128+
return b, nil
129+
})
132130
return val.(*BalancerFactory)
133131
}
134132

@@ -158,29 +156,25 @@ func renameResultCacheKey(res *discovery.Result, resolverName string) {
158156
// Get create a new balancer if not exists
159157
func (b *BalancerFactory) Get(ctx context.Context, target rpcinfo.EndpointInfo) (*Balancer, error) {
160158
desc := b.resolver.Target(ctx, target)
161-
val, ok := b.cache.Load(desc)
162-
if ok {
163-
return val.(*Balancer), nil
164-
}
165-
val, err, _ := b.sfg.Do(desc, func() (interface{}, error) {
166-
if v, ok := b.cache.Load(desc); ok {
167-
// cache may be set already
168-
return v.(*Balancer), nil
169-
}
170-
res, err := b.resolver.Resolve(ctx, desc)
171-
if err != nil {
172-
return nil, err
173-
}
174-
renameResultCacheKey(&res, b.resolver.Name())
175-
bl := &Balancer{
176-
b: b,
177-
target: desc,
178-
}
179-
bl.res.Store(res)
180-
bl.sharedTicker = getSharedTicker(bl, b.opts.RefreshInterval)
181-
b.cache.Store(desc, bl)
182-
return bl, nil
183-
})
159+
val, err, _ := b.sfg.CheckAndDo(desc,
160+
func() (any, bool) {
161+
return b.cache.Load(desc)
162+
},
163+
func() (interface{}, error) {
164+
res, err := b.resolver.Resolve(ctx, desc)
165+
if err != nil {
166+
return nil, err
167+
}
168+
renameResultCacheKey(&res, b.resolver.Name())
169+
bl := &Balancer{
170+
b: b,
171+
target: desc,
172+
}
173+
bl.res.Store(res)
174+
bl.sharedTicker = getSharedTicker(bl, b.opts.RefreshInterval)
175+
b.cache.Store(desc, bl)
176+
return bl, nil
177+
})
184178
if err != nil {
185179
return nil, err
186180
}

pkg/loadbalance/lbcache/shared_ticker.go

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,32 +20,27 @@ import (
2020
"sync"
2121
"time"
2222

23-
"golang.org/x/sync/singleflight"
24-
2523
"github.com/cloudwego/kitex/pkg/utils"
2624
)
2725

2826
var (
2927
// insert, not delete
3028
sharedTickers sync.Map
31-
sharedTickersSfg singleflight.Group
29+
sharedTickersSfg utils.SingleFlightGroup
3230
)
3331

3432
func getSharedTicker(b *Balancer, refreshInterval time.Duration) *utils.SharedTicker {
35-
sti, ok := sharedTickers.Load(refreshInterval)
36-
if ok {
37-
st := sti.(*utils.SharedTicker)
38-
st.Add(b)
39-
return st
40-
}
41-
v, _, _ := sharedTickersSfg.Do(refreshInterval.String(), func() (interface{}, error) {
42-
st := utils.NewSharedTicker(refreshInterval)
43-
sharedTickers.Store(refreshInterval, st)
44-
return st, nil
45-
})
33+
v, _, _ := sharedTickersSfg.CheckAndDo(
34+
refreshInterval.String(),
35+
func() (any, bool) {
36+
return sharedTickers.Load(refreshInterval)
37+
},
38+
func() (interface{}, error) {
39+
st := utils.NewSharedTicker(refreshInterval)
40+
sharedTickers.Store(refreshInterval, st)
41+
return st, nil
42+
})
4643
st := v.(*utils.SharedTicker)
47-
// Add without singleflight,
48-
// because we need all refreshers those call this function to add themselves to SharedTicker
4944
st.Add(b)
5045
return st
5146
}

pkg/loadbalance/weighted_balancer.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package loadbalance
1919
import (
2020
"sync"
2121

22-
"golang.org/x/sync/singleflight"
22+
"github.com/cloudwego/kitex/pkg/utils"
2323

2424
"github.com/cloudwego/kitex/pkg/discovery"
2525
"github.com/cloudwego/kitex/pkg/klog"
@@ -35,7 +35,7 @@ const (
3535
type weightedBalancer struct {
3636
kind int
3737
pickerCache sync.Map
38-
sfg singleflight.Group
38+
sfg utils.SingleFlightGroup
3939
}
4040

4141
// NewWeightedBalancer creates a loadbalancer using weighted-round-robin algorithm.
@@ -73,15 +73,15 @@ func (wb *weightedBalancer) GetPicker(e discovery.Result) Picker {
7373
picker := wb.createPicker(e)
7474
return picker
7575
}
76-
77-
picker, ok := wb.pickerCache.Load(e.CacheKey)
78-
if !ok {
79-
picker, _, _ = wb.sfg.Do(e.CacheKey, func() (interface{}, error) {
76+
picker, _, _ := wb.sfg.CheckAndDo(e.CacheKey,
77+
func() (any, bool) {
78+
return wb.pickerCache.Load(e.CacheKey)
79+
},
80+
func() (interface{}, error) {
8081
p := wb.createPicker(e)
8182
wb.pickerCache.Store(e.CacheKey, p)
8283
return p, nil
8384
})
84-
}
8585
return picker.(Picker)
8686
}
8787

pkg/remote/trans/nphttp2/conn_pool.go

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626
"sync/atomic"
2727
"time"
2828

29-
"golang.org/x/sync/singleflight"
29+
"github.com/cloudwego/kitex/pkg/utils"
3030

3131
"github.com/cloudwego/kitex/pkg/klog"
3232
"github.com/cloudwego/kitex/pkg/remote"
@@ -57,7 +57,7 @@ func NewConnPool(remoteService string, size uint32, connOpts grpc.ConnectOptions
5757
// MuxPool manages a pool of long connections.
5858
type connPool struct {
5959
size uint32
60-
sfg singleflight.Group
60+
sfg utils.SingleFlightGroup
6161
conns sync.Map // key: address, value: *transports
6262
remoteService string // remote service name
6363
connOpts grpc.ConnectOptions
@@ -141,41 +141,39 @@ func (p *connPool) Get(ctx context.Context, network, address string, opt remote.
141141

142142
var (
143143
trans *transports
144-
conn *clientConn
145144
err error
146145
)
147146

148-
v, ok := p.conns.Load(address)
149-
if ok {
150-
trans = v.(*transports)
151-
if tr := trans.get(); tr != nil {
152-
if tr.(grpc.IsActive).IsActive() {
153-
// Actually new a stream, reuse the connection (grpc.ClientTransport)
154-
conn, err = newClientConn(ctx, tr, address)
155-
if err == nil {
156-
return conn, nil
147+
tr, err, _ := p.sfg.CheckAndDo(address,
148+
func() (any, bool) {
149+
v, ok := p.conns.Load(address)
150+
if ok {
151+
trans = v.(*transports)
152+
if tr := trans.get(); tr != nil {
153+
if tr.(grpc.IsActive).IsActive() {
154+
return tr, true
155+
}
157156
}
158-
klog.CtxDebugf(ctx, "KITEX: New grpc stream failed, network=%s, address=%s, error=%s", network, address, err.Error())
159157
}
160-
}
161-
}
162-
tr, err, _ := p.sfg.Do(address, func() (i interface{}, e error) {
163-
// Notice: newTransport means new a connection, the timeout of connection cannot be set,
164-
// so using context.Background() but not the ctx passed in as the parameter.
165-
tr, err := p.newTransport(context.Background(), opt.Dialer, network, address, opt.ConnectTimeout, p.connOpts)
166-
if err != nil {
167-
return nil, err
168-
}
169-
if trans == nil {
170-
trans = &transports{
171-
size: p.size,
172-
cliTransports: make([]grpc.ClientTransport, p.size),
158+
return nil, false
159+
},
160+
func() (i interface{}, e error) {
161+
// Notice: newTransport means new a connection, the timeout of connection cannot be set,
162+
// so using context.Background() but not the ctx passed in as the parameter.
163+
tr, err := p.newTransport(context.Background(), opt.Dialer, network, address, opt.ConnectTimeout, p.connOpts)
164+
if err != nil {
165+
return nil, err
173166
}
174-
}
175-
trans.put(tr) // the tr (connection) maybe not in the pool, but can be recycled by keepalive.
176-
p.conns.Store(address, trans)
177-
return tr, nil
178-
})
167+
if trans == nil {
168+
trans = &transports{
169+
size: p.size,
170+
cliTransports: make([]grpc.ClientTransport, p.size),
171+
}
172+
}
173+
trans.put(tr) // the tr (connection) maybe not in the pool, but can be recycled by keepalive.
174+
p.conns.Store(address, trans)
175+
return tr, nil
176+
})
179177
if err != nil {
180178
klog.CtxErrorf(ctx, "KITEX: New grpc client connection failed, network=%s, address=%s, error=%s", network, address, err.Error())
181179
return nil, err

pkg/utils/singleFlight.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package utils
2+
3+
import "golang.org/x/sync/singleflight"
4+
5+
// SingleFlightGroup is a wrapper around singleflight.Group to provide a CheckAndDo functionality.
6+
// It is used to ensure that `fn` is executed only once for a given key.
7+
type SingleFlightGroup struct {
8+
singleflight.Group
9+
}
10+
11+
// CheckAndDo implements a double-checked pattern to ensure that `fn` is executed only once for a given key.
12+
// It performs an initial check with `checkFunc` before calling `singleflight.Do`. If the value is found, return.
13+
// Otherwise, call `Do` and execute `checkFunc` again before executing `fn`.
14+
func (g *SingleFlightGroup) CheckAndDo(key string, checkFunc func() (any, bool), fn func() (any, error)) (v any, err error, shared bool) {
15+
if value, loaded := checkFunc(); loaded {
16+
return value, nil, true
17+
}
18+
return g.Do(key, func() (any, error) {
19+
if value, loaded := checkFunc(); loaded {
20+
return value, nil
21+
}
22+
return fn()
23+
})
24+
}

0 commit comments

Comments
 (0)