@@ -38,6 +38,7 @@ import (
3838const (
3939 headerNameUserAgent = "User-Agent"
4040 sdkName = "ibm-go-sdk-core"
41+ maxRedirects = 10
4142)
4243
4344// ServiceOptions is a struct of configuration values for a service.
@@ -117,7 +118,7 @@ func (service *BaseService) Clone() *BaseService {
117118 // First, copy the service options struct.
118119 serviceOptions := * service .Options
119120
120- // Next, make a copy the service struct, then use the copy of the service options.
121+ // Next, make a copy of the service struct, then use the copy of the service options.
121122 // Note, we'll re-use the "Client" instance from the original BaseService instance.
122123 clone := * service
123124 clone .Options = & serviceOptions
@@ -234,7 +235,7 @@ func (service *BaseService) SetDefaultHeaders(headers http.Header) {
234235// the retryable client; otherwise "client" will be stored
235236// directly on "service".
236237func (service * BaseService ) SetHTTPClient (client * http.Client ) {
237- setMinimumTLSVersion (client )
238+ setupHTTPClient (client )
238239
239240 if isRetryableClient (service .Client ) {
240241 // If "service" is currently holding a retryable client,
@@ -298,15 +299,83 @@ func (service *BaseService) IsSSLDisabled() bool {
298299 return false
299300}
300301
301- // setMinimumTLSVersion sets the minimum TLS version required by the client to TLS v1.2
302- func setMinimumTLSVersion (client * http.Client ) {
302+ // setupHTTPClient will configure "client" for use with the BaseService object.
303+ func setupHTTPClient (client * http.Client ) {
304+ // Set the minimum TLS version to be 1.2
303305 if tr , ok := client .Transport .(* http.Transport ); tr != nil && ok {
304306 if tr .TLSClientConfig == nil {
305307 tr .TLSClientConfig = & tls.Config {} // #nosec G402
306308 }
307309
308310 tr .TLSClientConfig .MinVersion = tls .VersionTLS12
309311 }
312+
313+ // Set our "CheckRedirect" function to allow safe headers to be included
314+ // in redirected requests under certain conditions.
315+ if client .CheckRedirect == nil {
316+ client .CheckRedirect = checkRedirect
317+ }
318+ }
319+
320+ // checkRedirect is used as an override for the default "CheckRedirect" function supplied
321+ // by the net/http package and implements some additional logic required by IBM SDKs.
322+ func checkRedirect (req * http.Request , via []* http.Request ) error {
323+
324+ // The net/http module is implemented such that it will only include "safe" headers
325+ // ("Authorization", "WWW-Authenticate", "Cookie", "Cookie2") when redirecting a request
326+ // if the redirected host is the same host or a sub-domain of the original request's host.
327+ // Example: foo.com redirected to foo.com or bar.foo.com would work, but bar.com would not.
328+ // This "CheckRedirect" implementation will propagate "safe" headers in a redirected request
329+ // only in situations where the hosts associated with the original and redirected request URLs
330+ // are both located within the ".cloud.ibm.com" domain.
331+
332+ // First, perform the check that is done by the default CheckRedirect function
333+ // to ensure we don't exhaust our max redirect limit.
334+ if len (via ) >= maxRedirects {
335+ GetLogger ().Debug ("Exceeded max redirects: %d" , maxRedirects )
336+ return fmt .Errorf ("stopped after %d redirects" , maxRedirects )
337+ }
338+
339+ if len (via ) > 0 {
340+ GetLogger ().Debug ("Detected %d prior request(s)" , len (via ))
341+ originalReq := via [0 ]
342+ redirectedReq := req
343+ GetLogger ().Debug ("Redirecting request from %s to %s" , originalReq .URL .String (), redirectedReq .URL .String ())
344+ redirectedHeader := req .Header
345+ originalHeader := via [0 ].Header
346+
347+ originalHost := originalReq .URL .Hostname ()
348+ redirectedHost := redirectedReq .URL .Hostname ()
349+
350+ if shouldCopySafeHeadersOnRedirect (originalHost , redirectedHost ) {
351+
352+ // We're only concerned with "safe" headers since these are the ones that are not
353+ // propagated automatically by net/http for a "cross-site" redirect.
354+ for _ , headerKey := range []string {"Authorization" , "WWW-Authenticate" , "Cookie" , "Cookie2" } {
355+ // If the original request contains a value for "headerKey"
356+ // *and* this header is not already present in the redirected request,
357+ // then copy the value from the original request to the redirected request.
358+ if v , inOriginalRequest := originalHeader [headerKey ]; inOriginalRequest {
359+ if _ , inRedirectedRequest := redirectedHeader [headerKey ]; ! inRedirectedRequest {
360+ redirectedHeader [headerKey ] = v
361+ GetLogger ().Debug ("Propagating header '%s' in redirected request" , headerKey )
362+ }
363+ }
364+ }
365+ } else {
366+ GetLogger ().Debug ("Redirected request is not within the trusted domain." )
367+ }
368+ } else {
369+ GetLogger ().Debug ("Detected no prior requests!" )
370+ }
371+ return nil
372+ }
373+
374+ // shouldCopySafeHeadersOnRedirect returns true iff safe headers should be copied
375+ // to a redirected request.
376+ func shouldCopySafeHeadersOnRedirect (fromHost , toHost string ) bool {
377+ GetLogger ().Debug ("hosts: %s %s" , fromHost , toHost )
378+ return strings .HasSuffix (fromHost , ".cloud.ibm.com" ) && strings .HasSuffix (toHost , ".cloud.ibm.com" )
310379}
311380
312381// SetEnableGzipCompression sets the service's EnableGzipCompression field
@@ -693,7 +762,7 @@ func (service *BaseService) DisableRetries() {
693762// DefaultHTTPClient returns a non-retryable http client with default configuration.
694763func DefaultHTTPClient () * http.Client {
695764 client := cleanhttp .DefaultPooledClient ()
696- setMinimumTLSVersion (client )
765+ setupHTTPClient (client )
697766 return client
698767}
699768
@@ -731,7 +800,7 @@ func NewRetryableClientWithHTTPClient(httpClient *http.Client) *retryablehttp.Cl
731800 // as our embedded client used to invoke individual requests.
732801 client .HTTPClient = httpClient
733802 } else {
734- // Otherwise, we'll use construct a default HTTP client and use that
803+ // Otherwise, we'll construct a default HTTP client and use that
735804 client .HTTPClient = DefaultHTTPClient ()
736805 }
737806
0 commit comments