Skip to content

Commit 1fb370e

Browse files
committed
fix(oauth): OAuth clients not to inherit DefaultClient config
1 parent 54983b9 commit 1fb370e

File tree

4 files changed

+287
-143
lines changed

4 files changed

+287
-143
lines changed

auth_providers/auth_core.go

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ type CommandAuthConfig struct {
141141

142142
// HttpClient is the http Client to be used for authentication to Keyfactor Command API
143143
HttpClient *http.Client
144+
//DefaultHttpClient *http.Client
144145
}
145146

146147
// cleanHostName cleans the hostname for authentication to Keyfactor Command API.
@@ -275,66 +276,53 @@ func (c *CommandAuthConfig) ValidateAuthConfig() error {
275276
// check if CommandCACert is set in environment
276277
if caCert, ok := os.LookupEnv(EnvKeyfactorCACert); ok {
277278
c.CommandCACert = caCert
278-
} else {
279-
return nil
280279
}
281280
}
282281

283282
// check for skip verify in environment
284283
if skipVerify, ok := os.LookupEnv(EnvKeyfactorSkipVerify); ok {
285284
c.SkipVerify = skipVerify == "true" || skipVerify == "1"
286285
}
287-
288-
//TODO: This should be part of BuildTransport
289-
//if c.SkipVerify {
290-
// c.HttpClient.Transport = &http.Transport{
291-
// TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
292-
// }
293-
// //return nil
294-
//}
295-
//
296-
//caErr := c.updateCACerts()
297-
//if caErr != nil {
298-
// return caErr
299-
//}
300-
301286
return nil
302287
}
303288

304289
// BuildTransport creates a custom http Transport for authentication to Keyfactor Command API.
305290
func (c *CommandAuthConfig) BuildTransport() (*http.Transport, error) {
306-
var output *http.Transport
307-
if c.HttpClient == nil {
308-
c.SetClient(nil)
309-
}
310-
// check if c already has a transport and if it does, assign it to output else create a new transport
311-
if c.HttpClient.Transport != nil {
312-
if transport, ok := c.HttpClient.Transport.(*http.Transport); ok {
313-
output = transport
314-
} else {
315-
output = &http.Transport{
316-
TLSClientConfig: &tls.Config{},
317-
}
318-
}
319-
} else {
320-
output = &http.Transport{
321-
Proxy: http.ProxyFromEnvironment,
322-
TLSClientConfig: &tls.Config{
323-
Renegotiation: tls.RenegotiateOnceAsClient,
324-
},
325-
TLSHandshakeTimeout: 10 * time.Second,
326-
}
291+
output := http.Transport{
292+
Proxy: http.ProxyFromEnvironment,
293+
TLSClientConfig: &tls.Config{
294+
Renegotiation: tls.RenegotiateOnceAsClient,
295+
},
296+
TLSHandshakeTimeout: 10 * time.Second,
327297
}
328298

329299
if c.SkipVerify {
330300
output.TLSClientConfig.InsecureSkipVerify = true
331301
}
332302

333303
if c.CommandCACert != "" {
334-
_ = c.updateCACerts()
304+
if _, err := os.Stat(c.CommandCACert); err == nil {
305+
cert, ioErr := os.ReadFile(c.CommandCACert)
306+
if ioErr != nil {
307+
return &output, ioErr
308+
}
309+
// check if output.TLSClientConfig.RootCAs is nil
310+
if output.TLSClientConfig.RootCAs == nil {
311+
output.TLSClientConfig.RootCAs = x509.NewCertPool()
312+
}
313+
// Append your custom cert to the pool
314+
if ok := output.TLSClientConfig.RootCAs.AppendCertsFromPEM(cert); !ok {
315+
return &output, fmt.Errorf("failed to append custom CA cert to pool")
316+
}
317+
} else {
318+
// Append your custom cert to the pool
319+
if ok := output.TLSClientConfig.RootCAs.AppendCertsFromPEM([]byte(c.CommandCACert)); !ok {
320+
return &output, fmt.Errorf("failed to append custom CA cert to pool")
321+
}
322+
}
335323
}
336324

337-
return output, nil
325+
return &output, nil
338326
}
339327

340328
// SetClient sets the http Client for authentication to Keyfactor Command API.
@@ -343,8 +331,34 @@ func (c *CommandAuthConfig) SetClient(client *http.Client) *http.Client {
343331
c.HttpClient = client
344332
}
345333
if c.HttpClient == nil {
346-
c.HttpClient = http.DefaultClient
334+
//// Copy the default transport and apply the custom TLS config
335+
//defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
336+
////defaultTransport.TLSClientConfig = tlsConfig
337+
//c.HttpClient = &http.Client{Transport: defaultTransport}
338+
defaultTimeout := time.Duration(c.HttpClientTimeout) * time.Second
339+
c.HttpClient = &http.Client{
340+
Transport: &http.Transport{
341+
Proxy: http.ProxyFromEnvironment,
342+
TLSClientConfig: &tls.Config{
343+
Renegotiation: tls.RenegotiateOnceAsClient,
344+
},
345+
TLSHandshakeTimeout: defaultTimeout,
346+
DisableKeepAlives: false,
347+
DisableCompression: false,
348+
MaxIdleConns: 10,
349+
MaxIdleConnsPerHost: 10,
350+
MaxConnsPerHost: 10,
351+
IdleConnTimeout: defaultTimeout,
352+
ResponseHeaderTimeout: defaultTimeout,
353+
ExpectContinueTimeout: defaultTimeout,
354+
MaxResponseHeaderBytes: 0,
355+
WriteBufferSize: 0,
356+
ReadBufferSize: 0,
357+
ForceAttemptHTTP2: false,
358+
},
359+
}
347360
}
361+
348362
return c.HttpClient
349363
}
350364

@@ -356,6 +370,7 @@ func (c *CommandAuthConfig) updateCACerts() error {
356370
if caCert, ok := os.LookupEnv(EnvKeyfactorCACert); ok {
357371
c.CommandCACert = caCert
358372
} else {
373+
// nothing to do
359374
return nil
360375
}
361376
}
@@ -452,7 +467,6 @@ func (c *CommandAuthConfig) Authenticate() error {
452467
}
453468

454469
c.HttpClient.Timeout = time.Duration(c.HttpClientTimeout) * time.Second
455-
456470
cResp, cErr := c.HttpClient.Do(req)
457471
if cErr != nil {
458472
return cErr
@@ -645,7 +659,7 @@ func (c *CommandAuthConfig) LoadConfig(profile string, configFilePath string, si
645659
if c.CommandCACert == "" {
646660
c.CommandCACert = server.CACertPath
647661
}
648-
if c.SkipVerify {
662+
if !c.SkipVerify {
649663
c.SkipVerify = server.SkipTLSVerify
650664
}
651665

auth_providers/auth_oauth.go

Lines changed: 26 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package auth_providers
22

33
import (
44
"context"
5-
"crypto/tls"
65
"crypto/x509"
76
"fmt"
87
"net/http"
@@ -51,6 +50,11 @@ type OAuthAuthenticator struct {
5150
Client *http.Client
5251
}
5352

53+
type oauth2Transport struct {
54+
base http.RoundTripper
55+
src oauth2.TokenSource
56+
}
57+
5458
// GetHttpClient returns the http client
5559
func (a *OAuthAuthenticator) GetHttpClient() (*http.Client, error) {
5660
return a.Client, nil
@@ -162,24 +166,17 @@ func (b *CommandConfigOauth) WithHttpClient(httpClient *http.Client) *CommandCon
162166
// GetHttpClient returns an HTTP client for oAuth authentication.
163167
func (b *CommandConfigOauth) GetHttpClient() (*http.Client, error) {
164168
cErr := b.ValidateAuthConfig()
165-
var client http.Client
166-
if b.CommandAuthConfig.HttpClient != nil {
167-
client = *b.CommandAuthConfig.HttpClient
168-
}
169169
if cErr != nil {
170170
return nil, cErr
171171
}
172172

173-
if client.Transport == nil {
174-
transport, tErr := b.BuildTransport()
175-
if tErr != nil {
176-
return nil, tErr
177-
}
178-
client.Transport = transport
173+
var client http.Client
174+
baseTransport, tErr := b.BuildTransport()
175+
if tErr != nil {
176+
return nil, tErr
179177
}
180178

181179
if b.AccessToken != "" {
182-
baseTransport := cloneHTTPTransport(client.Transport.(*http.Transport))
183180
client.Transport = &oauth2.Transport{
184181
Base: baseTransport,
185182
Source: oauth2.StaticTokenSource(
@@ -209,15 +206,15 @@ func (b *CommandConfigOauth) GetHttpClient() (*http.Client, error) {
209206
}
210207
}
211208

212-
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client)
213-
209+
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: baseTransport})
214210
tokenSource := config.TokenSource(ctx)
215-
baseTransport := cloneHTTPTransport(client.Transport.(*http.Transport))
216-
oauthTransport := oauth2.Transport{
217-
Base: baseTransport,
218-
Source: tokenSource,
211+
212+
client = http.Client{
213+
Transport: &oauth2Transport{
214+
base: baseTransport,
215+
src: tokenSource,
216+
},
219217
}
220-
client.Transport = &oauthTransport
221218

222219
return &client, nil
223220
}
@@ -375,6 +372,7 @@ func (b *CommandConfigOauth) Authenticate() error {
375372
}
376373

377374
b.SetClient(oauthy)
375+
//b.DefaultHttpClient = oauthy
378376

379377
aErr := b.CommandAuthConfig.Authenticate()
380378
if aErr != nil {
@@ -401,79 +399,16 @@ func (b *CommandConfigOauth) GetServerConfig() *Server {
401399
return &server
402400
}
403401

404-
// Example usage of CommandConfigOauth
405-
//
406-
// This example demonstrates how to use CommandConfigOauth to authenticate to the Keyfactor Command API using OAuth2.
407-
//
408-
// func ExampleCommandConfigOauth_Authenticate() {
409-
// authConfig := &CommandConfigOauth{
410-
// CommandAuthConfig: CommandAuthConfig{
411-
// ConfigFilePath: "/path/to/config.json",
412-
// ConfigProfile: "default",
413-
// CommandHostName: "exampleHost",
414-
// CommandPort: 443,
415-
// CommandAPIPath: "/api/v1",
416-
// CommandCACert: "/path/to/ca-cert.pem",
417-
// SkipVerify: true,
418-
// HttpClientTimeout: 60,
419-
// },
420-
// ClientID: "exampleClientID",
421-
// ClientSecret: "exampleClientSecret",
422-
// TokenURL: "https://example.com/oauth/token",
423-
// Scopes: []string{"openid", "profile", "email"},
424-
// Audience: "exampleAudience",
425-
// CACertificatePath: "/path/to/ca-cert.pem",
426-
// AccessToken: "exampleAccessToken",
427-
// }
428-
//
429-
// err := authConfig.Authenticate()
430-
// if err != nil {
431-
// fmt.Println("Authentication failed:", err)
432-
// } else {
433-
// fmt.Println("Authentication successful")
434-
// }
435-
// }
436-
437-
func cloneHTTPTransport(original *http.Transport) *http.Transport {
438-
if original == nil {
439-
return nil
402+
// RoundTrip executes a single HTTP transaction, adding the OAuth2 token to the request
403+
func (t *oauth2Transport) RoundTrip(req *http.Request) (*http.Response, error) {
404+
token, err := t.src.Token()
405+
if err != nil {
406+
return nil, fmt.Errorf("failed to retrieve OAuth token: %w", err)
440407
}
441408

442-
return &http.Transport{
443-
Proxy: original.Proxy,
444-
DialContext: original.DialContext,
445-
ForceAttemptHTTP2: original.ForceAttemptHTTP2,
446-
MaxIdleConns: original.MaxIdleConns,
447-
IdleConnTimeout: original.IdleConnTimeout,
448-
TLSHandshakeTimeout: original.TLSHandshakeTimeout,
449-
ExpectContinueTimeout: original.ExpectContinueTimeout,
450-
ResponseHeaderTimeout: original.ResponseHeaderTimeout,
451-
TLSClientConfig: cloneTLSConfig(original.TLSClientConfig),
452-
DialTLSContext: original.DialTLSContext,
453-
DisableKeepAlives: original.DisableKeepAlives,
454-
DisableCompression: original.DisableCompression,
455-
MaxIdleConnsPerHost: original.MaxIdleConnsPerHost,
456-
MaxConnsPerHost: original.MaxConnsPerHost,
457-
WriteBufferSize: original.WriteBufferSize,
458-
ReadBufferSize: original.ReadBufferSize,
459-
}
460-
}
409+
// Clone the request to avoid mutating the original
410+
reqCopy := req.Clone(req.Context())
411+
token.SetAuthHeader(reqCopy)
461412

462-
func cloneTLSConfig(original *tls.Config) *tls.Config {
463-
if original == nil {
464-
return nil
465-
}
466-
467-
return &tls.Config{
468-
InsecureSkipVerify: original.InsecureSkipVerify,
469-
MinVersion: original.MinVersion,
470-
MaxVersion: original.MaxVersion,
471-
CipherSuites: original.CipherSuites,
472-
PreferServerCipherSuites: original.PreferServerCipherSuites,
473-
NextProtos: original.NextProtos,
474-
ServerName: original.ServerName,
475-
ClientAuth: original.ClientAuth,
476-
RootCAs: original.RootCAs,
477-
// Deep copy the rest of the TLS fields as needed
478-
}
413+
return t.base.RoundTrip(reqCopy)
479414
}

0 commit comments

Comments
 (0)