@@ -401,11 +401,24 @@ type ErrorHandler func(resp *http.Response, err error, numTries int) (*http.Resp
401401// PrepareRetry is called before retry operation. It can be used for example to re-sign the request
402402type PrepareRetry func (req * http.Request ) error
403403
404+ type HTTPClient interface {
405+ // Do performs an HTTP request and returns an HTTP response.
406+ Do (* http.Request ) (* http.Response , error )
407+ // Done is called when the client is no longer needed.
408+ Done ()
409+ }
410+
411+ type HTTPClientFactory interface {
412+ // New returns an HTTP client to use for a request, including retries.
413+ New () HTTPClient
414+ }
415+
404416// Client is used to make HTTP requests. It adds additional functionality
405417// like automatic retries to tolerate minor outages.
406418type Client struct {
407- HTTPClient * http.Client // Internal HTTP client.
408- Logger interface {} // Customer logger instance. Can be either Logger or LeveledLogger
419+ HTTPClient * http.Client // Internal HTTP client. This field is used if set, otherwise HTTPClientFactory is used.
420+ HTTPClientFactory HTTPClientFactory
421+ Logger interface {} // Customer logger instance. Can be either Logger or LeveledLogger
409422
410423 RetryWaitMin time.Duration // Minimum time to wait
411424 RetryWaitMax time.Duration // Maximum time to wait
@@ -433,19 +446,18 @@ type Client struct {
433446 PrepareRetry PrepareRetry
434447
435448 loggerInit sync.Once
436- clientInit sync.Once
437449}
438450
439451// NewClient creates a new Client with default settings.
440452func NewClient () * Client {
441453 return & Client {
442- HTTPClient : cleanhttp . DefaultPooledClient () ,
443- Logger : defaultLogger ,
444- RetryWaitMin : defaultRetryWaitMin ,
445- RetryWaitMax : defaultRetryWaitMax ,
446- RetryMax : defaultRetryMax ,
447- CheckRetry : DefaultRetryPolicy ,
448- Backoff : DefaultBackoff ,
454+ HTTPClientFactory : & CleanPooledClientFactory {} ,
455+ Logger : defaultLogger ,
456+ RetryWaitMin : defaultRetryWaitMin ,
457+ RetryWaitMax : defaultRetryWaitMax ,
458+ RetryMax : defaultRetryMax ,
459+ CheckRetry : DefaultRetryPolicy ,
460+ Backoff : DefaultBackoff ,
449461 }
450462}
451463
@@ -647,12 +659,6 @@ func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Respo
647659
648660// Do wraps calling an HTTP method with retries.
649661func (c * Client ) Do (req * Request ) (* http.Response , error ) {
650- c .clientInit .Do (func () {
651- if c .HTTPClient == nil {
652- c .HTTPClient = cleanhttp .DefaultPooledClient ()
653- }
654- })
655-
656662 logger := c .logger ()
657663
658664 if logger != nil {
@@ -664,6 +670,9 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
664670 }
665671 }
666672
673+ httpClient := c .getHTTPClient ()
674+ defer httpClient .Done ()
675+
667676 var resp * http.Response
668677 var attempt int
669678 var shouldRetry bool
@@ -677,7 +686,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
677686 if req .body != nil {
678687 body , err := req .body ()
679688 if err != nil {
680- c .HTTPClient .CloseIdleConnections ()
681689 return resp , err
682690 }
683691 if c , ok := body .(io.ReadCloser ); ok {
@@ -699,7 +707,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
699707 }
700708
701709 // Attempt the request
702- resp , doErr = c .HTTPClient .Do (req .Request )
710+
711+ resp , doErr = httpClient .Do (req .Request )
703712
704713 // Check if we should continue with retries.
705714 shouldRetry , checkErr = c .CheckRetry (req .Context (), resp , doErr )
@@ -768,7 +777,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
768777 select {
769778 case <- req .Context ().Done ():
770779 timer .Stop ()
771- c .HTTPClient .CloseIdleConnections ()
772780 return nil , req .Context ().Err ()
773781 case <- timer .C :
774782 }
@@ -791,8 +799,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
791799 return resp , nil
792800 }
793801
794- defer c .HTTPClient .CloseIdleConnections ()
795-
796802 var err error
797803 if prepareErr != nil {
798804 err = prepareErr
@@ -841,6 +847,19 @@ func (c *Client) drainBody(body io.ReadCloser) {
841847 }
842848}
843849
850+ func (c * Client ) getHTTPClient () HTTPClient {
851+ if c .HTTPClient != nil {
852+ return & idleConnectionsClosingClient {
853+ httpClient : c .HTTPClient ,
854+ }
855+ }
856+ clientFactory := c .HTTPClientFactory
857+ if clientFactory == nil {
858+ clientFactory = & CleanPooledClientFactory {}
859+ }
860+ return clientFactory .New ()
861+ }
862+
844863// Get is a shortcut for doing a GET request without making a new client.
845864func Get (url string ) (* http.Response , error ) {
846865 return defaultClient .Get (url )
@@ -917,3 +936,29 @@ func redactURL(u *url.URL) string {
917936 }
918937 return ru .String ()
919938}
939+
940+ var (
941+ _ HTTPClientFactory = & CleanPooledClientFactory {}
942+ _ HTTPClient = & idleConnectionsClosingClient {}
943+ )
944+
945+ type CleanPooledClientFactory struct {
946+ }
947+
948+ func (f * CleanPooledClientFactory ) New () HTTPClient {
949+ return & idleConnectionsClosingClient {
950+ httpClient : cleanhttp .DefaultPooledClient (),
951+ }
952+ }
953+
954+ type idleConnectionsClosingClient struct {
955+ httpClient * http.Client
956+ }
957+
958+ func (c * idleConnectionsClosingClient ) Do (req * http.Request ) (* http.Response , error ) {
959+ return c .httpClient .Do (req )
960+ }
961+
962+ func (c * idleConnectionsClosingClient ) Done () {
963+ c .httpClient .CloseIdleConnections ()
964+ }
0 commit comments