Skip to content

Commit 54aef94

Browse files
Exposes configuration function to be able to set the Transport - Added test
Signed-off-by: Leandro Deveikis <[email protected]>
1 parent e2855eb commit 54aef94

File tree

4 files changed

+36
-4
lines changed

4 files changed

+36
-4
lines changed

connector.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,10 @@ func WithAuthenticator(authr auth.Authenticator) connOption {
214214
c.Authenticator = authr
215215
}
216216
}
217+
218+
// WithTransport sets up the transport configuration to be used by the httpclient.
219+
func WithTransport(t http.RoundTripper) connOption {
220+
return func(c *config.Config) {
221+
c.Transport = t
222+
}
223+
}

connector_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package dbsql
22

33
import (
4+
"net/http"
45
"testing"
56
"time"
67

@@ -22,6 +23,7 @@ func TestNewConnector(t *testing.T) {
2223
schema := "schema-string"
2324
userAgentEntry := "user-agent"
2425
sessionParams := map[string]string{"key": "value"}
26+
roundTripper := mockRoundTripper{}
2527
con, err := NewConnector(
2628
WithServerHostname(host),
2729
WithPort(port),
@@ -33,6 +35,7 @@ func TestNewConnector(t *testing.T) {
3335
WithUserAgentEntry(userAgentEntry),
3436
WithSessionParams(sessionParams),
3537
WithRetries(10, 3*time.Second, 60*time.Second),
38+
WithTransport(roundTripper),
3639
)
3740
expectedUserConfig := config.UserConfig{
3841
Host: host,
@@ -50,6 +53,7 @@ func TestNewConnector(t *testing.T) {
5053
RetryMax: 10,
5154
RetryWaitMin: 3 * time.Second,
5255
RetryWaitMax: 60 * time.Second,
56+
Transport: roundTripper,
5357
}
5458
expectedCfg := config.WithDefaults()
5559
expectedCfg.DriverVersion = DriverVersion
@@ -127,3 +131,11 @@ func TestNewConnector(t *testing.T) {
127131
assert.Equal(t, expectedCfg, coni.cfg)
128132
})
129133
}
134+
135+
type mockRoundTripper struct{}
136+
137+
var _ http.RoundTripper = mockRoundTripper{}
138+
139+
func (m mockRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
140+
return &http.Response{StatusCode: 200}, nil
141+
}

internal/client/client.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ func isRetryableServerResponse(resp *http.Response) bool {
319319
}
320320

321321
type Transport struct {
322-
Base *http.Transport
322+
Base http.RoundTripper
323323
Authr auth.Authenticator
324324
trace bool
325325
}
@@ -396,10 +396,20 @@ func PooledClient(cfg *config.Config) *http.Client {
396396
if cfg.Authenticator == nil {
397397
return nil
398398
}
399-
tr := &Transport{
400-
Base: PooledTransport(),
401-
Authr: cfg.Authenticator,
399+
400+
var tr *Transport
401+
if cfg.Transport != nil {
402+
tr = &Transport{
403+
Base: cfg.Transport,
404+
Authr: cfg.Authenticator,
405+
}
406+
} else {
407+
tr = &Transport{
408+
Base: PooledTransport(),
409+
Authr: cfg.Authenticator,
410+
}
402411
}
412+
403413
return &http.Client{
404414
Transport: tr,
405415
Timeout: cfg.ClientTimeout,

internal/config/config.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"crypto/tls"
66
"fmt"
7+
"net/http"
78
"net/url"
89
"strconv"
910
"strings"
@@ -97,6 +98,7 @@ type UserConfig struct {
9798
RetryWaitMin time.Duration
9899
RetryWaitMax time.Duration
99100
RetryMax int
101+
Transport http.RoundTripper
100102
}
101103

102104
// DeepCopy returns a true deep copy of UserConfig
@@ -135,6 +137,7 @@ func (ucfg UserConfig) DeepCopy() UserConfig {
135137
RetryWaitMin: ucfg.RetryWaitMin,
136138
RetryWaitMax: ucfg.RetryWaitMax,
137139
RetryMax: ucfg.RetryMax,
140+
Transport: ucfg.Transport,
138141
}
139142
}
140143

0 commit comments

Comments
 (0)