Skip to content

Commit 999df82

Browse files
committed
refactor code (#6898)
Signed-off-by: abhishek818 <[email protected]>
1 parent 8d9269f commit 999df82

File tree

1 file changed

+51
-36
lines changed

1 file changed

+51
-36
lines changed

services/auth/source/ldap/source_search.go

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"net"
1111
"strconv"
1212
"strings"
13+
"sync"
1314
"time"
1415

1516
"code.gitea.io/gitea/modules/container"
@@ -32,6 +33,12 @@ type SearchResult struct {
3233
Groups container.Set[string]
3334
}
3435

36+
// DialResult : dial response
37+
type DialResult struct {
38+
conn *ldap.Conn
39+
err error
40+
}
41+
3542
func (source *Source) sanitizedUserQuery(username string) (string, bool) {
3643
// See http://tools.ietf.org/search/rfc4515
3744
badCharacters := "\x00()*\\"
@@ -109,6 +116,39 @@ func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) {
109116
return userDN, true
110117
}
111118

119+
func dialHost(host string, source *Source, results chan DialResult, wg *sync.WaitGroup) {
120+
defer wg.Done()
121+
122+
tlsConfig := &tls.Config{
123+
ServerName: host,
124+
InsecureSkipVerify: source.SkipVerify,
125+
}
126+
127+
var conn *ldap.Conn
128+
var err error
129+
130+
if source.SecurityProtocol == SecurityProtocolLDAPS {
131+
conn, err = ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig)
132+
} else {
133+
conn, err = ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)))
134+
if err == nil && source.SecurityProtocol == SecurityProtocolStartTLS {
135+
err = conn.StartTLS(tlsConfig)
136+
}
137+
}
138+
139+
if err != nil {
140+
if conn != nil {
141+
conn.Close()
142+
}
143+
log.Trace("error during Dial for host %s: %w", host, err)
144+
results <- DialResult{nil, err}
145+
return
146+
}
147+
148+
conn.SetTimeout(time.Second * 10)
149+
results <- DialResult{conn, nil}
150+
}
151+
112152
func dial(source *Source) (*ldap.Conn, error) {
113153
log.Trace("Dialing LDAP with security protocol (%v) without verifying: %v", source.SecurityProtocol, source.SkipVerify)
114154

@@ -118,46 +158,21 @@ func dial(source *Source) (*ldap.Conn, error) {
118158
// HostList is a list of hosts separated by commas
119159
hostList := strings.Split(tempHostList, ",")
120160

121-
type result struct {
122-
conn *ldap.Conn
123-
err error
124-
}
125-
126-
results := make(chan result, len(hostList))
161+
results := make(chan DialResult, len(hostList))
162+
var wg sync.WaitGroup
127163

164+
// Race all connections
128165
for _, host := range hostList {
129-
go func(host string) {
130-
tlsConfig := &tls.Config{
131-
ServerName: host,
132-
InsecureSkipVerify: source.SkipVerify,
133-
}
134-
135-
var conn *ldap.Conn
136-
var err error
137-
138-
if source.SecurityProtocol == SecurityProtocolLDAPS {
139-
conn, err = ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig)
140-
} else {
141-
conn, err = ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)))
142-
if err == nil && source.SecurityProtocol == SecurityProtocolStartTLS {
143-
err = conn.StartTLS(tlsConfig)
144-
}
145-
}
146-
147-
if err != nil {
148-
if conn != nil {
149-
conn.Close()
150-
}
151-
log.Trace("error during Dial for host %s: %w", host, err)
152-
results <- result{nil, err}
153-
return
154-
}
155-
156-
conn.SetTimeout(time.Second * 10)
157-
results <- result{conn, nil}
158-
}(host)
166+
wg.Add(1)
167+
go dialHost(host, source, results, &wg)
159168
}
160169

170+
// Close the results channel after all goroutines finish
171+
go func() {
172+
wg.Wait()
173+
close(results)
174+
}()
175+
161176
for range hostList {
162177
r := <-results
163178
if r.err == nil {

0 commit comments

Comments
 (0)