Skip to content

Commit 5ca9b51

Browse files
committed
Swap implementation to use existing transport setting
1 parent 35320c6 commit 5ca9b51

File tree

5 files changed

+25
-62
lines changed

5 files changed

+25
-62
lines changed

connector.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,6 @@ func WithCloudFetch(useCloudFetch bool) ConnOption {
269269
}
270270
}
271271

272-
// WithHTTPClient allows a custom http client to be used for cloud fetch. Default is http.DefaultClient.
273-
func WithHTTPClient(httpClient *http.Client) ConnOption {
274-
return func(c *config.Config) {
275-
c.UserConfig.CloudFetchConfig.HTTPClient = httpClient
276-
}
277-
}
278-
279272
// WithMaxDownloadThreads sets up maximum download threads for cloud fetch. Default is 10.
280273
func WithMaxDownloadThreads(numThreads int) ConnOption {
281274
return func(c *config.Config) {

connector_test.go

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -246,42 +246,6 @@ func TestNewConnector(t *testing.T) {
246246
require.True(t, ok)
247247
assert.False(t, coni.cfg.EnableMetricViewMetadata)
248248
})
249-
250-
t.Run("Connector test WithCloudFetchHTTPClient sets custom client", func(t *testing.T) {
251-
host := "databricks-host"
252-
accessToken := "token"
253-
httpPath := "http-path"
254-
customClient := &http.Client{Timeout: 5 * time.Second}
255-
256-
con, err := NewConnector(
257-
WithServerHostname(host),
258-
WithAccessToken(accessToken),
259-
WithHTTPPath(httpPath),
260-
WithHTTPClient(customClient),
261-
)
262-
assert.Nil(t, err)
263-
264-
coni, ok := con.(*connector)
265-
require.True(t, ok)
266-
assert.Equal(t, customClient, coni.cfg.UserConfig.CloudFetchConfig.HTTPClient)
267-
})
268-
269-
t.Run("Connector test WithCloudFetchHTTPClient with nil client is accepted", func(t *testing.T) {
270-
host := "databricks-host"
271-
accessToken := "token"
272-
httpPath := "http-path"
273-
274-
con, err := NewConnector(
275-
WithServerHostname(host),
276-
WithAccessToken(accessToken),
277-
WithHTTPPath(httpPath),
278-
)
279-
assert.Nil(t, err)
280-
281-
coni, ok := con.(*connector)
282-
require.True(t, ok)
283-
assert.Nil(t, coni.cfg.UserConfig.CloudFetchConfig.HTTPClient)
284-
})
285249
}
286250

287251
type mockRoundTripper struct{}

internal/config/config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ type CloudFetchConfig struct {
479479
MaxFilesInMemory int
480480
MinTimeToExpiry time.Duration
481481
CloudFetchSpeedThresholdMbps float64 // Minimum download speed in MBps before WARN logging (default: 0.1)
482-
HTTPClient *http.Client
482+
Transport http.RoundTripper
483483
}
484484

485485
func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig {

internal/rows/arrowbased/batchloader.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func NewCloudIPCStreamIterator(
4040
startRowOffset: startRowOffset,
4141
pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](),
4242
downloadTasks: NewQueue[cloudFetchDownloadTask](),
43-
httpClient: cfg.UserConfig.CloudFetchConfig.HTTPClient,
43+
transport: cfg.UserConfig.Transport,
4444
}
4545

4646
for _, link := range files {
@@ -141,7 +141,7 @@ type cloudIPCStreamIterator struct {
141141
startRowOffset int64
142142
pendingLinks Queue[cli_service.TSparkArrowResultLink]
143143
downloadTasks Queue[cloudFetchDownloadTask]
144-
httpClient *http.Client
144+
transport http.RoundTripper
145145
}
146146

147147
var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil)
@@ -164,7 +164,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
164164
resultChan: make(chan cloudFetchDownloadTaskResult),
165165
minTimeToExpiry: bi.cfg.MinTimeToExpiry,
166166
speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps,
167-
httpClient: bi.httpClient,
167+
transport: bi.transport,
168168
}
169169
task.Run()
170170
bi.downloadTasks.Enqueue(task)
@@ -213,7 +213,7 @@ type cloudFetchDownloadTask struct {
213213
link *cli_service.TSparkArrowResultLink
214214
resultChan chan cloudFetchDownloadTaskResult
215215
speedThresholdMbps float64
216-
httpClient *http.Client
216+
transport http.RoundTripper
217217
}
218218

219219
func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, error) {
@@ -256,7 +256,7 @@ func (cft *cloudFetchDownloadTask) Run() {
256256
cft.link.StartRowOffset,
257257
cft.link.RowCount,
258258
)
259-
data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.httpClient)
259+
data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.transport)
260260
if err != nil {
261261
cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err}
262262
return
@@ -304,7 +304,7 @@ func fetchBatchBytes(
304304
link *cli_service.TSparkArrowResultLink,
305305
minTimeToExpiry time.Duration,
306306
speedThresholdMbps float64,
307-
httpClient *http.Client,
307+
transport http.RoundTripper,
308308
) (io.ReadCloser, error) {
309309
if isLinkExpired(link.ExpiryTime, minTimeToExpiry) {
310310
return nil, errors.New(dbsqlerr.ErrLinkExpired)
@@ -322,8 +322,11 @@ func fetchBatchBytes(
322322
}
323323
}
324324

325-
if httpClient == nil {
326-
httpClient = http.DefaultClient
325+
httpClient := http.DefaultClient
326+
if transport != nil {
327+
httpClient = &http.Client{
328+
Transport: transport,
329+
}
327330
}
328331

329332
startTime := time.Now()

internal/rows/arrowbased/batchloader_test.go

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,11 @@ func TestCloudFetchIterator(t *testing.T) {
254254
assert.ErrorContains(t, err3, fmt.Sprintf("%s %d", "HTTP error", http.StatusNotFound))
255255
})
256256

257-
t.Run("should use custom HTTPClient when provided", func(t *testing.T) {
258-
customClient := &http.Client{Timeout: 5 * time.Second}
257+
t.Run("should use custom Transport when provided", func(t *testing.T) {
258+
customTransport := &http.Transport{
259+
MaxIdleConns: 10,
260+
MaxIdleConnsPerHost: 5,
261+
}
259262
requestCount := 0
260263

261264
handler = func(w http.ResponseWriter, r *http.Request) {
@@ -281,7 +284,7 @@ func TestCloudFetchIterator(t *testing.T) {
281284
cfg := config.WithDefaults()
282285
cfg.UseLz4Compression = false
283286
cfg.MaxDownloadThreads = 1
284-
cfg.UserConfig.CloudFetchConfig.HTTPClient = customClient
287+
cfg.UserConfig.Transport = customTransport
285288

286289
bi, err := NewCloudBatchIterator(
287290
context.Background(),
@@ -291,21 +294,21 @@ func TestCloudFetchIterator(t *testing.T) {
291294
)
292295
assert.Nil(t, err)
293296

294-
// Verify custom client is passed through the iterator chain
297+
// Verify custom transport is passed through the iterator chain
295298
wrapper, ok := bi.(*batchIterator)
296299
assert.True(t, ok)
297300
cbi, ok := wrapper.ipcIterator.(*cloudIPCStreamIterator)
298301
assert.True(t, ok)
299-
assert.Equal(t, customClient, cbi.httpClient)
302+
assert.Equal(t, customTransport, cbi.transport)
300303

301-
// Fetch should work with custom client
304+
// Fetch should work with custom transport
302305
sab1, nextErr := bi.Next()
303306
assert.Nil(t, nextErr)
304307
assert.NotNil(t, sab1)
305308
assert.Greater(t, requestCount, 0) // Verify request was made
306309
})
307310

308-
t.Run("should use http.DefaultClient when HTTPClient is nil", func(t *testing.T) {
311+
t.Run("should use http.DefaultClient when Transport is nil", func(t *testing.T) {
309312
handler = func(w http.ResponseWriter, r *http.Request) {
310313
w.WriteHeader(http.StatusOK)
311314
_, err := w.Write(generateMockArrowBytes(generateArrowRecord()))
@@ -328,7 +331,7 @@ func TestCloudFetchIterator(t *testing.T) {
328331
cfg := config.WithDefaults()
329332
cfg.UseLz4Compression = false
330333
cfg.MaxDownloadThreads = 1
331-
// HTTPClient is nil by default
334+
// Transport is nil by default
332335

333336
bi, err := NewCloudBatchIterator(
334337
context.Background(),
@@ -338,12 +341,12 @@ func TestCloudFetchIterator(t *testing.T) {
338341
)
339342
assert.Nil(t, err)
340343

341-
// Verify nil client is passed through
344+
// Verify nil transport is passed through
342345
wrapper, ok := bi.(*batchIterator)
343346
assert.True(t, ok)
344347
cbi, ok := wrapper.ipcIterator.(*cloudIPCStreamIterator)
345348
assert.True(t, ok)
346-
assert.Nil(t, cbi.httpClient)
349+
assert.Nil(t, cbi.transport)
347350

348351
// Fetch should work (falls back to http.DefaultClient)
349352
sab1, nextErr := bi.Next()

0 commit comments

Comments
 (0)