7
7
package integration
8
8
9
9
import (
10
+ "bytes"
11
+ "context"
10
12
"fmt"
13
+ "io"
14
+ "net"
11
15
"path"
12
16
"reflect"
13
17
"strings"
18
+ "sync"
14
19
"testing"
15
20
"time"
16
21
@@ -23,6 +28,8 @@ import (
23
28
"go.mongodb.org/mongo-driver/mongo/options"
24
29
"go.mongodb.org/mongo-driver/mongo/readpref"
25
30
"go.mongodb.org/mongo-driver/x/mongo/driver"
31
+ "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
32
+ "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
26
33
)
27
34
28
35
const certificatesDir = "../../data/certificates"
@@ -333,4 +340,125 @@ func TestClient(t *testing.T) {
333
340
})
334
341
}
335
342
})
343
+
344
+ testAppName := "foo"
345
+ hosts := options .Client ().ApplyURI (mt .ConnString ()).Hosts
346
+ appNameProxyDialer := newProxyDialer ()
347
+ appNameDialerOpts := options .Client ().
348
+ SetDialer (appNameProxyDialer ).
349
+ SetHosts (hosts [:1 ]).
350
+ SetDirect (true ).
351
+ SetAppName (testAppName )
352
+ appNameMtOpts := mtest .NewOptions ().
353
+ ClientOptions (appNameDialerOpts ).
354
+ Topologies (mtest .Single )
355
+ mt .RunOpts ("app name is always sent" , appNameMtOpts , func (mt * mtest.T ) {
356
+ err := mt .Client .Ping (mtest .Background , mtest .PrimaryRp )
357
+ assert .Nil (mt , err , "Ping error: %v" , err )
358
+
359
+ sent := appNameProxyDialer .sent
360
+ assert .True (mt , len (sent ) >= 2 , "expected at least 2 events sent, got %v" , len (sent ))
361
+
362
+ // First two messages should be connection handshakes: one for the heartbeat connection and the other for the
363
+ // application connection.
364
+ for idx , wm := range sent [:2 ] {
365
+ cmd , err := drivertest .GetCommandFromQueryWireMessage (wm )
366
+ assert .Nil (mt , err , "GetCommandFromQueryWireMessage error at index %d: %v" , idx , err )
367
+ heartbeatCmdName := cmd .Index (0 ).Key ()
368
+ assert .Equal (mt , "isMaster" , heartbeatCmdName ,
369
+ "expected command name isMaster at index %d, got %v" , idx , heartbeatCmdName )
370
+
371
+ appNameVal , err := cmd .LookupErr ("client" , "application" , "name" )
372
+ assert .Nil (mt , err , "expected command %s at index %d to contain app name" , cmd , idx )
373
+ appName := appNameVal .StringValue ()
374
+ assert .Equal (mt , testAppName , appName , "expected app name %v at index %d, got %v" , testAppName , idx , appName )
375
+ }
376
+ })
377
+ }
378
+
379
+ // proxyDialer is a ContextDialer implementation that wraps a net.Dialer and records the messages sent and received
380
+ // using connections created through it.
381
+ type proxyDialer struct {
382
+ * net.Dialer
383
+ sync.Mutex
384
+ sent []wiremessage.WireMessage
385
+ received []wiremessage.WireMessage
386
+ }
387
+
388
+ var _ options.ContextDialer = (* proxyDialer )(nil )
389
+
390
+ func newProxyDialer () * proxyDialer {
391
+ return & proxyDialer {
392
+ Dialer : & net.Dialer {Timeout : 30 * time .Second },
393
+ }
394
+ }
395
+
396
+ // DialContext creates a new proxyConnection.
397
+ func (p * proxyDialer ) DialContext (ctx context.Context , network , address string ) (net.Conn , error ) {
398
+ netConn , err := p .Dialer .DialContext (ctx , network , address )
399
+ if err != nil {
400
+ return netConn , err
401
+ }
402
+
403
+ proxy := & proxyConn {
404
+ Conn : netConn ,
405
+ dialer : p ,
406
+ currentReading : bytes .NewBuffer (nil ),
407
+ }
408
+ return proxy , nil
409
+ }
410
+
411
+ // storeSentMessage stores a copy of the wire message being sent to the server.
412
+ func (p * proxyDialer ) storeSentMessage (msg []byte ) {
413
+ p .Lock ()
414
+ defer p .Unlock ()
415
+
416
+ msgCopy := make (wiremessage.WireMessage , len (msg ))
417
+ copy (msgCopy , msg )
418
+ p .sent = append (p .sent , msgCopy )
419
+ }
420
+
421
+ // storeReceivedMessage stores a copy of the wire message being received from the server.
422
+ func (p * proxyDialer ) storeReceivedMessage (msg []byte ) {
423
+ p .Lock ()
424
+ defer p .Unlock ()
425
+
426
+ msgCopy := make (wiremessage.WireMessage , len (msg ))
427
+ copy (msgCopy , msg )
428
+ p .received = append (p .received , msgCopy )
429
+ }
430
+
431
+ // proxyConn is a net.Conn that wraps a network connection. All messages sent/received through a proxyConn are stored
432
+ // in the associated proxyDialer and are forwarded over the wrapped connection.
433
+ type proxyConn struct {
434
+ net.Conn
435
+ dialer * proxyDialer
436
+ currentReading * bytes.Buffer // The current message being read.
437
+ }
438
+
439
+ // Write stores the given message in the proxyDialer associated with this connection and forwards the message to the
440
+ // server.
441
+ func (pc * proxyConn ) Write (msg []byte ) (n int , err error ) {
442
+ pc .dialer .storeSentMessage (msg )
443
+ return pc .Conn .Write (msg )
444
+ }
445
+
446
+ // Read reads the message from the server into the given buffer and stores the read message in the proxyDialer
447
+ // associated with this connection.
448
+ func (pc * proxyConn ) Read (buffer []byte ) (int , error ) {
449
+ n , err := pc .Conn .Read (buffer )
450
+ if err != nil {
451
+ return n , err
452
+ }
453
+
454
+ _ , err = io .Copy (pc .currentReading , bytes .NewReader (buffer ))
455
+ if err != nil {
456
+ return 0 , fmt .Errorf ("error copying to mock: %v" , err )
457
+ }
458
+ if len (buffer ) != 4 {
459
+ pc .dialer .storeReceivedMessage (pc .currentReading .Bytes ())
460
+ pc .currentReading .Reset ()
461
+ }
462
+
463
+ return n , err
336
464
}
0 commit comments