Skip to content

Commit 31d5281

Browse files
committed
Factor TLS config logic out of connection.New
GODRIVER-336 Change-Id: Icdd21d949b570ffb954b0652281a6a6366c933a8
1 parent e07d7fe commit 31d5281

File tree

1 file changed

+33
-26
lines changed

1 file changed

+33
-26
lines changed

core/connection/connection.go

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -96,33 +96,10 @@ func New(ctx context.Context, address addr.Addr, opts ...Option) (Connection, *d
9696

9797
if cfg.tlsConfig != nil {
9898
tlsConfig := cfg.tlsConfig.Clone()
99-
if !tlsConfig.InsecureSkipVerify {
100-
hostname := address.String()
101-
colonPos := strings.LastIndex(hostname, ":")
102-
if colonPos == -1 {
103-
colonPos = len(hostname)
104-
}
105-
106-
hostname = hostname[:colonPos]
107-
tlsConfig.ServerName = hostname
108-
}
109-
110-
client := tls.Client(nc, tlsConfig.Config)
111-
112-
errChan := make(chan error, 1)
113-
go func() {
114-
errChan <- client.Handshake()
115-
}()
116-
117-
select {
118-
case err := <-errChan:
119-
if err != nil {
120-
return nil, nil, err
121-
}
122-
case <-ctx.Done():
123-
return nil, nil, errors.New("server connection cancelled/timeout during TLS handshake")
99+
nc, err = configureTLS(ctx, nc, address, tlsConfig)
100+
if err != nil {
101+
return nil, nil, err
124102
}
125-
nc = client
126103
}
127104

128105
var lifetimeDeadline time.Time
@@ -158,6 +135,36 @@ func New(ctx context.Context, address addr.Addr, opts ...Option) (Connection, *d
158135
return c, desc, nil
159136
}
160137

138+
func configureTLS(ctx context.Context, nc net.Conn, address addr.Addr, config *TLSConfig) (net.Conn, error) {
139+
if !config.InsecureSkipVerify {
140+
hostname := address.String()
141+
colonPos := strings.LastIndex(hostname, ":")
142+
if colonPos == -1 {
143+
colonPos = len(hostname)
144+
}
145+
146+
hostname = hostname[:colonPos]
147+
config.ServerName = hostname
148+
}
149+
150+
client := tls.Client(nc, config.Config)
151+
152+
errChan := make(chan error, 1)
153+
go func() {
154+
errChan <- client.Handshake()
155+
}()
156+
157+
select {
158+
case err := <-errChan:
159+
if err != nil {
160+
return nil, err
161+
}
162+
case <-ctx.Done():
163+
return nil, errors.New("server connection cancelled/timeout during TLS handshake")
164+
}
165+
return client, nil
166+
}
167+
161168
func (c *connection) Alive() bool {
162169
return !c.dead
163170
}

0 commit comments

Comments
 (0)