@@ -2,7 +2,6 @@ package auth_providers
22
33import (
44 "context"
5- "crypto/tls"
65 "crypto/x509"
76 "fmt"
87 "net/http"
@@ -51,6 +50,11 @@ type OAuthAuthenticator struct {
5150 Client * http.Client
5251}
5352
53+ type oauth2Transport struct {
54+ base http.RoundTripper
55+ src oauth2.TokenSource
56+ }
57+
5458// GetHttpClient returns the http client
5559func (a * OAuthAuthenticator ) GetHttpClient () (* http.Client , error ) {
5660 return a .Client , nil
@@ -162,24 +166,17 @@ func (b *CommandConfigOauth) WithHttpClient(httpClient *http.Client) *CommandCon
162166// GetHttpClient returns an HTTP client for oAuth authentication.
163167func (b * CommandConfigOauth ) GetHttpClient () (* http.Client , error ) {
164168 cErr := b .ValidateAuthConfig ()
165- var client http.Client
166- if b .CommandAuthConfig .HttpClient != nil {
167- client = * b .CommandAuthConfig .HttpClient
168- }
169169 if cErr != nil {
170170 return nil , cErr
171171 }
172172
173- if client .Transport == nil {
174- transport , tErr := b .BuildTransport ()
175- if tErr != nil {
176- return nil , tErr
177- }
178- client .Transport = transport
173+ var client http.Client
174+ baseTransport , tErr := b .BuildTransport ()
175+ if tErr != nil {
176+ return nil , tErr
179177 }
180178
181179 if b .AccessToken != "" {
182- baseTransport := cloneHTTPTransport (client .Transport .(* http.Transport ))
183180 client .Transport = & oauth2.Transport {
184181 Base : baseTransport ,
185182 Source : oauth2 .StaticTokenSource (
@@ -209,15 +206,15 @@ func (b *CommandConfigOauth) GetHttpClient() (*http.Client, error) {
209206 }
210207 }
211208
212- ctx := context .WithValue (context .Background (), oauth2 .HTTPClient , client )
213-
209+ ctx := context .WithValue (context .Background (), oauth2 .HTTPClient , & http.Client {Transport : baseTransport })
214210 tokenSource := config .TokenSource (ctx )
215- baseTransport := cloneHTTPTransport (client .Transport .(* http.Transport ))
216- oauthTransport := oauth2.Transport {
217- Base : baseTransport ,
218- Source : tokenSource ,
211+
212+ client = http.Client {
213+ Transport : & oauth2Transport {
214+ base : baseTransport ,
215+ src : tokenSource ,
216+ },
219217 }
220- client .Transport = & oauthTransport
221218
222219 return & client , nil
223220}
@@ -375,6 +372,7 @@ func (b *CommandConfigOauth) Authenticate() error {
375372 }
376373
377374 b .SetClient (oauthy )
375+ //b.DefaultHttpClient = oauthy
378376
379377 aErr := b .CommandAuthConfig .Authenticate ()
380378 if aErr != nil {
@@ -401,79 +399,16 @@ func (b *CommandConfigOauth) GetServerConfig() *Server {
401399 return & server
402400}
403401
404- // Example usage of CommandConfigOauth
405- //
406- // This example demonstrates how to use CommandConfigOauth to authenticate to the Keyfactor Command API using OAuth2.
407- //
408- // func ExampleCommandConfigOauth_Authenticate() {
409- // authConfig := &CommandConfigOauth{
410- // CommandAuthConfig: CommandAuthConfig{
411- // ConfigFilePath: "/path/to/config.json",
412- // ConfigProfile: "default",
413- // CommandHostName: "exampleHost",
414- // CommandPort: 443,
415- // CommandAPIPath: "/api/v1",
416- // CommandCACert: "/path/to/ca-cert.pem",
417- // SkipVerify: true,
418- // HttpClientTimeout: 60,
419- // },
420- // ClientID: "exampleClientID",
421- // ClientSecret: "exampleClientSecret",
422- // TokenURL: "https://example.com/oauth/token",
423- // Scopes: []string{"openid", "profile", "email"},
424- // Audience: "exampleAudience",
425- // CACertificatePath: "/path/to/ca-cert.pem",
426- // AccessToken: "exampleAccessToken",
427- // }
428- //
429- // err := authConfig.Authenticate()
430- // if err != nil {
431- // fmt.Println("Authentication failed:", err)
432- // } else {
433- // fmt.Println("Authentication successful")
434- // }
435- // }
436-
437- func cloneHTTPTransport (original * http.Transport ) * http.Transport {
438- if original == nil {
439- return nil
402+ // RoundTrip executes a single HTTP transaction, adding the OAuth2 token to the request
403+ func (t * oauth2Transport ) RoundTrip (req * http.Request ) (* http.Response , error ) {
404+ token , err := t .src .Token ()
405+ if err != nil {
406+ return nil , fmt .Errorf ("failed to retrieve OAuth token: %w" , err )
440407 }
441408
442- return & http.Transport {
443- Proxy : original .Proxy ,
444- DialContext : original .DialContext ,
445- ForceAttemptHTTP2 : original .ForceAttemptHTTP2 ,
446- MaxIdleConns : original .MaxIdleConns ,
447- IdleConnTimeout : original .IdleConnTimeout ,
448- TLSHandshakeTimeout : original .TLSHandshakeTimeout ,
449- ExpectContinueTimeout : original .ExpectContinueTimeout ,
450- ResponseHeaderTimeout : original .ResponseHeaderTimeout ,
451- TLSClientConfig : cloneTLSConfig (original .TLSClientConfig ),
452- DialTLSContext : original .DialTLSContext ,
453- DisableKeepAlives : original .DisableKeepAlives ,
454- DisableCompression : original .DisableCompression ,
455- MaxIdleConnsPerHost : original .MaxIdleConnsPerHost ,
456- MaxConnsPerHost : original .MaxConnsPerHost ,
457- WriteBufferSize : original .WriteBufferSize ,
458- ReadBufferSize : original .ReadBufferSize ,
459- }
460- }
409+ // Clone the request to avoid mutating the original
410+ reqCopy := req .Clone (req .Context ())
411+ token .SetAuthHeader (reqCopy )
461412
462- func cloneTLSConfig (original * tls.Config ) * tls.Config {
463- if original == nil {
464- return nil
465- }
466-
467- return & tls.Config {
468- InsecureSkipVerify : original .InsecureSkipVerify ,
469- MinVersion : original .MinVersion ,
470- MaxVersion : original .MaxVersion ,
471- CipherSuites : original .CipherSuites ,
472- PreferServerCipherSuites : original .PreferServerCipherSuites ,
473- NextProtos : original .NextProtos ,
474- ServerName : original .ServerName ,
475- ClientAuth : original .ClientAuth ,
476- RootCAs : original .RootCAs ,
477- // Deep copy the rest of the TLS fields as needed
478- }
413+ return t .base .RoundTrip (reqCopy )
479414}
0 commit comments