Skip to content

Commit c4af6fa

Browse files
Cloudfetch: Allow configuration of httpclient for cloudfetch (#308)
[issues](#307) Hello, as per issue looking for ways to modify the transport layer of the httpclient that cloudfetch uses. Happy to go another way to solving this, this just seamed like the simplest. Thanks for your work on the driver, its been very useful 👍 --------- Authored-by: Tim Mulqueen <[email protected]> Co-authored-by: Samikshya Chand <[email protected]>
1 parent 29c881a commit c4af6fa

File tree

5 files changed

+118
-3
lines changed

5 files changed

+118
-3
lines changed

connector.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,12 @@ func WithAuthenticator(authr auth.Authenticator) ConnOption {
260260
func WithTransport(t http.RoundTripper) ConnOption {
261261
return func(c *config.Config) {
262262
c.Transport = t
263+
264+
if c.CloudFetchConfig.HTTPClient == nil {
265+
c.CloudFetchConfig.HTTPClient = &http.Client{
266+
Transport: t,
267+
}
268+
}
263269
}
264270
}
265271

connector_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ func TestNewConnector(t *testing.T) {
4848
MaxFilesInMemory: 10,
4949
MinTimeToExpiry: 0 * time.Second,
5050
CloudFetchSpeedThresholdMbps: 0.1,
51+
HTTPClient: &http.Client{Transport: roundTripper},
5152
}
5253
expectedUserConfig := config.UserConfig{
5354
Host: host,
@@ -246,6 +247,25 @@ func TestNewConnector(t *testing.T) {
246247
require.True(t, ok)
247248
assert.False(t, coni.cfg.EnableMetricViewMetadata)
248249
})
250+
251+
t.Run("Connector test WithTransport sets HTTPClient in CloudFetchConfig", func(t *testing.T) {
252+
host := "databricks-host"
253+
accessToken := "token"
254+
httpPath := "http-path"
255+
customTransport := &http.Transport{MaxIdleConns: 10}
256+
con, err := NewConnector(
257+
WithServerHostname(host),
258+
WithAccessToken(accessToken),
259+
WithHTTPPath(httpPath),
260+
WithTransport(customTransport),
261+
)
262+
assert.Nil(t, err)
263+
264+
coni, ok := con.(*connector)
265+
require.True(t, ok)
266+
assert.NotNil(t, coni.cfg.CloudFetchConfig.HTTPClient)
267+
assert.Equal(t, customTransport, coni.cfg.CloudFetchConfig.HTTPClient.Transport)
268+
})
249269
}
250270

251271
type mockRoundTripper struct{}

internal/config/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ type CloudFetchConfig struct {
479479
MaxFilesInMemory int
480480
MinTimeToExpiry time.Duration
481481
CloudFetchSpeedThresholdMbps float64 // Minimum download speed in MBps before WARN logging (default: 0.1)
482+
HTTPClient *http.Client
482483
}
483484

484485
func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig {

internal/rows/arrowbased/batchloader.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,18 @@ func NewCloudIPCStreamIterator(
3434
startRowOffset int64,
3535
cfg *config.Config,
3636
) (IPCStreamIterator, dbsqlerr.DBError) {
37+
httpClient := http.DefaultClient
38+
if cfg.UserConfig.CloudFetchConfig.HTTPClient != nil {
39+
httpClient = cfg.UserConfig.CloudFetchConfig.HTTPClient
40+
}
41+
3742
bi := &cloudIPCStreamIterator{
3843
ctx: ctx,
3944
cfg: cfg,
4045
startRowOffset: startRowOffset,
4146
pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](),
4247
downloadTasks: NewQueue[cloudFetchDownloadTask](),
48+
httpClient: httpClient,
4349
}
4450

4551
for _, link := range files {
@@ -140,6 +146,7 @@ type cloudIPCStreamIterator struct {
140146
startRowOffset int64
141147
pendingLinks Queue[cli_service.TSparkArrowResultLink]
142148
downloadTasks Queue[cloudFetchDownloadTask]
149+
httpClient *http.Client
143150
}
144151

145152
var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil)
@@ -162,6 +169,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
162169
resultChan: make(chan cloudFetchDownloadTaskResult),
163170
minTimeToExpiry: bi.cfg.MinTimeToExpiry,
164171
speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps,
172+
httpClient: bi.httpClient,
165173
}
166174
task.Run()
167175
bi.downloadTasks.Enqueue(task)
@@ -210,6 +218,7 @@ type cloudFetchDownloadTask struct {
210218
link *cli_service.TSparkArrowResultLink
211219
resultChan chan cloudFetchDownloadTaskResult
212220
speedThresholdMbps float64
221+
httpClient *http.Client
213222
}
214223

215224
func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, error) {
@@ -252,7 +261,7 @@ func (cft *cloudFetchDownloadTask) Run() {
252261
cft.link.StartRowOffset,
253262
cft.link.RowCount,
254263
)
255-
data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps)
264+
data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.httpClient)
256265
if err != nil {
257266
cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err}
258267
return
@@ -300,6 +309,7 @@ func fetchBatchBytes(
300309
link *cli_service.TSparkArrowResultLink,
301310
minTimeToExpiry time.Duration,
302311
speedThresholdMbps float64,
312+
httpClient *http.Client,
303313
) (io.ReadCloser, error) {
304314
if isLinkExpired(link.ExpiryTime, minTimeToExpiry) {
305315
return nil, errors.New(dbsqlerr.ErrLinkExpired)
@@ -318,8 +328,7 @@ func fetchBatchBytes(
318328
}
319329

320330
startTime := time.Now()
321-
client := http.DefaultClient
322-
res, err := client.Do(req)
331+
res, err := httpClient.Do(req)
323332
if err != nil {
324333
return nil, err
325334
}

internal/rows/arrowbased/batchloader_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,85 @@ func TestCloudFetchIterator(t *testing.T) {
253253
assert.NotNil(t, err3)
254254
assert.ErrorContains(t, err3, fmt.Sprintf("%s %d", "HTTP error", http.StatusNotFound))
255255
})
256+
257+
t.Run("should use custom HTTPClient when provided", func(t *testing.T) {
258+
handler = func(w http.ResponseWriter, r *http.Request) {
259+
w.WriteHeader(http.StatusOK)
260+
_, err := w.Write(generateMockArrowBytes(generateArrowRecord()))
261+
if err != nil {
262+
panic(err)
263+
}
264+
}
265+
266+
startRowOffset := int64(100)
267+
customHTTPClient := &http.Client{
268+
Transport: &http.Transport{MaxIdleConns: 10},
269+
}
270+
271+
cfg := config.WithDefaults()
272+
cfg.UseLz4Compression = false
273+
cfg.MaxDownloadThreads = 1
274+
cfg.UserConfig.CloudFetchConfig.HTTPClient = customHTTPClient
275+
276+
bi, err := NewCloudBatchIterator(
277+
context.Background(),
278+
[]*cli_service.TSparkArrowResultLink{{
279+
FileLink: server.URL,
280+
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
281+
StartRowOffset: startRowOffset,
282+
RowCount: 1,
283+
}},
284+
startRowOffset,
285+
cfg,
286+
)
287+
assert.Nil(t, err)
288+
289+
cbi := bi.(*batchIterator).ipcIterator.(*cloudIPCStreamIterator)
290+
assert.Equal(t, customHTTPClient, cbi.httpClient)
291+
292+
// Verify fetch works
293+
sab, nextErr := bi.Next()
294+
assert.Nil(t, nextErr)
295+
assert.NotNil(t, sab)
296+
})
297+
298+
t.Run("should fallback to http.DefaultClient when HTTPClient is nil", func(t *testing.T) {
299+
handler = func(w http.ResponseWriter, r *http.Request) {
300+
w.WriteHeader(http.StatusOK)
301+
_, err := w.Write(generateMockArrowBytes(generateArrowRecord()))
302+
if err != nil {
303+
panic(err)
304+
}
305+
}
306+
307+
startRowOffset := int64(100)
308+
cfg := config.WithDefaults()
309+
cfg.UseLz4Compression = false
310+
cfg.MaxDownloadThreads = 1
311+
// Explicitly set HTTPClient to nil to verify fallback behavior
312+
cfg.UserConfig.CloudFetchConfig.HTTPClient = nil
313+
314+
bi, err := NewCloudBatchIterator(
315+
context.Background(),
316+
[]*cli_service.TSparkArrowResultLink{{
317+
FileLink: server.URL,
318+
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
319+
StartRowOffset: startRowOffset,
320+
RowCount: 1,
321+
}},
322+
startRowOffset,
323+
cfg,
324+
)
325+
assert.Nil(t, err)
326+
327+
cbi := bi.(*batchIterator).ipcIterator.(*cloudIPCStreamIterator)
328+
assert.Equal(t, http.DefaultClient, cbi.httpClient)
329+
330+
// Verify fetch works with default client
331+
sab, nextErr := bi.Next()
332+
assert.Nil(t, nextErr)
333+
assert.NotNil(t, sab)
334+
})
256335
}
257336

258337
func generateArrowRecord() arrow.Record {

0 commit comments

Comments
 (0)