@@ -10,6 +10,7 @@ import (
10
10
"github.com/google/go-cmp/cmp"
11
11
"go.mongodb.org/mongo-driver/bson/bsontype"
12
12
"go.mongodb.org/mongo-driver/bson/primitive"
13
+ "go.mongodb.org/mongo-driver/internal/testutil/assert"
13
14
"go.mongodb.org/mongo-driver/mongo/readconcern"
14
15
"go.mongodb.org/mongo-driver/mongo/readpref"
15
16
"go.mongodb.org/mongo-driver/mongo/writeconcern"
@@ -518,6 +519,93 @@ func TestOperation(t *testing.T) {
518
519
})
519
520
}
520
521
})
522
+ t .Run ("ExecuteExhaust" , func (t * testing.T ) {
523
+ t .Run ("errors if connection is not streaming" , func (t * testing.T ) {
524
+ conn := & mockConnection {
525
+ rStreaming : false ,
526
+ }
527
+ err := Operation {}.ExecuteExhaust (context .TODO (), conn , nil )
528
+ assert .NotNil (t , err , "expected error, got nil" )
529
+ })
530
+ })
531
+ t .Run ("exhaustAllowed and moreToCome" , func (t * testing.T ) {
532
+ // Test the interaction between exhaustAllowed and moreToCome on requests/responses when using the Execute
533
+ // and ExecuteExhaust methods.
534
+
535
+ // Create a server response wire message that has moreToCome=false.
536
+ serverResponseDoc := bsoncore .BuildDocumentFromElements (nil ,
537
+ bsoncore .AppendInt32Element (nil , "ok" , 1 ),
538
+ )
539
+ nonStreamingResponse := createExhaustServerResponse (t , serverResponseDoc , false )
540
+
541
+ // Create a connection that reports that it cannot stream messages.
542
+ conn := & mockConnection {
543
+ rDesc : description.Server {
544
+ WireVersion : & description.VersionRange {
545
+ Max : 6 ,
546
+ },
547
+ },
548
+ rReadWM : nonStreamingResponse ,
549
+ rCanStream : false ,
550
+ }
551
+ op := Operation {
552
+ CommandFn : func (dst []byte , desc description.SelectedServer ) ([]byte , error ) {
553
+ return bsoncore .AppendInt32Element (dst , "isMaster" , 1 ), nil
554
+ },
555
+ Database : "admin" ,
556
+ Deployment : SingleConnectionDeployment {conn },
557
+ }
558
+ err := op .Execute (context .TODO (), nil )
559
+ assert .Nil (t , err , "Execute error: %v" , err )
560
+
561
+ // The wire message sent to the server should not have exhaustAllowed=true. After execution, the connection
562
+ // should not be in a streaming state.
563
+ assertExhaustAllowedSet (t , conn .pWriteWM , false )
564
+ assert .False (t , conn .CurrentlyStreaming (), "expected CurrentlyStreaming to be false" )
565
+
566
+ // Modify the connection to report that it can stream and create a new server response with moreToCome=true.
567
+ streamingResponse := createExhaustServerResponse (t , serverResponseDoc , true )
568
+ conn .rReadWM = streamingResponse
569
+ conn .rCanStream = true
570
+ err = op .Execute (context .TODO (), nil )
571
+ assert .Nil (t , err , "Execute error: %v" , err )
572
+ assertExhaustAllowedSet (t , conn .pWriteWM , true )
573
+ assert .True (t , conn .CurrentlyStreaming (), "expected CurrentlyStreaming to be true" )
574
+
575
+ // Reset the server response and go through ExecuteExhaust to mimic streaming the next response. After
576
+ // execution, the connection should still be in a streaming state.
577
+ conn .rReadWM = streamingResponse
578
+ err = op .ExecuteExhaust (context .TODO (), conn , nil )
579
+ assert .Nil (t , err , "ExecuteExhaust error: %v" , err )
580
+ assert .True (t , conn .CurrentlyStreaming (), "expected CurrentlyStreaming to be true" )
581
+ })
582
+ }
583
+
584
+ func createExhaustServerResponse (t * testing.T , response bsoncore.Document , moreToCome bool ) []byte {
585
+ idx , wm := wiremessage .AppendHeaderStart (nil , 0 , wiremessage .CurrentRequestID ()+ 1 , wiremessage .OpMsg )
586
+ var flags wiremessage.MsgFlag
587
+ if moreToCome {
588
+ flags = wiremessage .MoreToCome
589
+ }
590
+ wm = wiremessage .AppendMsgFlags (wm , flags )
591
+ wm = wiremessage .AppendMsgSectionType (wm , wiremessage .SingleDocument )
592
+ wm = bsoncore .AppendDocument (wm , response )
593
+ return bsoncore .UpdateLength (wm , idx , int32 (len (wm )))
594
+ }
595
+
596
+ func assertExhaustAllowedSet (t * testing.T , wm []byte , expected bool ) {
597
+ t .Helper ()
598
+ _ , _ , _ , _ , wm , ok := wiremessage .ReadHeader (wm )
599
+ if ! ok {
600
+ t .Fatal ("could not read wm header" )
601
+ }
602
+ flags , wm , ok := wiremessage .ReadMsgFlags (wm )
603
+ if ! ok {
604
+ t .Fatal ("could not read wm flags" )
605
+ }
606
+
607
+ actual := flags & wiremessage .ExhaustAllowed > 0
608
+ assert .Equal (t , expected , actual , "expected exhaustAllowed set %v, got %v" , expected , actual )
521
609
}
522
610
523
611
type mockDeployment struct {
@@ -554,19 +642,24 @@ type mockConnection struct {
554
642
pReadDst []byte
555
643
556
644
// returns
557
- rWriteErr error
558
- rReadWM []byte
559
- rReadErr error
560
- rDesc description.Server
561
- rCloseErr error
562
- rID string
563
- rAddr address.Address
645
+ rWriteErr error
646
+ rReadWM []byte
647
+ rReadErr error
648
+ rDesc description.Server
649
+ rCloseErr error
650
+ rID string
651
+ rAddr address.Address
652
+ rCanStream bool
653
+ rStreaming bool
564
654
}
565
655
566
656
func (m * mockConnection ) Description () description.Server { return m .rDesc }
567
657
func (m * mockConnection ) Close () error { return m .rCloseErr }
568
658
func (m * mockConnection ) ID () string { return m .rID }
569
659
func (m * mockConnection ) Address () address.Address { return m .rAddr }
660
+ func (m * mockConnection ) SupportsStreaming () bool { return m .rCanStream }
661
+ func (m * mockConnection ) CurrentlyStreaming () bool { return m .rStreaming }
662
+ func (m * mockConnection ) SetStreaming (streaming bool ) { m .rStreaming = streaming }
570
663
571
664
func (m * mockConnection ) WriteWireMessage (_ context.Context , wm []byte ) error {
572
665
m .pWriteWM = wm
0 commit comments