diff --git a/internal/config/config.go b/internal/config/config.go index 67437a9c..6956ab37 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -479,6 +479,7 @@ type CloudFetchConfig struct { MaxFilesInMemory int MinTimeToExpiry time.Duration CloudFetchSpeedThresholdMbps float64 // Minimum download speed in MBps before WARN logging (default: 0.1) + Transport http.RoundTripper } func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig { diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index d26d8a4a..545fe9c0 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -40,6 +40,7 @@ func NewCloudIPCStreamIterator( startRowOffset: startRowOffset, pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](), downloadTasks: NewQueue[cloudFetchDownloadTask](), + transport: cfg.UserConfig.Transport, } for _, link := range files { @@ -140,6 +141,7 @@ type cloudIPCStreamIterator struct { startRowOffset int64 pendingLinks Queue[cli_service.TSparkArrowResultLink] downloadTasks Queue[cloudFetchDownloadTask] + transport http.RoundTripper } var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil) @@ -162,6 +164,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { resultChan: make(chan cloudFetchDownloadTaskResult), minTimeToExpiry: bi.cfg.MinTimeToExpiry, speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps, + transport: bi.transport, } task.Run() bi.downloadTasks.Enqueue(task) @@ -210,6 +213,7 @@ type cloudFetchDownloadTask struct { link *cli_service.TSparkArrowResultLink resultChan chan cloudFetchDownloadTaskResult speedThresholdMbps float64 + transport http.RoundTripper } func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, error) { @@ -252,7 +256,7 @@ func (cft *cloudFetchDownloadTask) Run() { cft.link.StartRowOffset, cft.link.RowCount, ) - data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps) + data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.transport) if err != nil { cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err} return @@ -300,6 +304,7 @@ func fetchBatchBytes( link *cli_service.TSparkArrowResultLink, minTimeToExpiry time.Duration, speedThresholdMbps float64, + transport http.RoundTripper, ) (io.ReadCloser, error) { if isLinkExpired(link.ExpiryTime, minTimeToExpiry) { return nil, errors.New(dbsqlerr.ErrLinkExpired) @@ -317,9 +322,15 @@ func fetchBatchBytes( } } + httpClient := http.DefaultClient + if transport != nil { + httpClient = &http.Client{ + Transport: transport, + } + } + startTime := time.Now() - client := http.DefaultClient - res, err := client.Do(req) + res, err := httpClient.Do(req) if err != nil { return nil, err } diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index b018eb6d..70636fa8 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -253,6 +253,75 @@ func TestCloudFetchIterator(t *testing.T) { assert.NotNil(t, err3) assert.ErrorContains(t, err3, fmt.Sprintf("%s %d", "HTTP error", http.StatusNotFound)) }) + + t.Run("should use custom Transport when provided", func(t *testing.T) { + handler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write(generateMockArrowBytes(generateArrowRecord())) + } + + startRowOffset := int64(100) + customTransport := &http.Transport{MaxIdleConns: 10} + + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + cfg.UserConfig.Transport = customTransport + + bi, err := NewCloudBatchIterator( + context.Background(), + []*cli_service.TSparkArrowResultLink{{ + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }}, + startRowOffset, + cfg, + ) + assert.Nil(t, err) + + cbi := bi.(*batchIterator).ipcIterator.(*cloudIPCStreamIterator) + assert.Equal(t, customTransport, cbi.transport) + + // Verify fetch works + sab, nextErr := bi.Next() + assert.Nil(t, nextErr) + assert.NotNil(t, sab) + }) + + t.Run("should fallback to http.DefaultClient when Transport is nil", func(t *testing.T) { + handler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write(generateMockArrowBytes(generateArrowRecord())) + } + + startRowOffset := int64(100) + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + + bi, err := NewCloudBatchIterator( + context.Background(), + []*cli_service.TSparkArrowResultLink{{ + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }}, + startRowOffset, + cfg, + ) + assert.Nil(t, err) + + cbi := bi.(*batchIterator).ipcIterator.(*cloudIPCStreamIterator) + assert.Nil(t, cbi.transport) + + // Verify fetch works with default client + sab, nextErr := bi.Next() + assert.Nil(t, nextErr) + assert.NotNil(t, sab) + }) } func generateArrowRecord() arrow.Record {