Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ type CloudFetchConfig struct {
MaxFilesInMemory int
MinTimeToExpiry time.Duration
CloudFetchSpeedThresholdMbps float64 // Minimum download speed in MBps before WARN logging (default: 0.1)
Transport http.RoundTripper
}

func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig {
Expand Down
17 changes: 14 additions & 3 deletions internal/rows/arrowbased/batchloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func NewCloudIPCStreamIterator(
startRowOffset: startRowOffset,
pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](),
downloadTasks: NewQueue[cloudFetchDownloadTask](),
transport: cfg.UserConfig.Transport,
}

for _, link := range files {
Expand Down Expand Up @@ -140,6 +141,7 @@ type cloudIPCStreamIterator struct {
startRowOffset int64
pendingLinks Queue[cli_service.TSparkArrowResultLink]
downloadTasks Queue[cloudFetchDownloadTask]
transport http.RoundTripper
}

var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil)
Expand All @@ -162,6 +164,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
resultChan: make(chan cloudFetchDownloadTaskResult),
minTimeToExpiry: bi.cfg.MinTimeToExpiry,
speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps,
transport: bi.transport,
}
task.Run()
bi.downloadTasks.Enqueue(task)
Expand Down Expand Up @@ -210,6 +213,7 @@ type cloudFetchDownloadTask struct {
link *cli_service.TSparkArrowResultLink
resultChan chan cloudFetchDownloadTaskResult
speedThresholdMbps float64
transport http.RoundTripper
}

func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, error) {
Expand Down Expand Up @@ -252,7 +256,7 @@ func (cft *cloudFetchDownloadTask) Run() {
cft.link.StartRowOffset,
cft.link.RowCount,
)
data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps)
data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.transport)
if err != nil {
cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err}
return
Expand Down Expand Up @@ -300,6 +304,7 @@ func fetchBatchBytes(
link *cli_service.TSparkArrowResultLink,
minTimeToExpiry time.Duration,
speedThresholdMbps float64,
transport http.RoundTripper,
) (io.ReadCloser, error) {
if isLinkExpired(link.ExpiryTime, minTimeToExpiry) {
return nil, errors.New(dbsqlerr.ErrLinkExpired)
Expand All @@ -317,9 +322,15 @@ func fetchBatchBytes(
}
}

httpClient := http.DefaultClient
if transport != nil {
httpClient = &http.Client{
Transport: transport,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean a new HttpClient is created on every fetch request when a custom transport is provided? If yes, can we make that more optimal?

}
}

startTime := time.Now()
client := http.DefaultClient
res, err := client.Do(req)
res, err := httpClient.Do(req)
if err != nil {
return nil, err
}
Expand Down
69 changes: 69 additions & 0 deletions internal/rows/arrowbased/batchloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,75 @@
assert.NotNil(t, err3)
assert.ErrorContains(t, err3, fmt.Sprintf("%s %d", "HTTP error", http.StatusNotFound))
})

t.Run("should use custom Transport when provided", func(t *testing.T) {
handler = func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write(generateMockArrowBytes(generateArrowRecord()))

Check failure on line 260 in internal/rows/arrowbased/batchloader_test.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `w.Write` is not checked (errcheck)
}

startRowOffset := int64(100)
customTransport := &http.Transport{MaxIdleConns: 10}

cfg := config.WithDefaults()
cfg.UseLz4Compression = false
cfg.MaxDownloadThreads = 1
cfg.UserConfig.Transport = customTransport

bi, err := NewCloudBatchIterator(
context.Background(),
[]*cli_service.TSparkArrowResultLink{{
FileLink: server.URL,
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
StartRowOffset: startRowOffset,
RowCount: 1,
}},
startRowOffset,
cfg,
)
assert.Nil(t, err)

cbi := bi.(*batchIterator).ipcIterator.(*cloudIPCStreamIterator)
assert.Equal(t, customTransport, cbi.transport)

// Verify fetch works
sab, nextErr := bi.Next()
assert.Nil(t, nextErr)
assert.NotNil(t, sab)
})

t.Run("should fallback to http.DefaultClient when Transport is nil", func(t *testing.T) {
handler = func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write(generateMockArrowBytes(generateArrowRecord()))

Check failure on line 296 in internal/rows/arrowbased/batchloader_test.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `w.Write` is not checked (errcheck)
}

startRowOffset := int64(100)
cfg := config.WithDefaults()
cfg.UseLz4Compression = false
cfg.MaxDownloadThreads = 1

bi, err := NewCloudBatchIterator(
context.Background(),
[]*cli_service.TSparkArrowResultLink{{
FileLink: server.URL,
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
StartRowOffset: startRowOffset,
RowCount: 1,
}},
startRowOffset,
cfg,
)
assert.Nil(t, err)

cbi := bi.(*batchIterator).ipcIterator.(*cloudIPCStreamIterator)
assert.Nil(t, cbi.transport)

// Verify fetch works with default client
sab, nextErr := bi.Next()
assert.Nil(t, nextErr)
assert.NotNil(t, sab)
})
}

func generateArrowRecord() arrow.Record {
Expand Down
Loading