@@ -56,40 +56,59 @@ type HandshakeOptions struct {
56
56
PerformAuthentication func (description.Server ) bool
57
57
}
58
58
59
- // Handshaker creates a connection handshaker for the given authenticator.
60
- func Handshaker (h driver.Handshaker , options * HandshakeOptions ) driver.Handshaker {
61
- return driver .HandshakerFunc (func (ctx context.Context , addr address.Address , conn driver.Connection ) (description.Server , error ) {
62
- desc , err := operation .NewIsMaster ().
63
- AppName (options .AppName ).
64
- Compressors (options .Compressors ).
65
- SASLSupportedMechs (options .DBUser ).
66
- Handshake (ctx , addr , conn )
59
+ type authHandshaker struct {
60
+ wrapped driver.Handshaker
61
+ options * HandshakeOptions
62
+ }
67
63
68
- if err != nil {
69
- return description.Server {}, newAuthError ("handshake failure" , err )
70
- }
64
+ // GetDescription performs an isMaster to retrieve the initial description for conn.
65
+ func (ah * authHandshaker ) GetDescription (ctx context.Context , addr address.Address , conn driver.Connection ) (description.Server , error ) {
66
+ if ah .wrapped != nil {
67
+ return ah .wrapped .GetDescription (ctx , addr , conn )
68
+ }
71
69
72
- performAuth := options .PerformAuthentication
73
- if performAuth == nil {
74
- performAuth = func (serv description.Server ) bool {
75
- return serv .Kind == description .RSPrimary ||
76
- serv .Kind == description .RSSecondary ||
77
- serv .Kind == description .Mongos ||
78
- serv .Kind == description .Standalone
79
- }
80
- }
81
- if performAuth (desc ) && options .Authenticator != nil {
82
- err = options .Authenticator .Auth (ctx , desc , conn )
83
- if err != nil {
84
- return description.Server {}, newAuthError ("auth error" , err )
85
- }
70
+ desc , err := operation .NewIsMaster ().
71
+ AppName (ah .options .AppName ).
72
+ Compressors (ah .options .Compressors ).
73
+ SASLSupportedMechs (ah .options .DBUser ).
74
+ GetDescription (ctx , addr , conn )
75
+ if err != nil {
76
+ return description.Server {}, newAuthError ("handshake failure" , err )
77
+ }
78
+ return desc , nil
79
+ }
86
80
81
+ // FinishHandshake performs authentication for conn if necessary.
82
+ func (ah * authHandshaker ) FinishHandshake (ctx context.Context , conn driver.Connection ) error {
83
+ performAuth := ah .options .PerformAuthentication
84
+ if performAuth == nil {
85
+ performAuth = func (serv description.Server ) bool {
86
+ return serv .Kind == description .RSPrimary ||
87
+ serv .Kind == description .RSSecondary ||
88
+ serv .Kind == description .Mongos ||
89
+ serv .Kind == description .Standalone
87
90
}
88
- if h == nil {
89
- return desc , nil
91
+ }
92
+ desc := conn .Description ()
93
+ if performAuth (desc ) && ah .options .Authenticator != nil {
94
+ err := ah .options .Authenticator .Auth (ctx , desc , conn )
95
+ if err != nil {
96
+ return newAuthError ("auth error" , err )
90
97
}
91
- return h .Handshake (ctx , addr , conn )
92
- })
98
+ }
99
+
100
+ if ah .wrapped == nil {
101
+ return nil
102
+ }
103
+ return ah .wrapped .FinishHandshake (ctx , conn )
104
+ }
105
+
106
+ // Handshaker creates a connection handshaker for the given authenticator.
107
+ func Handshaker (h driver.Handshaker , options * HandshakeOptions ) driver.Handshaker {
108
+ return & authHandshaker {
109
+ wrapped : h ,
110
+ options : options ,
111
+ }
93
112
}
94
113
95
114
// Authenticator handles authenticating a connection.
0 commit comments