@@ -10,6 +10,7 @@ import (
1010 "net"
1111 "strconv"
1212 "strings"
13+ "time"
1314
1415 "code.gitea.io/gitea/modules/container"
1516 "code.gitea.io/gitea/modules/log"
@@ -111,28 +112,47 @@ func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) {
111112func dial (source * Source ) (* ldap.Conn , error ) {
112113 log .Trace ("Dialing LDAP with security protocol (%v) without verifying: %v" , source .SecurityProtocol , source .SkipVerify )
113114
114- tlsConfig := & tls.Config {
115- ServerName : source .Host ,
116- InsecureSkipVerify : source .SkipVerify ,
117- }
115+ ldap .DefaultTimeout = time .Second * 15
116+ // HostList is a list of hosts separated by commas
117+ hostList := strings .Split (source .HostList , "," )
118118
119- if source .SecurityProtocol == SecurityProtocolLDAPS {
120- return ldap .DialTLS ("tcp" , net .JoinHostPort (source .Host , strconv .Itoa (source .Port )), tlsConfig )
121- }
119+ for _ , host := range hostList {
120+ tlsConfig := & tls.Config {
121+ ServerName : host ,
122+ InsecureSkipVerify : source .SkipVerify ,
123+ }
122124
123- conn , err := ldap .Dial ("tcp" , net .JoinHostPort (source .Host , strconv .Itoa (source .Port )))
124- if err != nil {
125- return nil , fmt .Errorf ("error during Dial: %w" , err )
126- }
125+ if source .SecurityProtocol == SecurityProtocolLDAPS {
126+ conn , err := ldap .DialTLS ("tcp" , net .JoinHostPort (host , strconv .Itoa (source .Port )), tlsConfig )
127+
128+ if err != nil {
129+ // Connection failed, try again with the next host.
130+ log .Trace ("error during Dial for host %s: %w" , host , err )
131+ continue
132+ }
133+ conn .SetTimeout (time .Second * 10 )
127134
128- if source .SecurityProtocol == SecurityProtocolStartTLS {
129- if err = conn .StartTLS (tlsConfig ); err != nil {
130- conn .Close ()
131- return nil , fmt .Errorf ("error during StartTLS: %w" , err )
135+ return conn , err
136+ }
137+
138+ conn , err := ldap .Dial ("tcp" , net .JoinHostPort (host , strconv .Itoa (source .Port )))
139+ if err != nil {
140+ log .Trace ("error during Dial for host %s: %w" , host , err )
141+ continue
142+ }
143+ conn .SetTimeout (time .Second * 10 )
144+
145+ if source .SecurityProtocol == SecurityProtocolStartTLS {
146+ if err = conn .StartTLS (tlsConfig ); err != nil {
147+ conn .Close ()
148+ log .Trace ("error during StartTLS for host %s: %w" , host , err )
149+ continue
150+ }
132151 }
133152 }
134153
135- return conn , nil
154+ // All servers were unreachable
155+ return nil , fmt .Errorf ("dial failed for all provided servers: %s" , hostList )
136156}
137157
138158func bindUser (l * ldap.Conn , userDN , passwd string ) error {
@@ -257,7 +277,7 @@ func (source *Source) SearchEntry(name, passwd string, directBind bool) *SearchR
257277 }
258278 l , err := dial (source )
259279 if err != nil {
260- log .Error ("LDAP Connect error, %s:%v" , source .Host , err )
280+ log .Error ("LDAP Connect error, %s:%v" , source .HostList , err )
261281 source .Enabled = false
262282 return nil
263283 }
@@ -421,7 +441,7 @@ func (source *Source) UsePagedSearch() bool {
421441func (source * Source ) SearchEntries () ([]* SearchResult , error ) {
422442 l , err := dial (source )
423443 if err != nil {
424- log .Error ("LDAP Connect error, %s:%v" , source .Host , err )
444+ log .Error ("LDAP Connect error, %s:%v" , source .HostList , err )
425445 source .Enabled = false
426446 return nil , err
427447 }
0 commit comments