66package resty
77
88import (
9+ "context"
910 "errors"
1011 "net"
1112 "net/http"
@@ -23,8 +24,9 @@ func TestRoundRobin(t *testing.T) {
2324
2425 runCount := 5
2526 var result []string
27+ ctx := context .Background ()
2628 for i := 0 ; i < runCount ; i ++ {
27- baseURL , _ := rr .Next ( )
29+ baseURL , _ := rr .NextWithContext ( ctx )
2830 result = append (result , baseURL )
2931 }
3032
@@ -49,8 +51,9 @@ func TestRoundRobin(t *testing.T) {
4951
5052 runCount := 30
5153 var result []string
54+ ctx := context .Background ()
5255 for i := 0 ; i < runCount ; i ++ {
53- baseURL , _ := rr .Next ( )
56+ baseURL , _ := rr .NextWithContext ( ctx )
5457 result = append (result , baseURL )
5558 }
5659
@@ -76,8 +79,9 @@ func TestRoundRobin(t *testing.T) {
7679
7780 runCount := 5
7881 var result []string
82+ ctx := context .Background ()
7983 for i := 0 ; i < runCount ; i ++ {
80- baseURL , _ := rr .Next ( )
84+ baseURL , _ := rr .NextWithContext ( ctx )
8185 result = append (result , baseURL )
8286 }
8387
@@ -93,6 +97,43 @@ func TestRoundRobin(t *testing.T) {
9397 rr .Feedback (& RequestFeedback {})
9498 rr .Close ()
9599 })
100+
101+ t .Run ("NextWithContext context cancellation" , func (t * testing.T ) {
102+ rr , _ := NewRoundRobin ("https://example.com" )
103+ ctx , cancel := context .WithCancel (context .Background ())
104+ cancel ()
105+ _ , err := rr .NextWithContext (ctx )
106+ assertErrorIs (t , context .Canceled , err )
107+ })
108+
109+ t .Run ("NextWithContext normal operation" , func (t * testing.T ) {
110+ rr , _ := NewRoundRobin ("https://example1.com" , "https://example2.com" )
111+ ctx := context .Background ()
112+ url1 , err := rr .NextWithContext (ctx )
113+ assertNil (t , err )
114+ url2 , err := rr .NextWithContext (ctx )
115+ assertNil (t , err )
116+ assertNotEqual (t , url1 , url2 )
117+ })
118+ }
119+
120+ func TestRoundRobinNoBaseURLs (t * testing.T ) {
121+ t .Run ("new round robin no base urls" , func (t * testing.T ) {
122+ rr , err := NewRoundRobin ()
123+ assertErrorIs (t , ErrNoBaseURLs , err )
124+ assertNil (t , rr )
125+ })
126+
127+ t .Run ("new round robin no base urls on next with context" , func (t * testing.T ) {
128+ rr , err := NewRoundRobin ("https://example1.com" )
129+ assertNil (t , err )
130+ assertNotNil (t , rr )
131+
132+ rr .Refresh ()
133+ ctx := context .Background ()
134+ _ , err = rr .NextWithContext (ctx )
135+ assertErrorIs (t , ErrNoBaseURLs , err )
136+ })
96137}
97138
98139func TestWeightedRoundRobin (t * testing.T ) {
@@ -109,8 +150,9 @@ func TestWeightedRoundRobin(t *testing.T) {
109150
110151 runCount := 5
111152 var result []string
153+ ctx := context .Background ()
112154 for i := 0 ; i < runCount ; i ++ {
113- baseURL , err := wrr .Next ( )
155+ baseURL , err := wrr .NextWithContext ( ctx )
114156 assertNil (t , err )
115157 result = append (result , baseURL )
116158 }
@@ -123,6 +165,8 @@ func TestWeightedRoundRobin(t *testing.T) {
123165 assertEqual (t , runCount , len (expected ))
124166 assertEqual (t , runCount , len (result ))
125167 assertEqual (t , expected , result )
168+
169+ wrr .Feedback (nil )
126170 })
127171
128172 t .Run ("3 hosts with weight {2,1,10}" , func (t * testing.T ) {
@@ -143,8 +187,9 @@ func TestWeightedRoundRobin(t *testing.T) {
143187
144188 runCount := 10
145189 var result []string
190+ ctx := context .Background ()
146191 for i := 0 ; i < runCount ; i ++ {
147- baseURL , err := wrr .Next ( )
192+ baseURL , err := wrr .NextWithContext ( ctx )
148193 assertNil (t , err )
149194 result = append (result , baseURL )
150195 if baseURL == "https://example3.com" && i % 2 != 0 {
@@ -184,8 +229,9 @@ func TestWeightedRoundRobin(t *testing.T) {
184229
185230 runCount := 5
186231 var result []string
232+ ctx := context .Background ()
187233 for i := 0 ; i < runCount ; i ++ {
188- baseURL , err := wrr .Next ( )
234+ baseURL , err := wrr .NextWithContext ( ctx )
189235 assertNil (t , err )
190236 result = append (result , baseURL )
191237 }
@@ -205,9 +251,31 @@ func TestWeightedRoundRobin(t *testing.T) {
205251 assertNil (t , err )
206252 defer wrr .Close ()
207253
208- _ , err = wrr .Next ( )
254+ _ , err = wrr .NextWithContext ( context . Background () )
209255 assertErrorIs (t , ErrNoActiveHost , err )
210256 })
257+
258+ t .Run ("NextWithContext context cancellation" , func (t * testing.T ) {
259+ wrr , _ := NewWeightedRoundRobin (0 , & Host {BaseURL : "https://example.com" , Weight : 1 })
260+ ctx , cancel := context .WithCancel (context .Background ())
261+ cancel ()
262+ _ , err := wrr .NextWithContext (ctx )
263+ assertErrorIs (t , context .Canceled , err )
264+ })
265+
266+ t .Run ("NextWithContext normal operation" , func (t * testing.T ) {
267+ hosts := []* Host {
268+ {BaseURL : "https://example1.com" , Weight : 1 },
269+ {BaseURL : "https://example2.com" , Weight : 1 },
270+ }
271+ wrr , _ := NewWeightedRoundRobin (0 , hosts ... )
272+ ctx := context .Background ()
273+ url1 , err := wrr .NextWithContext (ctx )
274+ assertNil (t , err )
275+ url2 , err := wrr .NextWithContext (ctx )
276+ assertNil (t , err )
277+ assertNotEqual (t , url1 , url2 )
278+ })
211279}
212280
213281func TestSRVWeightedRoundRobin (t * testing.T ) {
@@ -233,8 +301,9 @@ func TestSRVWeightedRoundRobin(t *testing.T) {
233301
234302 runCount := 5
235303 var result []string
304+ ctx := context .Background ()
236305 for i := 0 ; i < runCount ; i ++ {
237- baseURL , err := srv .Next ( )
306+ baseURL , err := srv .NextWithContext ( ctx )
238307 assertNil (t , err )
239308 result = append (result , baseURL )
240309 }
@@ -271,8 +340,9 @@ func TestSRVWeightedRoundRobin(t *testing.T) {
271340
272341 runCount := 5
273342 var result []string
343+ ctx := context .Background ()
274344 for i := 0 ; i < runCount ; i ++ {
275- baseURL , err := srv .Next ( )
345+ baseURL , err := srv .NextWithContext ( ctx )
276346 assertNil (t , err )
277347 result = append (result , baseURL )
278348 }
@@ -315,8 +385,9 @@ func TestSRVWeightedRoundRobin(t *testing.T) {
315385
316386 runCount := 20
317387 var result []string
388+ ctx := context .Background ()
318389 for i := 0 ; i < runCount ; i ++ {
319- baseURL , err := srv .Next ( )
390+ baseURL , err := srv .NextWithContext ( ctx )
320391 assertNil (t , err )
321392 result = append (result , baseURL )
322393
@@ -363,7 +434,7 @@ func TestSRVWeightedRoundRobin(t *testing.T) {
363434
364435 go func () {
365436 for i := 0 ; i < 10 ; i ++ {
366- baseURL , _ := srv .Next ( )
437+ baseURL , _ := srv .NextWithContext ( context . Background () )
367438 assertNotNil (t , baseURL )
368439 time .Sleep (15 * time .Millisecond )
369440 }
@@ -438,7 +509,7 @@ func TestLoadBalancerRequestFlowError(t *testing.T) {
438509 c .SetLoadBalancer (wrr )
439510
440511 resp , err := c .R ().Get ("/" )
441- assertEqual (t , ErrNoActiveHost , err )
512+ assertErrorIs (t , ErrNoActiveHost , err )
442513 assertNil (t , resp )
443514 })
444515
0 commit comments