@@ -112,50 +112,68 @@ func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) {
112112func dial (source * Source ) (* ldap.Conn , error ) {
113113 log .Trace ("Dialing LDAP with security protocol (%v) without verifying: %v" , source .SecurityProtocol , source .SkipVerify )
114114
115- ldap .DefaultTimeout = time .Second * 15
115+ ldap .DefaultTimeout = time .Second * 10
116116 // Remove any extra spaces in HostList string
117117 tempHostList := strings .ReplaceAll (source .HostList , " " , "" )
118118 // HostList is a list of hosts separated by commas
119119 hostList := strings .Split (tempHostList , "," )
120- // hostList := strings.Split(source.HostList, ",")
121120
122- for _ , host := range hostList {
123- tlsConfig := & tls.Config {
124- ServerName : host ,
125- InsecureSkipVerify : source .SkipVerify ,
126- }
121+ type result struct {
122+ conn * ldap.Conn
123+ err error
124+ }
127125
128- if source . SecurityProtocol == SecurityProtocolLDAPS {
129- conn , err := ldap . DialTLS ( "tcp" , net . JoinHostPort ( host , strconv . Itoa ( source . Port )), tlsConfig )
130- if err != nil {
131- // Connection failed, try again with the next host.
132- conn . Close ()
133- log . Trace ( "error during Dial for host %s: %w" , host , err )
134- continue
126+ results := make ( chan result , len ( hostList ))
127+
128+ for _ , host := range hostList {
129+ go func ( host string ) {
130+ tlsConfig := & tls. Config {
131+ ServerName : host ,
132+ InsecureSkipVerify : source . SkipVerify ,
135133 }
136- conn .SetTimeout (time .Second * 10 )
137134
138- return conn , err
139- }
135+ var conn * ldap. Conn
136+ var err error
140137
141- conn , err := ldap .Dial ("tcp" , net .JoinHostPort (host , strconv .Itoa (source .Port )))
142- if err != nil {
143- conn .Close ()
144- log .Trace ("error during Dial for host %s: %w" , host , err )
145- continue
146- }
147- conn .SetTimeout (time .Second * 10 )
138+ if source .SecurityProtocol == SecurityProtocolLDAPS {
139+ conn , err = ldap .DialTLS ("tcp" , net .JoinHostPort (host , strconv .Itoa (source .Port )), tlsConfig )
140+ } else {
141+ conn , err = ldap .Dial ("tcp" , net .JoinHostPort (host , strconv .Itoa (source .Port )))
142+ if err == nil && source .SecurityProtocol == SecurityProtocolStartTLS {
143+ err = conn .StartTLS (tlsConfig )
144+ }
145+ }
148146
149- if source .SecurityProtocol == SecurityProtocolStartTLS {
150- if err = conn .StartTLS (tlsConfig ); err != nil {
151- conn .Close ()
152- log .Trace ("error during StartTLS for host %s: %w" , host , err )
153- continue
147+ if err != nil {
148+ if conn != nil {
149+ conn .Close ()
150+ }
151+ log .Trace ("error during Dial for host %s: %w" , host , err )
152+ results <- result {nil , err }
153+ return
154154 }
155+
156+ conn .SetTimeout (time .Second * 10 )
157+ results <- result {conn , nil }
158+ }(host )
159+ }
160+
161+ for range hostList {
162+ r := <- results
163+ if r .err == nil {
164+ // Close other connections still in progress
165+ go func () {
166+ for range hostList {
167+ r := <- results
168+ if r .conn != nil {
169+ r .conn .Close ()
170+ }
171+ }
172+ }()
173+ return r .conn , nil
155174 }
156175 }
157176
158- // All servers were unreachable
159177 return nil , fmt .Errorf ("dial failed for all provided servers: %s" , hostList )
160178}
161179
0 commit comments