@@ -71,7 +71,7 @@ func TestClient(t *testing.T) {
71
71
72
72
assert .Equal (mt , int64 (- 10 ), got .ID , "expected ID -10, got %v" , got .ID )
73
73
})
74
- mt .RunOpts ("tls connection" , mtest .NewOptions ().MinServerVersion ("3.0" ).Auth (true ), func (mt * mtest.T ) {
74
+ mt .RunOpts ("tls connection" , mtest .NewOptions ().MinServerVersion ("3.0" ).SSL (true ), func (mt * mtest.T ) {
75
75
var result bson.Raw
76
76
err := mt .Coll .Database ().RunCommand (mtest .Background , bson.D {
77
77
{"serverStatus" , 1 },
@@ -86,7 +86,7 @@ func TestClient(t *testing.T) {
86
86
_ , found = security .Document ().LookupErr ("SSLServerHasCertificateAuthority" )
87
87
assert .Nil (mt , found , "SSLServerHasCertificateAuthority not found in result" )
88
88
})
89
- mt .RunOpts ("x509" , mtest .NewOptions ().Auth (true ), func (mt * mtest.T ) {
89
+ mt .RunOpts ("x509" , mtest .NewOptions ().Auth (true ). SSL ( true ) , func (mt * mtest.T ) {
90
90
const user = "C=US,ST=New York,L=New York City,O=MongoDB,OU=other,CN=external"
91
91
db := mt .Client .Database ("$external" )
92
92
@@ -396,13 +396,13 @@ func TestClient(t *testing.T) {
396
396
err := mt .Client .Ping (mtest .Background , mtest .PrimaryRp )
397
397
assert .Nil (mt , err , "Ping error: %v" , err )
398
398
399
- sent := appNameProxyDialer .sent
400
- assert .True (mt , len (sent ) >= 2 , "expected at least 2 events sent, got %v" , len (sent ))
399
+ msgPairs := appNameProxyDialer .messages
400
+ assert .True (mt , len (msgPairs ) >= 2 , "expected at least 2 events sent, got %v" , len (msgPairs ))
401
401
402
402
// First two messages should be connection handshakes: one for the heartbeat connection and the other for the
403
403
// application connection.
404
- for idx , wm := range sent [:2 ] {
405
- cmd , err := drivertest .GetCommandFromQueryWireMessage (wm )
404
+ for idx , pair := range msgPairs [:2 ] {
405
+ cmd , err := drivertest .GetCommandFromQueryWireMessage (pair . sent )
406
406
assert .Nil (mt , err , "GetCommandFromQueryWireMessage error at index %d: %v" , idx , err )
407
407
heartbeatCmdName := cmd .Index (0 ).Key ()
408
408
assert .Equal (mt , "isMaster" , heartbeatCmdName ,
@@ -441,13 +441,19 @@ func TestClient(t *testing.T) {
441
441
})
442
442
}
443
443
444
+ type proxyMessage struct {
445
+ serverAddress string
446
+ sent wiremessage.WireMessage
447
+ received wiremessage.WireMessage
448
+ }
449
+
444
450
// proxyDialer is a ContextDialer implementation that wraps a net.Dialer and records the messages sent and received
445
451
// using connections created through it.
446
452
type proxyDialer struct {
447
453
* net.Dialer
448
454
sync.Mutex
449
- sent []wiremessage. WireMessage
450
- received []wiremessage. WireMessage
455
+ messages [] proxyMessage
456
+ sentMap sync. Map
451
457
}
452
458
453
459
var _ options.ContextDialer = (* proxyDialer )(nil )
@@ -480,7 +486,9 @@ func (p *proxyDialer) storeSentMessage(msg []byte) {
480
486
481
487
msgCopy := make (wiremessage.WireMessage , len (msg ))
482
488
copy (msgCopy , msg )
483
- p .sent = append (p .sent , msgCopy )
489
+
490
+ _ , requestID , _ , _ , _ , _ := wiremessage .ReadHeader (msgCopy )
491
+ p .sentMap .Store (requestID , msgCopy )
484
492
}
485
493
486
494
// storeReceivedMessage stores a copy of the wire message being received from the server.
@@ -490,7 +498,16 @@ func (p *proxyDialer) storeReceivedMessage(msg []byte) {
490
498
491
499
msgCopy := make (wiremessage.WireMessage , len (msg ))
492
500
copy (msgCopy , msg )
493
- p .received = append (p .received , msgCopy )
501
+
502
+ _ , _ , responseTo , _ , _ , _ := wiremessage .ReadHeader (msgCopy )
503
+ sentMsg , _ := p .sentMap .Load (responseTo )
504
+ p .sentMap .Delete (responseTo )
505
+
506
+ proxyMsg := proxyMessage {
507
+ sent : sentMsg .(wiremessage.WireMessage ),
508
+ received : msgCopy ,
509
+ }
510
+ p .messages = append (p .messages , proxyMsg )
494
511
}
495
512
496
513
// proxyConn is a net.Conn that wraps a network connection. All messages sent/received through a proxyConn are stored
0 commit comments