Skip to content

Commit 7ffcb9e

Browse files
authored
feat(enhancement)!: lb next method with context usage (#1117)
1 parent a2cc38b commit 7ffcb9e

File tree

3 files changed

+130
-25
lines changed

3 files changed

+130
-25
lines changed

load_balancer.go

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package resty
77

88
import (
9+
"context"
910
"errors"
1011
"fmt"
1112
"net"
@@ -15,10 +16,13 @@ import (
1516
"time"
1617
)
1718

19+
// ErrNoBaseURLs error returned when no base URLs are found
20+
var ErrNoBaseURLs = errors.New("resty: no base URLs found")
21+
1822
// LoadBalancer is the interface that wraps the HTTP client load-balancing
1923
// algorithm that returns the "Next" Base URL for the request to target
2024
type LoadBalancer interface {
21-
Next() (string, error)
25+
NextWithContext(ctx context.Context) (string, error)
2226
Feedback(*RequestFeedback)
2327
Close() error
2428
}
@@ -34,6 +38,10 @@ type RequestFeedback struct {
3438
// NewRoundRobin method creates the new Round-Robin(RR) request load balancer
3539
// instance with given base URLs
3640
func NewRoundRobin(baseURLs ...string) (*RoundRobin, error) {
41+
if len(baseURLs) == 0 {
42+
return nil, ErrNoBaseURLs
43+
}
44+
3745
rr := &RoundRobin{lock: new(sync.Mutex)}
3846
if err := rr.Refresh(baseURLs...); err != nil {
3947
return rr, err
@@ -51,11 +59,22 @@ type RoundRobin struct {
5159
current int
5260
}
5361

54-
// Next method returns the next Base URL based on the Round-Robin(RR) algorithm
55-
func (rr *RoundRobin) Next() (string, error) {
62+
// NextWithContext method returns the next Base URL based on the Round-Robin(RR) algorithm
63+
// with context support for cancellation
64+
func (rr *RoundRobin) NextWithContext(ctx context.Context) (string, error) {
65+
select {
66+
case <-ctx.Done():
67+
return "", ctx.Err()
68+
default:
69+
}
70+
5671
rr.lock.Lock()
5772
defer rr.lock.Unlock()
5873

74+
if len(rr.baseURLs) == 0 {
75+
return "", ErrNoBaseURLs
76+
}
77+
5978
baseURL := rr.baseURLs[rr.current]
6079
rr.current = (rr.current + 1) % len(rr.baseURLs)
6180
return baseURL, nil
@@ -135,7 +154,7 @@ func NewWeightedRoundRobin(recovery time.Duration, hosts ...*Host) (*WeightedRou
135154
recovery = 120 * time.Second // defaults to 120 seconds
136155
}
137156
wrr := &WeightedRoundRobin{
138-
lock: new(sync.Mutex),
157+
lock: new(sync.RWMutex),
139158
hosts: make([]*Host, 0),
140159
tick: time.NewTicker(recovery),
141160
recovery: recovery,
@@ -153,7 +172,7 @@ var _ LoadBalancer = (*WeightedRoundRobin)(nil)
153172
// WeightedRoundRobin struct used to represent the host details for
154173
// Weighted Round-Robin(WRR) algorithm implementation
155174
type WeightedRoundRobin struct {
156-
lock *sync.Mutex
175+
lock *sync.RWMutex
157176
hosts []*Host
158177
totalWeight int
159178
tick *time.Ticker
@@ -165,8 +184,15 @@ type WeightedRoundRobin struct {
165184
recovery time.Duration
166185
}
167186

168-
// Next method returns the next Base URL based on Weighted Round-Robin(WRR)
169-
func (wrr *WeightedRoundRobin) Next() (string, error) {
187+
// NextWithContext method returns the next Base URL based on Weighted Round-Robin(WRR)
188+
// with context support for cancellation
189+
func (wrr *WeightedRoundRobin) NextWithContext(ctx context.Context) (string, error) {
190+
select {
191+
case <-ctx.Done():
192+
return "", ctx.Err()
193+
default:
194+
}
195+
170196
wrr.lock.Lock()
171197
defer wrr.lock.Unlock()
172198

@@ -196,6 +222,10 @@ func (wrr *WeightedRoundRobin) Next() (string, error) {
196222
// Feedback method process the request feedback for Weighted Round-Robin(WRR)
197223
// request load balancer
198224
func (wrr *WeightedRoundRobin) Feedback(f *RequestFeedback) {
225+
if f == nil {
226+
return
227+
}
228+
199229
wrr.lock.Lock()
200230
defer wrr.lock.Unlock()
201231

@@ -273,7 +303,11 @@ func (wrr *WeightedRoundRobin) SetRecoveryDuration(d time.Duration) {
273303
func (wrr *WeightedRoundRobin) ticker() {
274304
for range wrr.tick.C {
275305
wrr.lock.Lock()
276-
for _, host := range wrr.hosts {
306+
hosts := make([]*Host, len(wrr.hosts))
307+
copy(hosts, wrr.hosts)
308+
wrr.lock.Unlock()
309+
310+
for _, host := range hosts {
277311
if host.state == HostStateInActive {
278312
host.state = HostStateActive
279313
host.failedRequests = 0
@@ -283,7 +317,6 @@ func (wrr *WeightedRoundRobin) ticker() {
283317
}
284318
}
285319
}
286-
wrr.lock.Unlock()
287320
}
288321
}
289322

@@ -334,9 +367,10 @@ type SRVWeightedRoundRobin struct {
334367
lookupSRV func() ([]*net.SRV, error)
335368
}
336369

337-
// Next method returns the next SRV Base URL based on Weighted Round-Robin(RR)
338-
func (swrr *SRVWeightedRoundRobin) Next() (string, error) {
339-
return swrr.wrr.Next()
370+
// NextWithContext method returns the next SRV Base URL based on Weighted Round-Robin(RR)
371+
// with context support for cancellation
372+
func (swrr *SRVWeightedRoundRobin) NextWithContext(ctx context.Context) (string, error) {
373+
return swrr.wrr.NextWithContext(ctx)
340374
}
341375

342376
// Feedback method does nothing in SRV Base URL based on Weighted Round-Robin(WRR)

load_balancer_test.go

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package resty
77

88
import (
9+
"context"
910
"errors"
1011
"net"
1112
"net/http"
@@ -23,8 +24,9 @@ func TestRoundRobin(t *testing.T) {
2324

2425
runCount := 5
2526
var result []string
27+
ctx := context.Background()
2628
for i := 0; i < runCount; i++ {
27-
baseURL, _ := rr.Next()
29+
baseURL, _ := rr.NextWithContext(ctx)
2830
result = append(result, baseURL)
2931
}
3032

@@ -49,8 +51,9 @@ func TestRoundRobin(t *testing.T) {
4951

5052
runCount := 30
5153
var result []string
54+
ctx := context.Background()
5255
for i := 0; i < runCount; i++ {
53-
baseURL, _ := rr.Next()
56+
baseURL, _ := rr.NextWithContext(ctx)
5457
result = append(result, baseURL)
5558
}
5659

@@ -76,8 +79,9 @@ func TestRoundRobin(t *testing.T) {
7679

7780
runCount := 5
7881
var result []string
82+
ctx := context.Background()
7983
for i := 0; i < runCount; i++ {
80-
baseURL, _ := rr.Next()
84+
baseURL, _ := rr.NextWithContext(ctx)
8185
result = append(result, baseURL)
8286
}
8387

@@ -93,6 +97,43 @@ func TestRoundRobin(t *testing.T) {
9397
rr.Feedback(&RequestFeedback{})
9498
rr.Close()
9599
})
100+
101+
t.Run("NextWithContext context cancellation", func(t *testing.T) {
102+
rr, _ := NewRoundRobin("https://example.com")
103+
ctx, cancel := context.WithCancel(context.Background())
104+
cancel()
105+
_, err := rr.NextWithContext(ctx)
106+
assertErrorIs(t, context.Canceled, err)
107+
})
108+
109+
t.Run("NextWithContext normal operation", func(t *testing.T) {
110+
rr, _ := NewRoundRobin("https://example1.com", "https://example2.com")
111+
ctx := context.Background()
112+
url1, err := rr.NextWithContext(ctx)
113+
assertNil(t, err)
114+
url2, err := rr.NextWithContext(ctx)
115+
assertNil(t, err)
116+
assertNotEqual(t, url1, url2)
117+
})
118+
}
119+
120+
func TestRoundRobinNoBaseURLs(t *testing.T) {
121+
t.Run("new round robin no base urls", func(t *testing.T) {
122+
rr, err := NewRoundRobin()
123+
assertErrorIs(t, ErrNoBaseURLs, err)
124+
assertNil(t, rr)
125+
})
126+
127+
t.Run("new round robin no base urls on next with context", func(t *testing.T) {
128+
rr, err := NewRoundRobin("https://example1.com")
129+
assertNil(t, err)
130+
assertNotNil(t, rr)
131+
132+
rr.Refresh()
133+
ctx := context.Background()
134+
_, err = rr.NextWithContext(ctx)
135+
assertErrorIs(t, ErrNoBaseURLs, err)
136+
})
96137
}
97138

98139
func TestWeightedRoundRobin(t *testing.T) {
@@ -109,8 +150,9 @@ func TestWeightedRoundRobin(t *testing.T) {
109150

110151
runCount := 5
111152
var result []string
153+
ctx := context.Background()
112154
for i := 0; i < runCount; i++ {
113-
baseURL, err := wrr.Next()
155+
baseURL, err := wrr.NextWithContext(ctx)
114156
assertNil(t, err)
115157
result = append(result, baseURL)
116158
}
@@ -123,6 +165,8 @@ func TestWeightedRoundRobin(t *testing.T) {
123165
assertEqual(t, runCount, len(expected))
124166
assertEqual(t, runCount, len(result))
125167
assertEqual(t, expected, result)
168+
169+
wrr.Feedback(nil)
126170
})
127171

128172
t.Run("3 hosts with weight {2,1,10}", func(t *testing.T) {
@@ -143,8 +187,9 @@ func TestWeightedRoundRobin(t *testing.T) {
143187

144188
runCount := 10
145189
var result []string
190+
ctx := context.Background()
146191
for i := 0; i < runCount; i++ {
147-
baseURL, err := wrr.Next()
192+
baseURL, err := wrr.NextWithContext(ctx)
148193
assertNil(t, err)
149194
result = append(result, baseURL)
150195
if baseURL == "https://example3.com" && i%2 != 0 {
@@ -184,8 +229,9 @@ func TestWeightedRoundRobin(t *testing.T) {
184229

185230
runCount := 5
186231
var result []string
232+
ctx := context.Background()
187233
for i := 0; i < runCount; i++ {
188-
baseURL, err := wrr.Next()
234+
baseURL, err := wrr.NextWithContext(ctx)
189235
assertNil(t, err)
190236
result = append(result, baseURL)
191237
}
@@ -205,9 +251,31 @@ func TestWeightedRoundRobin(t *testing.T) {
205251
assertNil(t, err)
206252
defer wrr.Close()
207253

208-
_, err = wrr.Next()
254+
_, err = wrr.NextWithContext(context.Background())
209255
assertErrorIs(t, ErrNoActiveHost, err)
210256
})
257+
258+
t.Run("NextWithContext context cancellation", func(t *testing.T) {
259+
wrr, _ := NewWeightedRoundRobin(0, &Host{BaseURL: "https://example.com", Weight: 1})
260+
ctx, cancel := context.WithCancel(context.Background())
261+
cancel()
262+
_, err := wrr.NextWithContext(ctx)
263+
assertErrorIs(t, context.Canceled, err)
264+
})
265+
266+
t.Run("NextWithContext normal operation", func(t *testing.T) {
267+
hosts := []*Host{
268+
{BaseURL: "https://example1.com", Weight: 1},
269+
{BaseURL: "https://example2.com", Weight: 1},
270+
}
271+
wrr, _ := NewWeightedRoundRobin(0, hosts...)
272+
ctx := context.Background()
273+
url1, err := wrr.NextWithContext(ctx)
274+
assertNil(t, err)
275+
url2, err := wrr.NextWithContext(ctx)
276+
assertNil(t, err)
277+
assertNotEqual(t, url1, url2)
278+
})
211279
}
212280

213281
func TestSRVWeightedRoundRobin(t *testing.T) {
@@ -233,8 +301,9 @@ func TestSRVWeightedRoundRobin(t *testing.T) {
233301

234302
runCount := 5
235303
var result []string
304+
ctx := context.Background()
236305
for i := 0; i < runCount; i++ {
237-
baseURL, err := srv.Next()
306+
baseURL, err := srv.NextWithContext(ctx)
238307
assertNil(t, err)
239308
result = append(result, baseURL)
240309
}
@@ -271,8 +340,9 @@ func TestSRVWeightedRoundRobin(t *testing.T) {
271340

272341
runCount := 5
273342
var result []string
343+
ctx := context.Background()
274344
for i := 0; i < runCount; i++ {
275-
baseURL, err := srv.Next()
345+
baseURL, err := srv.NextWithContext(ctx)
276346
assertNil(t, err)
277347
result = append(result, baseURL)
278348
}
@@ -315,8 +385,9 @@ func TestSRVWeightedRoundRobin(t *testing.T) {
315385

316386
runCount := 20
317387
var result []string
388+
ctx := context.Background()
318389
for i := 0; i < runCount; i++ {
319-
baseURL, err := srv.Next()
390+
baseURL, err := srv.NextWithContext(ctx)
320391
assertNil(t, err)
321392
result = append(result, baseURL)
322393

@@ -363,7 +434,7 @@ func TestSRVWeightedRoundRobin(t *testing.T) {
363434

364435
go func() {
365436
for i := 0; i < 10; i++ {
366-
baseURL, _ := srv.Next()
437+
baseURL, _ := srv.NextWithContext(context.Background())
367438
assertNotNil(t, baseURL)
368439
time.Sleep(15 * time.Millisecond)
369440
}
@@ -438,7 +509,7 @@ func TestLoadBalancerRequestFlowError(t *testing.T) {
438509
c.SetLoadBalancer(wrr)
439510

440511
resp, err := c.R().Get("/")
441-
assertEqual(t, ErrNoActiveHost, err)
512+
assertErrorIs(t, ErrNoActiveHost, err)
442513
assertNil(t, resp)
443514
})
444515

middleware.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ func parseRequestURL(c *Client, r *Request) error {
131131
}
132132

133133
if r.client.LoadBalancer() != nil {
134-
r.baseURL, err = r.client.LoadBalancer().Next()
134+
r.baseURL, err = r.client.LoadBalancer().NextWithContext(r.Context())
135135
if err != nil {
136136
return &invalidRequestError{Err: err}
137137
}

0 commit comments

Comments
 (0)