@@ -454,7 +454,7 @@ func TestErrorCleanUpBeforeFirstMessage(t *testing.T) {
454454 waitForLength (t , testClient .MessageLen , messageTypeInsert , 0 )
455455
456456 close (ch )
457- requireErrorCount (t , 1 , errCh )
457+ requireErrorCount (t , errCh , 1 , 1 )
458458}
459459
460460func TestErrorCleanUpFirstMessage (t * testing.T ) {
@@ -494,7 +494,7 @@ func TestErrorCleanUpFirstMessage(t *testing.T) {
494494 waitForLength (t , testClient .InflightLen , messageTypeInsert , 1 )
495495
496496 close (ch )
497- requireErrorCount (t , 1 , errCh )
497+ requireErrorCount (t , errCh , 1 , 1 )
498498}
499499
500500func TestErrorCleanUpSecondMessage (t * testing.T ) {
@@ -532,9 +532,9 @@ func TestErrorCleanUpSecondMessage(t *testing.T) {
532532 <- done
533533
534534 close (ch )
535- requireErrorCount (t , 1 , errCh )
535+ numErrs := requireErrorCount (t , errCh , 1 , 2 ) // can have 2 errors depending on processing order
536536
537- waitForLength (t , testClient .InflightLen , messageTypeInsert , 2 ) // testStreamingBatchClient doesn't commit the batch before erroring
537+ waitForLength (t , testClient .InflightLen , messageTypeInsert , 1 + numErrs ) // testStreamingBatchClient doesn't commit the batch before erroring
538538 waitForLength (t , testClient .MessageLen , messageTypeInsert , 0 )
539539}
540540
@@ -568,7 +568,7 @@ func TestErrorCleanUpAfterClose(t *testing.T) {
568568 waitForLength (t , testClient .InflightLen , messageTypeInsert , 10 )
569569 close (ch )
570570
571- requireErrorCount (t , 1 , errCh )
571+ requireErrorCount (t , errCh , 1 , 1 )
572572
573573 waitForLength (t , testClient .MessageLen , messageTypeInsert , 0 ) // batch size 1
574574}
@@ -602,7 +602,7 @@ func getRecord(sc *arrow.Schema, rows int) arrow.Record {
602602}
603603
604604// nolint:unparam
605- func requireErrorCount (t * testing.T , expected int , errCh chan error ) {
605+ func requireErrorCount (t * testing.T , errCh chan error , expectedMin , expectedMax int ) int {
606606 t .Helper ()
607607 select {
608608 case <- time .After (5 * time .Second ):
@@ -614,8 +614,13 @@ func requireErrorCount(t *testing.T, expected int, errCh chan error) {
614614 }
615615
616616 errs := jointErrs .Unwrap ()
617- if l := len (errs ); l != expected {
618- t .Fatalf ("expected %d errors, got %d: %v" , expected , l , errs )
617+ l := len (errs )
618+ if expectedMin == expectedMax && l != expectedMin {
619+ t .Fatalf ("expected %d errors, got %d: %v" , expectedMin , l , errs )
620+ } else if l < expectedMin || l > expectedMax {
621+ t .Fatalf ("expected between %d and %d errors, got %d: %v" , expectedMin , expectedMax , l , errs )
619622 }
623+ return l
620624 }
625+ return - 1
621626}
0 commit comments