Skip to content

Commit 9f722b0

Browse files
committed
Cloudfetch: Allow configuration of httpclient for cloudfetch
1 parent 0d8b25b commit 9f722b0

File tree

5 files changed

+152
-3
lines changed

5 files changed

+152
-3
lines changed

connector.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,13 @@ func WithCloudFetch(useCloudFetch bool) ConnOption {
269269
}
270270
}
271271

272+
// WithCloudFetchHTTPClient allows a custom http client to be used for cloud fetch. Default is http.DefaultClient.
273+
func WithCloudFetchHTTPClient(httpClient *http.Client) ConnOption {
274+
return func(c *config.Config) {
275+
c.UserConfig.CloudFetchConfig.HTTPClient = httpClient
276+
}
277+
}
278+
272279
// WithMaxDownloadThreads sets up maximum download threads for cloud fetch. Default is 10.
273280
func WithMaxDownloadThreads(numThreads int) ConnOption {
274281
return func(c *config.Config) {

connector_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,42 @@ func TestNewConnector(t *testing.T) {
246246
require.True(t, ok)
247247
assert.False(t, coni.cfg.EnableMetricViewMetadata)
248248
})
249+
250+
t.Run("Connector test WithCloudFetchHTTPClient sets custom client", func(t *testing.T) {
251+
host := "databricks-host"
252+
accessToken := "token"
253+
httpPath := "http-path"
254+
customClient := &http.Client{Timeout: 5 * time.Second}
255+
256+
con, err := NewConnector(
257+
WithServerHostname(host),
258+
WithAccessToken(accessToken),
259+
WithHTTPPath(httpPath),
260+
WithCloudFetchHTTPClient(customClient),
261+
)
262+
assert.Nil(t, err)
263+
264+
coni, ok := con.(*connector)
265+
require.True(t, ok)
266+
assert.Equal(t, customClient, coni.cfg.UserConfig.CloudFetchConfig.HTTPClient)
267+
})
268+
269+
t.Run("Connector test WithCloudFetchHTTPClient with nil client is accepted", func(t *testing.T) {
270+
host := "databricks-host"
271+
accessToken := "token"
272+
httpPath := "http-path"
273+
274+
con, err := NewConnector(
275+
WithServerHostname(host),
276+
WithAccessToken(accessToken),
277+
WithHTTPPath(httpPath),
278+
)
279+
assert.Nil(t, err)
280+
281+
coni, ok := con.(*connector)
282+
require.True(t, ok)
283+
assert.Nil(t, coni.cfg.UserConfig.CloudFetchConfig.HTTPClient)
284+
})
249285
}
250286

251287
type mockRoundTripper struct{}

internal/config/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ type CloudFetchConfig struct {
479479
MaxFilesInMemory int
480480
MinTimeToExpiry time.Duration
481481
CloudFetchSpeedThresholdMbps float64 // Minimum download speed in MBps before WARN logging (default: 0.1)
482+
HTTPClient *http.Client
482483
}
483484

484485
func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig {

internal/rows/arrowbased/batchloader.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ func NewCloudIPCStreamIterator(
4040
startRowOffset: startRowOffset,
4141
pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](),
4242
downloadTasks: NewQueue[cloudFetchDownloadTask](),
43+
httpClient: cfg.UserConfig.CloudFetchConfig.HTTPClient,
4344
}
4445

4546
for _, link := range files {
@@ -140,6 +141,7 @@ type cloudIPCStreamIterator struct {
140141
startRowOffset int64
141142
pendingLinks Queue[cli_service.TSparkArrowResultLink]
142143
downloadTasks Queue[cloudFetchDownloadTask]
144+
httpClient *http.Client
143145
}
144146

145147
var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil)
@@ -162,6 +164,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
162164
resultChan: make(chan cloudFetchDownloadTaskResult),
163165
minTimeToExpiry: bi.cfg.MinTimeToExpiry,
164166
speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps,
167+
httpClient: bi.httpClient,
165168
}
166169
task.Run()
167170
bi.downloadTasks.Enqueue(task)
@@ -210,6 +213,7 @@ type cloudFetchDownloadTask struct {
210213
link *cli_service.TSparkArrowResultLink
211214
resultChan chan cloudFetchDownloadTaskResult
212215
speedThresholdMbps float64
216+
httpClient *http.Client
213217
}
214218

215219
func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, error) {
@@ -252,7 +256,7 @@ func (cft *cloudFetchDownloadTask) Run() {
252256
cft.link.StartRowOffset,
253257
cft.link.RowCount,
254258
)
255-
data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps)
259+
data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.httpClient)
256260
if err != nil {
257261
cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err}
258262
return
@@ -300,6 +304,7 @@ func fetchBatchBytes(
300304
link *cli_service.TSparkArrowResultLink,
301305
minTimeToExpiry time.Duration,
302306
speedThresholdMbps float64,
307+
httpClient *http.Client,
303308
) (io.ReadCloser, error) {
304309
if isLinkExpired(link.ExpiryTime, minTimeToExpiry) {
305310
return nil, errors.New(dbsqlerr.ErrLinkExpired)
@@ -317,9 +322,12 @@ func fetchBatchBytes(
317322
}
318323
}
319324

325+
if httpClient == nil {
326+
httpClient = http.DefaultClient
327+
}
328+
320329
startTime := time.Now()
321-
client := http.DefaultClient
322-
res, err := client.Do(req)
330+
res, err := httpClient.Do(req)
323331
if err != nil {
324332
return nil, err
325333
}

internal/rows/arrowbased/batchloader_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,103 @@ func TestCloudFetchIterator(t *testing.T) {
253253
assert.NotNil(t, err3)
254254
assert.ErrorContains(t, err3, fmt.Sprintf("%s %d", "HTTP error", http.StatusNotFound))
255255
})
256+
257+
t.Run("should use custom HTTPClient when provided", func(t *testing.T) {
258+
customClient := &http.Client{Timeout: 5 * time.Second}
259+
requestCount := 0
260+
261+
handler = func(w http.ResponseWriter, r *http.Request) {
262+
requestCount++
263+
w.WriteHeader(http.StatusOK)
264+
_, err := w.Write(generateMockArrowBytes(generateArrowRecord()))
265+
if err != nil {
266+
panic(err)
267+
}
268+
}
269+
270+
startRowOffset := int64(100)
271+
272+
links := []*cli_service.TSparkArrowResultLink{
273+
{
274+
FileLink: server.URL,
275+
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
276+
StartRowOffset: startRowOffset,
277+
RowCount: 1,
278+
},
279+
}
280+
281+
cfg := config.WithDefaults()
282+
cfg.UseLz4Compression = false
283+
cfg.MaxDownloadThreads = 1
284+
cfg.UserConfig.CloudFetchConfig.HTTPClient = customClient
285+
286+
bi, err := NewCloudBatchIterator(
287+
context.Background(),
288+
links,
289+
startRowOffset,
290+
cfg,
291+
)
292+
assert.Nil(t, err)
293+
294+
// Verify custom client is passed through the iterator chain
295+
wrapper, ok := bi.(*batchIterator)
296+
assert.True(t, ok)
297+
cbi, ok := wrapper.ipcIterator.(*cloudIPCStreamIterator)
298+
assert.True(t, ok)
299+
assert.Equal(t, customClient, cbi.httpClient)
300+
301+
// Fetch should work with custom client
302+
sab1, nextErr := bi.Next()
303+
assert.Nil(t, nextErr)
304+
assert.NotNil(t, sab1)
305+
assert.Greater(t, requestCount, 0) // Verify request was made
306+
})
307+
308+
t.Run("should use http.DefaultClient when HTTPClient is nil", func(t *testing.T) {
309+
handler = func(w http.ResponseWriter, r *http.Request) {
310+
w.WriteHeader(http.StatusOK)
311+
_, err := w.Write(generateMockArrowBytes(generateArrowRecord()))
312+
if err != nil {
313+
panic(err)
314+
}
315+
}
316+
317+
startRowOffset := int64(100)
318+
319+
links := []*cli_service.TSparkArrowResultLink{
320+
{
321+
FileLink: server.URL,
322+
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
323+
StartRowOffset: startRowOffset,
324+
RowCount: 1,
325+
},
326+
}
327+
328+
cfg := config.WithDefaults()
329+
cfg.UseLz4Compression = false
330+
cfg.MaxDownloadThreads = 1
331+
// HTTPClient is nil by default
332+
333+
bi, err := NewCloudBatchIterator(
334+
context.Background(),
335+
links,
336+
startRowOffset,
337+
cfg,
338+
)
339+
assert.Nil(t, err)
340+
341+
// Verify nil client is passed through
342+
wrapper, ok := bi.(*batchIterator)
343+
assert.True(t, ok)
344+
cbi, ok := wrapper.ipcIterator.(*cloudIPCStreamIterator)
345+
assert.True(t, ok)
346+
assert.Nil(t, cbi.httpClient)
347+
348+
// Fetch should work (falls back to http.DefaultClient)
349+
sab1, nextErr := bi.Next()
350+
assert.Nil(t, nextErr)
351+
assert.NotNil(t, sab1)
352+
})
256353
}
257354

258355
func generateArrowRecord() arrow.Record {

0 commit comments

Comments
 (0)