@@ -10,6 +10,7 @@ import (
1010 "net"
1111 "strconv"
1212 "strings"
13+ "sync"
1314 "time"
1415
1516 "code.gitea.io/gitea/modules/container"
@@ -32,6 +33,12 @@ type SearchResult struct {
3233 Groups container.Set [string ]
3334}
3435
36+ // DialResult : dial response
37+ type DialResult struct {
38+ conn * ldap.Conn
39+ err error
40+ }
41+
3542func (source * Source ) sanitizedUserQuery (username string ) (string , bool ) {
3643 // See http://tools.ietf.org/search/rfc4515
3744 badCharacters := "\x00 ()*\\ "
@@ -109,6 +116,39 @@ func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) {
109116 return userDN , true
110117}
111118
119+ func dialHost (host string , source * Source , results chan DialResult , wg * sync.WaitGroup ) {
120+ defer wg .Done ()
121+
122+ tlsConfig := & tls.Config {
123+ ServerName : host ,
124+ InsecureSkipVerify : source .SkipVerify ,
125+ }
126+
127+ var conn * ldap.Conn
128+ var err error
129+
130+ if source .SecurityProtocol == SecurityProtocolLDAPS {
131+ conn , err = ldap .DialTLS ("tcp" , net .JoinHostPort (host , strconv .Itoa (source .Port )), tlsConfig )
132+ } else {
133+ conn , err = ldap .Dial ("tcp" , net .JoinHostPort (host , strconv .Itoa (source .Port )))
134+ if err == nil && source .SecurityProtocol == SecurityProtocolStartTLS {
135+ err = conn .StartTLS (tlsConfig )
136+ }
137+ }
138+
139+ if err != nil {
140+ if conn != nil {
141+ conn .Close ()
142+ }
143+ log .Trace ("error during Dial for host %s: %w" , host , err )
144+ results <- DialResult {nil , err }
145+ return
146+ }
147+
148+ conn .SetTimeout (time .Second * 10 )
149+ results <- DialResult {conn , nil }
150+ }
151+
112152func dial (source * Source ) (* ldap.Conn , error ) {
113153 log .Trace ("Dialing LDAP with security protocol (%v) without verifying: %v" , source .SecurityProtocol , source .SkipVerify )
114154
@@ -118,46 +158,21 @@ func dial(source *Source) (*ldap.Conn, error) {
118158 // HostList is a list of hosts separated by commas
119159 hostList := strings .Split (tempHostList , "," )
120160
121- type result struct {
122- conn * ldap.Conn
123- err error
124- }
125-
126- results := make (chan result , len (hostList ))
161+ results := make (chan DialResult , len (hostList ))
162+ var wg sync.WaitGroup
127163
164+ // Race all connections
128165 for _ , host := range hostList {
129- go func (host string ) {
130- tlsConfig := & tls.Config {
131- ServerName : host ,
132- InsecureSkipVerify : source .SkipVerify ,
133- }
134-
135- var conn * ldap.Conn
136- var err error
137-
138- if source .SecurityProtocol == SecurityProtocolLDAPS {
139- conn , err = ldap .DialTLS ("tcp" , net .JoinHostPort (host , strconv .Itoa (source .Port )), tlsConfig )
140- } else {
141- conn , err = ldap .Dial ("tcp" , net .JoinHostPort (host , strconv .Itoa (source .Port )))
142- if err == nil && source .SecurityProtocol == SecurityProtocolStartTLS {
143- err = conn .StartTLS (tlsConfig )
144- }
145- }
146-
147- if err != nil {
148- if conn != nil {
149- conn .Close ()
150- }
151- log .Trace ("error during Dial for host %s: %w" , host , err )
152- results <- result {nil , err }
153- return
154- }
155-
156- conn .SetTimeout (time .Second * 10 )
157- results <- result {conn , nil }
158- }(host )
166+ wg .Add (1 )
167+ go dialHost (host , source , results , & wg )
159168 }
160169
170+ // Close the results channel after all goroutines finish
171+ go func () {
172+ wg .Wait ()
173+ close (results )
174+ }()
175+
161176 for range hostList {
162177 r := <- results
163178 if r .err == nil {
0 commit comments