@@ -400,20 +400,49 @@ func TestSimpleRetryPolicy(t *testing.T) {
400400 // this should allow a total of 3 tries.
401401 rt := & SimpleRetryPolicy {NumRetries : 2 }
402402
403+ regular_error := errors .New ("regular error" )
404+
405+ qe1 := & QueryError {
406+ err : errors .New ("connection error" ),
407+ potentiallyExecuted : false ,
408+ isIdempotent : false ,
409+ }
410+
411+ qe2 := & QueryError {
412+ err : errors .New ("timeout error" ),
413+ potentiallyExecuted : true ,
414+ isIdempotent : true ,
415+ }
416+
417+ qe3 := & QueryError {
418+ err : errors .New ("write timeout" ),
419+ potentiallyExecuted : true ,
420+ isIdempotent : false ,
421+ }
422+
403423 cases := []struct {
404- attempts int
405- allow bool
424+ attempts int
425+ allow bool
426+ err error
427+ retryType RetryType
428+ LWTRetryType RetryType
406429 }{
407- {0 , true },
408- {1 , true },
409- {2 , true },
410- {3 , false },
411- {4 , false },
412- {5 , false },
430+ {0 , true , qe1 , RetryNextHost , Retry },
431+ {1 , true , qe2 , RetryNextHost , Retry },
432+ {2 , true , qe3 , Rethrow , Rethrow },
433+ {3 , false , regular_error , RetryNextHost , Retry },
434+ {4 , false , regular_error , RetryNextHost , Retry },
435+ {5 , false , regular_error , RetryNextHost , Retry },
413436 }
414437
415438 for _ , c := range cases {
416439 q .metrics = preFilledQueryMetrics (map [string ]* hostMetrics {"127.0.0.1" : {Attempts : c .attempts }})
440+ if c .retryType != rt .GetRetryType (c .err ) {
441+ t .Fatalf ("retry type for %v should be %v" , c .err , c .retryType )
442+ }
443+ if c .LWTRetryType != rt .GetRetryTypeLWT (c .err ) {
444+ t .Fatalf ("LWT retry type for %v should be %v" , c .err , c .LWTRetryType )
445+ }
417446 if c .allow && ! rt .Attempt (q ) {
418447 t .Fatalf ("should allow retry after %d attempts" , c .attempts )
419448 }
@@ -439,17 +468,45 @@ func TestExponentialBackoffPolicy(t *testing.T) {
439468 // test with defaults
440469 sut := & ExponentialBackoffRetryPolicy {NumRetries : 2 }
441470
471+ regular_error := errors .New ("regular error" )
472+
473+ qe1 := & QueryError {
474+ err : errors .New ("connection error" ),
475+ potentiallyExecuted : false ,
476+ isIdempotent : false ,
477+ }
478+
479+ qe2 := & QueryError {
480+ err : errors .New ("timeout error" ),
481+ potentiallyExecuted : true ,
482+ isIdempotent : true ,
483+ }
484+
485+ qe3 := & QueryError {
486+ err : errors .New ("write timeout" ),
487+ potentiallyExecuted : true ,
488+ isIdempotent : false ,
489+ }
490+
442491 cases := []struct {
443- attempts int
444- delay time.Duration
492+ attempts int
493+ delay time.Duration
494+ err error
495+ retryType RetryType
496+ LWTRetryType RetryType
445497 }{
446-
447- {1 , 100 * time .Millisecond },
448- {2 , (2 ) * 100 * time .Millisecond },
449- {3 , (2 * 2 ) * 100 * time .Millisecond },
450- {4 , (2 * 2 * 2 ) * 100 * time .Millisecond },
498+ {1 , 100 * time .Millisecond , qe1 , RetryNextHost , Retry },
499+ {2 , (2 ) * 100 * time .Millisecond , qe2 , RetryNextHost , Retry },
500+ {3 , (2 * 2 ) * 100 * time .Millisecond , qe3 , Rethrow , Rethrow },
501+ {4 , (2 * 2 * 2 ) * 100 * time .Millisecond , regular_error , RetryNextHost , Retry },
451502 }
452503 for _ , c := range cases {
504+ if c .retryType != sut .GetRetryType (c .err ) {
505+ t .Fatalf ("retry type for %v should be %v" , c .err , c .retryType )
506+ }
507+ if c .LWTRetryType != sut .GetRetryTypeLWT (c .err ) {
508+ t .Fatalf ("LWT retry type for %v should be %v" , c .err , c .LWTRetryType )
509+ }
453510 // test 100 times for each case
454511 for i := 0 ; i < 100 ; i ++ {
455512 d := sut .napTime (c .attempts )
0 commit comments