@@ -2,6 +2,8 @@ package device
22
33import (
44 "bytes"
5+ "context"
6+ "errors"
57 "io/ioutil"
68 "net/http"
79 "net/url"
@@ -230,28 +232,16 @@ func TestRequestCode(t *testing.T) {
230232}
231233
232234func TestPollToken (t * testing.T ) {
233- var totalSlept time.Duration
234- mockSleep := func (d time.Duration ) {
235- totalSlept += d
236- }
237- duration := func (d string ) time.Duration {
238- res , _ := time .ParseDuration (d )
239- return res
240- }
241- clock := func (durations ... string ) func () time.Time {
242- count := 0
243- now := time .Now ()
244- return func () time.Time {
245- t := now .Add (duration (durations [count ]))
246- count ++
247- return t
235+ makeFakePoller := func (maxWaits int ) pollerFactory {
236+ return func (ctx context.Context , interval , expiresIn time.Duration ) (context.Context , poller ) {
237+ return ctx , & fakePoller {maxWaits : maxWaits }
248238 }
249239 }
250240
251241 type args struct {
252242 http apiClient
253243 url string
254- opts PollOptions
244+ opts WaitOptions
255245 }
256246 tests := []struct {
257247 name string
@@ -279,7 +269,7 @@ func TestPollToken(t *testing.T) {
279269 },
280270 },
281271 url : "https://github.com/oauth" ,
282- opts : PollOptions {
272+ opts : WaitOptions {
283273 ClientID : "CLIENT-ID" ,
284274 DeviceCode : & CodeResponse {
285275 DeviceCode : "DEVIC" ,
@@ -288,14 +278,12 @@ func TestPollToken(t *testing.T) {
288278 ExpiresIn : 99 ,
289279 Interval : 5 ,
290280 },
291- timeSleep : mockSleep ,
292- timeNow : clock ("0" , "5s" , "10s" ),
281+ newPoller : makeFakePoller (2 ),
293282 },
294283 },
295284 want : & api.AccessToken {
296285 Token : "123abc" ,
297286 },
298- slept : duration ("10s" ),
299287 posts : []postArgs {
300288 {
301289 url : "https://github.com/oauth" ,
@@ -328,7 +316,7 @@ func TestPollToken(t *testing.T) {
328316 },
329317 },
330318 url : "https://github.com/oauth" ,
331- opts : PollOptions {
319+ opts : WaitOptions {
332320 ClientID : "CLIENT-ID" ,
333321 ClientSecret : "SEKRIT" ,
334322 GrantType : "device_code" ,
@@ -339,14 +327,12 @@ func TestPollToken(t *testing.T) {
339327 ExpiresIn : 99 ,
340328 Interval : 5 ,
341329 },
342- timeSleep : mockSleep ,
343- timeNow : clock ("0" , "5s" , "10s" ),
330+ newPoller : makeFakePoller (1 ),
344331 },
345332 },
346333 want : & api.AccessToken {
347334 Token : "123abc" ,
348335 },
349- slept : duration ("5s" ),
350336 posts : []postArgs {
351337 {
352338 url : "https://github.com/oauth" ,
@@ -377,21 +363,19 @@ func TestPollToken(t *testing.T) {
377363 },
378364 },
379365 url : "https://github.com/oauth" ,
380- opts : PollOptions {
366+ opts : WaitOptions {
381367 ClientID : "CLIENT-ID" ,
382368 DeviceCode : & CodeResponse {
383369 DeviceCode : "DEVIC" ,
384370 UserCode : "123-abc" ,
385371 VerificationURI : "http://verify.me" ,
386- ExpiresIn : 99 ,
372+ ExpiresIn : 14 ,
387373 Interval : 5 ,
388374 },
389- timeSleep : mockSleep ,
390- timeNow : clock ("0" , "5s" , "15m" ),
375+ newPoller : makeFakePoller (2 ),
391376 },
392377 },
393- wantErr : "authentication timed out" ,
394- slept : duration ("10s" ),
378+ wantErr : "context deadline exceeded" ,
395379 posts : []postArgs {
396380 {
397381 url : "https://github.com/oauth" ,
@@ -424,7 +408,7 @@ func TestPollToken(t *testing.T) {
424408 },
425409 },
426410 url : "https://github.com/oauth" ,
427- opts : PollOptions {
411+ opts : WaitOptions {
428412 ClientID : "CLIENT-ID" ,
429413 DeviceCode : & CodeResponse {
430414 DeviceCode : "DEVIC" ,
@@ -433,12 +417,10 @@ func TestPollToken(t *testing.T) {
433417 ExpiresIn : 99 ,
434418 Interval : 5 ,
435419 },
436- timeSleep : mockSleep ,
437- timeNow : clock ("0" , "5s" ),
420+ newPoller : makeFakePoller (1 ),
438421 },
439422 },
440423 wantErr : "access_denied" ,
441- slept : duration ("5s" ),
442424 posts : []postArgs {
443425 {
444426 url : "https://github.com/oauth" ,
@@ -453,8 +435,7 @@ func TestPollToken(t *testing.T) {
453435 }
454436 for _ , tt := range tests {
455437 t .Run (tt .name , func (t * testing.T ) {
456- totalSlept = 0
457- got , err := PollTokenWithOptions (& tt .args .http , tt .args .url , tt .args .opts )
438+ got , err := Wait (context .Background (), & tt .args .http , tt .args .url , tt .args .opts )
458439 if (err != nil ) != (tt .wantErr != "" ) {
459440 t .Errorf ("PollToken() error = %v, wantErr %v" , err , tt .wantErr )
460441 return
@@ -468,9 +449,22 @@ func TestPollToken(t *testing.T) {
468449 if ! reflect .DeepEqual (tt .args .http .calls , tt .posts ) {
469450 t .Errorf ("PostForm() = %v, want %v" , tt .args .http .calls , tt .posts )
470451 }
471- if totalSlept != tt .slept {
472- t .Errorf ("slept %v, wanted %v" , totalSlept , tt .slept )
473- }
474452 })
475453 }
476454}
455+
456+ type fakePoller struct {
457+ maxWaits int
458+ count int
459+ }
460+
461+ func (p * fakePoller ) Wait () error {
462+ if p .count == p .maxWaits {
463+ return errors .New ("context deadline exceeded" )
464+ }
465+ p .count ++
466+ return nil
467+ }
468+
469+ func (p * fakePoller ) Cancel () {
470+ }
0 commit comments