Skip to content

Commit f2c4cae

Browse files
committed
race the tcp connections (#6898)
Signed-off-by: abhishek818 <[email protected]>
1 parent 88ca8fa commit f2c4cae

File tree

1 file changed

+48
-30
lines changed

1 file changed

+48
-30
lines changed

services/auth/source/ldap/source_search.go

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -112,50 +112,68 @@ func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) {
112112
func dial(source *Source) (*ldap.Conn, error) {
113113
log.Trace("Dialing LDAP with security protocol (%v) without verifying: %v", source.SecurityProtocol, source.SkipVerify)
114114

115-
ldap.DefaultTimeout = time.Second * 15
115+
ldap.DefaultTimeout = time.Second * 10
116116
// Remove any extra spaces in HostList string
117117
tempHostList := strings.ReplaceAll(source.HostList, " ", "")
118118
// HostList is a list of hosts separated by commas
119119
hostList := strings.Split(tempHostList, ",")
120-
// hostList := strings.Split(source.HostList, ",")
121120

122-
for _, host := range hostList {
123-
tlsConfig := &tls.Config{
124-
ServerName: host,
125-
InsecureSkipVerify: source.SkipVerify,
126-
}
121+
type result struct {
122+
conn *ldap.Conn
123+
err error
124+
}
127125

128-
if source.SecurityProtocol == SecurityProtocolLDAPS {
129-
conn, err := ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig)
130-
if err != nil {
131-
// Connection failed, try again with the next host.
132-
conn.Close()
133-
log.Trace("error during Dial for host %s: %w", host, err)
134-
continue
126+
results := make(chan result, len(hostList))
127+
128+
for _, host := range hostList {
129+
go func(host string) {
130+
tlsConfig := &tls.Config{
131+
ServerName: host,
132+
InsecureSkipVerify: source.SkipVerify,
135133
}
136-
conn.SetTimeout(time.Second * 10)
137134

138-
return conn, err
139-
}
135+
var conn *ldap.Conn
136+
var err error
140137

141-
conn, err := ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)))
142-
if err != nil {
143-
conn.Close()
144-
log.Trace("error during Dial for host %s: %w", host, err)
145-
continue
146-
}
147-
conn.SetTimeout(time.Second * 10)
138+
if source.SecurityProtocol == SecurityProtocolLDAPS {
139+
conn, err = ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig)
140+
} else {
141+
conn, err = ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)))
142+
if err == nil && source.SecurityProtocol == SecurityProtocolStartTLS {
143+
err = conn.StartTLS(tlsConfig)
144+
}
145+
}
148146

149-
if source.SecurityProtocol == SecurityProtocolStartTLS {
150-
if err = conn.StartTLS(tlsConfig); err != nil {
151-
conn.Close()
152-
log.Trace("error during StartTLS for host %s: %w", host, err)
153-
continue
147+
if err != nil {
148+
if conn != nil {
149+
conn.Close()
150+
}
151+
log.Trace("error during Dial for host %s: %w", host, err)
152+
results <- result{nil, err}
153+
return
154154
}
155+
156+
conn.SetTimeout(time.Second * 10)
157+
results <- result{conn, nil}
158+
}(host)
159+
}
160+
161+
for range hostList {
162+
r := <-results
163+
if r.err == nil {
164+
// Close other connections still in progress
165+
go func() {
166+
for range hostList {
167+
r := <-results
168+
if r.conn != nil {
169+
r.conn.Close()
170+
}
171+
}
172+
}()
173+
return r.conn, nil
155174
}
156175
}
157176

158-
// All servers were unreachable
159177
return nil, fmt.Errorf("dial failed for all provided servers: %s", hostList)
160178
}
161179

0 commit comments

Comments
 (0)