7
7
package integration
8
8
9
9
import (
10
- "bytes"
11
- "context"
12
10
"fmt"
13
- "io"
14
- "net"
15
11
"path"
16
12
"reflect"
17
13
"strings"
18
- "sync"
19
14
"testing"
20
15
"time"
21
16
@@ -29,7 +24,6 @@ import (
29
24
"go.mongodb.org/mongo-driver/mongo/readpref"
30
25
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
31
26
"go.mongodb.org/mongo-driver/x/mongo/driver"
32
- "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
33
27
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
34
28
)
35
29
@@ -381,37 +375,31 @@ func TestClient(t *testing.T) {
381
375
})
382
376
383
377
testAppName := "foo"
384
- hosts := options .Client ().ApplyURI (mt .ConnString ()).Hosts
385
- appNameProxyDialer := newProxyDialer ()
386
- appNameDialerOpts := options .Client ().
387
- SetDialer (appNameProxyDialer ).
388
- SetHosts (hosts [:1 ]).
389
- SetDirect (true ).
378
+ appNameClientOpts := options .Client ().
390
379
SetAppName (testAppName )
391
380
appNameMtOpts := mtest .NewOptions ().
392
- ClientOptions ( appNameDialerOpts ).
393
- Topologies ( mtest . Single ).
394
- Auth ( false ) // Can't run with auth because the proxy dialer won't work with TLS enabled.
381
+ ClientType ( mtest . Proxy ).
382
+ ClientOptions ( appNameClientOpts ).
383
+ Topologies ( mtest . Single )
395
384
mt .RunOpts ("app name is always sent" , appNameMtOpts , func (mt * mtest.T ) {
396
385
err := mt .Client .Ping (mtest .Background , mtest .PrimaryRp )
397
386
assert .Nil (mt , err , "Ping error: %v" , err )
398
387
399
- msgPairs := appNameProxyDialer . messages
388
+ msgPairs := mt . GetProxiedMessages ()
400
389
assert .True (mt , len (msgPairs ) >= 2 , "expected at least 2 events sent, got %v" , len (msgPairs ))
401
390
402
391
// First two messages should be connection handshakes: one for the heartbeat connection and the other for the
403
392
// application connection.
404
393
for idx , pair := range msgPairs [:2 ] {
405
- cmd , err := drivertest .GetCommandFromQueryWireMessage (pair .sent )
406
- assert .Nil (mt , err , "GetCommandFromQueryWireMessage error at index %d: %v" , idx , err )
407
- heartbeatCmdName := cmd .Index (0 ).Key ()
408
- assert .Equal (mt , "isMaster" , heartbeatCmdName ,
409
- "expected command name isMaster at index %d, got %v" , idx , heartbeatCmdName )
410
-
411
- appNameVal , err := cmd .LookupErr ("client" , "application" , "name" )
412
- assert .Nil (mt , err , "expected command %s at index %d to contain app name" , cmd , idx )
394
+ assert .Equal (mt , pair .CommandName , "isMaster" , "expected command name isMaster at index %d, got %s" , idx ,
395
+ pair .CommandName )
396
+
397
+ sent := pair .Sent
398
+ appNameVal , err := sent .Command .LookupErr ("client" , "application" , "name" )
399
+ assert .Nil (mt , err , "expected command %s at index %d to contain app name" , sent .Command , idx )
413
400
appName := appNameVal .StringValue ()
414
- assert .Equal (mt , testAppName , appName , "expected app name %v at index %d, got %v" , testAppName , idx , appName )
401
+ assert .Equal (mt , testAppName , appName , "expected app name %v at index %d, got %v" , testAppName , idx ,
402
+ appName )
415
403
}
416
404
})
417
405
@@ -446,101 +434,3 @@ type proxyMessage struct {
446
434
sent wiremessage.WireMessage
447
435
received wiremessage.WireMessage
448
436
}
449
-
450
- // proxyDialer is a ContextDialer implementation that wraps a net.Dialer and records the messages sent and received
451
- // using connections created through it.
452
- type proxyDialer struct {
453
- * net.Dialer
454
- sync.Mutex
455
- messages []proxyMessage
456
- sentMap sync.Map
457
- }
458
-
459
- var _ options.ContextDialer = (* proxyDialer )(nil )
460
-
461
- func newProxyDialer () * proxyDialer {
462
- return & proxyDialer {
463
- Dialer : & net.Dialer {Timeout : 30 * time .Second },
464
- }
465
- }
466
-
467
- // DialContext creates a new proxyConnection.
468
- func (p * proxyDialer ) DialContext (ctx context.Context , network , address string ) (net.Conn , error ) {
469
- netConn , err := p .Dialer .DialContext (ctx , network , address )
470
- if err != nil {
471
- return netConn , err
472
- }
473
-
474
- proxy := & proxyConn {
475
- Conn : netConn ,
476
- dialer : p ,
477
- currentReading : bytes .NewBuffer (nil ),
478
- }
479
- return proxy , nil
480
- }
481
-
482
- // storeSentMessage stores a copy of the wire message being sent to the server.
483
- func (p * proxyDialer ) storeSentMessage (msg []byte ) {
484
- p .Lock ()
485
- defer p .Unlock ()
486
-
487
- msgCopy := make (wiremessage.WireMessage , len (msg ))
488
- copy (msgCopy , msg )
489
-
490
- _ , requestID , _ , _ , _ , _ := wiremessage .ReadHeader (msgCopy )
491
- p .sentMap .Store (requestID , msgCopy )
492
- }
493
-
494
- // storeReceivedMessage stores a copy of the wire message being received from the server.
495
- func (p * proxyDialer ) storeReceivedMessage (msg []byte ) {
496
- p .Lock ()
497
- defer p .Unlock ()
498
-
499
- msgCopy := make (wiremessage.WireMessage , len (msg ))
500
- copy (msgCopy , msg )
501
-
502
- _ , _ , responseTo , _ , _ , _ := wiremessage .ReadHeader (msgCopy )
503
- sentMsg , _ := p .sentMap .Load (responseTo )
504
- p .sentMap .Delete (responseTo )
505
-
506
- proxyMsg := proxyMessage {
507
- sent : sentMsg .(wiremessage.WireMessage ),
508
- received : msgCopy ,
509
- }
510
- p .messages = append (p .messages , proxyMsg )
511
- }
512
-
513
- // proxyConn is a net.Conn that wraps a network connection. All messages sent/received through a proxyConn are stored
514
- // in the associated proxyDialer and are forwarded over the wrapped connection.
515
- type proxyConn struct {
516
- net.Conn
517
- dialer * proxyDialer
518
- currentReading * bytes.Buffer // The current message being read.
519
- }
520
-
521
- // Write stores the given message in the proxyDialer associated with this connection and forwards the message to the
522
- // server.
523
- func (pc * proxyConn ) Write (msg []byte ) (n int , err error ) {
524
- pc .dialer .storeSentMessage (msg )
525
- return pc .Conn .Write (msg )
526
- }
527
-
528
- // Read reads the message from the server into the given buffer and stores the read message in the proxyDialer
529
- // associated with this connection.
530
- func (pc * proxyConn ) Read (buffer []byte ) (int , error ) {
531
- n , err := pc .Conn .Read (buffer )
532
- if err != nil {
533
- return n , err
534
- }
535
-
536
- _ , err = io .Copy (pc .currentReading , bytes .NewReader (buffer ))
537
- if err != nil {
538
- return 0 , fmt .Errorf ("error copying to mock: %v" , err )
539
- }
540
- if len (buffer ) != 4 {
541
- pc .dialer .storeReceivedMessage (pc .currentReading .Bytes ())
542
- pc .currentReading .Reset ()
543
- }
544
-
545
- return n , err
546
- }
0 commit comments