@@ -412,6 +412,39 @@ func errorFromResult(t testing.TB, result interface{}) *operationError {
412
412
return & expected
413
413
}
414
414
415
+ // errorDetails is a helper type that holds information that can be returned by driver functions in different error
416
+ // types.
417
+ type errorDetails struct {
418
+ name string
419
+ labels []string
420
+ }
421
+
422
+ // extractErrorDetails creates an errorDetails instance based on the provided error. It returns the details and an "ok"
423
+ // value which is true if the provided error is of a known type that can be processed.
424
+ func extractErrorDetails (err error ) (errorDetails , bool ) {
425
+ var details errorDetails
426
+
427
+ switch converted := err .(type ) {
428
+ case mongo.CommandError :
429
+ details .name = converted .Name
430
+ details .labels = converted .Labels
431
+ case mongo.WriteException :
432
+ if converted .WriteConcernError != nil {
433
+ details .name = converted .WriteConcernError .Name
434
+ }
435
+ details .labels = converted .Labels
436
+ case mongo.BulkWriteException :
437
+ if converted .WriteConcernError != nil {
438
+ details .name = converted .WriteConcernError .Name
439
+ }
440
+ details .labels = converted .Labels
441
+ default :
442
+ return errorDetails {}, false
443
+ }
444
+
445
+ return details , true
446
+ }
447
+
415
448
// verify that an error returned by an operation matches the expected error.
416
449
func verifyError (expected * operationError , actual error ) error {
417
450
// The spec test format doesn't treat ErrNoDocuments or ErrUnacknowledgedWrite as errors, so set actual to nil
@@ -439,23 +472,28 @@ func verifyError(expected *operationError, actual error) error {
439
472
}
440
473
}
441
474
442
- cerr , ok := actual .(mongo.CommandError )
475
+ // Get an errorDetails instance for the error. If this fails but the test has expectations about the error name or
476
+ // labels, fail because we can't verify them.
477
+ details , ok := extractErrorDetails (actual )
443
478
if ! ok {
479
+ if expected .ErrorCodeName != nil || len (expected .ErrorLabelsContain ) > 0 || len (expected .ErrorLabelsOmit ) > 0 {
480
+ return fmt .Errorf ("failed to extract details from error %v of type %T" , actual , actual )
481
+ }
444
482
return nil
445
483
}
446
484
447
485
if expected .ErrorCodeName != nil {
448
- if * expected .ErrorCodeName != cerr . Name {
449
- return fmt .Errorf ("expected error name %v, got %v" , * expected .ErrorCodeName , cerr . Name )
486
+ if * expected .ErrorCodeName != details . name {
487
+ return fmt .Errorf ("expected error name %v, got %v" , * expected .ErrorCodeName , details . name )
450
488
}
451
489
}
452
490
for _ , label := range expected .ErrorLabelsContain {
453
- if ! cerr . HasErrorLabel ( label ) {
491
+ if ! stringSliceContains ( details . labels , label ) {
454
492
return fmt .Errorf ("expected error %v to contain label %q" , actual , label )
455
493
}
456
494
}
457
495
for _ , label := range expected .ErrorLabelsOmit {
458
- if cerr . HasErrorLabel ( label ) {
496
+ if stringSliceContains ( details . labels , label ) {
459
497
return fmt .Errorf ("expected error %v to not contain label %q" , actual , label )
460
498
}
461
499
}
@@ -552,3 +590,12 @@ func convertValueToMilliseconds(t testing.TB, val bson.RawValue) time.Duration {
552
590
}
553
591
return time .Duration (int32Val ) * time .Millisecond
554
592
}
593
+
594
+ func stringSliceContains (stringSlice []string , target string ) bool {
595
+ for _ , str := range stringSlice {
596
+ if str == target {
597
+ return true
598
+ }
599
+ }
600
+ return false
601
+ }
0 commit comments