Skip to content

Commit 7f3d33b

Browse files
authored
feat(csharp/src/Drivers/Apache): Add prefetch functionality to CloudFetch in Spark ADBC driver (apache#2678)
# Add Prefetch Functionality to CloudFetch in Spark ADBC Driver This PR enhances the CloudFetch feature in the Spark ADBC driver by implementing prefetch functionality, which improves performance by fetching multiple batches of results ahead of time. ## Changes ### CloudFetchResultFetcher Enhancements - **Initial Prefetch**: Added code to perform an initial prefetch of multiple batches when the fetcher starts, ensuring data is available immediately when needed. - **State Management**: Added tracking for current batch offset and size, with proper state reset when starting the fetcher. ### Interface Updates - Added new methods to `ICloudFetchResultFetcher` interface: ### Testing Infrastructure - Created `ITestableHiveServer2Statement` interface to facilitate testing - Updated tests to account for prefetch behavior - Ensured all tests pass with the new prefetch functionality ## Benefits - **Improved Performance**: By prefetching multiple batches, data is available sooner, reducing wait times. - **Better Reliability**: Enhanced error handling and state management make the system more robust. - **More Efficient Resource Usage**: Link caching reduces unnecessary server requests. This implementation maintains backward compatibility while providing significant performance improvements for CloudFetch operations.
1 parent 1c780f1 commit 7f3d33b

15 files changed

+2660
-184
lines changed

csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,9 @@ protected internal int QueryTimeoutSeconds
284284

285285
public TOperationHandle? OperationHandle { get; private set; }
286286

287+
// Keep the original Client property for internal use
288+
public TCLIService.Client Client => Connection.Client;
289+
287290
private void UpdatePollTimeIfValid(string key, string value) => PollTimeMilliseconds = !string.IsNullOrEmpty(key) && int.TryParse(value, result: out int pollTimeMilliseconds) && pollTimeMilliseconds >= 0
288291
? pollTimeMilliseconds
289292
: throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value greater than or equal to 0.");
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
using System;
19+
using System.Collections.Concurrent;
20+
using System.Collections.Generic;
21+
using System.Net.Http;
22+
using System.Threading;
23+
using System.Threading.Tasks;
24+
using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
25+
using Apache.Arrow.Adbc.Drivers.Databricks;
26+
27+
namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
28+
{
29+
/// <summary>
30+
/// Manages the CloudFetch download pipeline.
31+
/// </summary>
32+
internal sealed class CloudFetchDownloadManager : ICloudFetchDownloadManager
33+
{
34+
// Default values
35+
private const int DefaultParallelDownloads = 3;
36+
private const int DefaultPrefetchCount = 2;
37+
private const int DefaultMemoryBufferSizeMB = 200;
38+
private const bool DefaultPrefetchEnabled = true;
39+
private const int DefaultFetchBatchSize = 2000000;
40+
41+
private readonly DatabricksStatement _statement;
42+
private readonly Schema _schema;
43+
private readonly bool _isLz4Compressed;
44+
private readonly ICloudFetchMemoryBufferManager _memoryManager;
45+
private readonly BlockingCollection<IDownloadResult> _downloadQueue;
46+
private readonly BlockingCollection<IDownloadResult> _resultQueue;
47+
private readonly ICloudFetchResultFetcher _resultFetcher;
48+
private readonly ICloudFetchDownloader _downloader;
49+
private readonly HttpClient _httpClient;
50+
private bool _isDisposed;
51+
private bool _isStarted;
52+
private CancellationTokenSource? _cancellationTokenSource;
53+
54+
/// <summary>
55+
/// Initializes a new instance of the <see cref="CloudFetchDownloadManager"/> class.
56+
/// </summary>
57+
/// <param name="statement">The HiveServer2 statement.</param>
58+
/// <param name="schema">The Arrow schema.</param>
59+
/// <param name="isLz4Compressed">Whether the results are LZ4 compressed.</param>
60+
public CloudFetchDownloadManager(DatabricksStatement statement, Schema schema, bool isLz4Compressed)
61+
{
62+
_statement = statement ?? throw new ArgumentNullException(nameof(statement));
63+
_schema = schema ?? throw new ArgumentNullException(nameof(schema));
64+
_isLz4Compressed = isLz4Compressed;
65+
66+
// Get configuration values from connection properties
67+
var connectionProps = statement.Connection.Properties;
68+
69+
// Parse parallel downloads
70+
int parallelDownloads = DefaultParallelDownloads;
71+
if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchParallelDownloads, out string? parallelDownloadsStr))
72+
{
73+
if (int.TryParse(parallelDownloadsStr, out int parsedParallelDownloads) && parsedParallelDownloads > 0)
74+
{
75+
parallelDownloads = parsedParallelDownloads;
76+
}
77+
else
78+
{
79+
throw new ArgumentException($"Invalid value for {DatabricksParameters.CloudFetchParallelDownloads}: {parallelDownloadsStr}. Expected a positive integer.");
80+
}
81+
}
82+
83+
// Parse prefetch count
84+
int prefetchCount = DefaultPrefetchCount;
85+
if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchPrefetchCount, out string? prefetchCountStr))
86+
{
87+
if (int.TryParse(prefetchCountStr, out int parsedPrefetchCount) && parsedPrefetchCount > 0)
88+
{
89+
prefetchCount = parsedPrefetchCount;
90+
}
91+
else
92+
{
93+
throw new ArgumentException($"Invalid value for {DatabricksParameters.CloudFetchPrefetchCount}: {prefetchCountStr}. Expected a positive integer.");
94+
}
95+
}
96+
97+
// Parse memory buffer size
98+
int memoryBufferSizeMB = DefaultMemoryBufferSizeMB;
99+
if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchMemoryBufferSize, out string? memoryBufferSizeStr))
100+
{
101+
if (int.TryParse(memoryBufferSizeStr, out int parsedMemoryBufferSize) && parsedMemoryBufferSize > 0)
102+
{
103+
memoryBufferSizeMB = parsedMemoryBufferSize;
104+
}
105+
else
106+
{
107+
throw new ArgumentException($"Invalid value for {DatabricksParameters.CloudFetchMemoryBufferSize}: {memoryBufferSizeStr}. Expected a positive integer.");
108+
}
109+
}
110+
111+
// Parse max retries
112+
int maxRetries = 3;
113+
if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchMaxRetries, out string? maxRetriesStr))
114+
{
115+
if (int.TryParse(maxRetriesStr, out int parsedMaxRetries) && parsedMaxRetries > 0)
116+
{
117+
maxRetries = parsedMaxRetries;
118+
}
119+
else
120+
{
121+
throw new ArgumentException($"Invalid value for {DatabricksParameters.CloudFetchMaxRetries}: {maxRetriesStr}. Expected a positive integer.");
122+
}
123+
}
124+
125+
// Parse retry delay
126+
int retryDelayMs = 500;
127+
if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchRetryDelayMs, out string? retryDelayStr))
128+
{
129+
if (int.TryParse(retryDelayStr, out int parsedRetryDelay) && parsedRetryDelay > 0)
130+
{
131+
retryDelayMs = parsedRetryDelay;
132+
}
133+
else
134+
{
135+
throw new ArgumentException($"Invalid value for {DatabricksParameters.CloudFetchRetryDelayMs}: {retryDelayStr}. Expected a positive integer.");
136+
}
137+
}
138+
139+
// Parse timeout minutes
140+
int timeoutMinutes = 5;
141+
if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchTimeoutMinutes, out string? timeoutStr))
142+
{
143+
if (int.TryParse(timeoutStr, out int parsedTimeout) && parsedTimeout > 0)
144+
{
145+
timeoutMinutes = parsedTimeout;
146+
}
147+
else
148+
{
149+
throw new ArgumentException($"Invalid value for {DatabricksParameters.CloudFetchTimeoutMinutes}: {timeoutStr}. Expected a positive integer.");
150+
}
151+
}
152+
153+
// Initialize the memory manager
154+
_memoryManager = new CloudFetchMemoryBufferManager(memoryBufferSizeMB);
155+
156+
// Initialize the queues with bounded capacity
157+
_downloadQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), prefetchCount * 2);
158+
_resultQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), prefetchCount * 2);
159+
160+
// Initialize the HTTP client
161+
_httpClient = new HttpClient
162+
{
163+
Timeout = TimeSpan.FromMinutes(timeoutMinutes)
164+
};
165+
166+
// Initialize the result fetcher
167+
_resultFetcher = new CloudFetchResultFetcher(
168+
_statement,
169+
_memoryManager,
170+
_downloadQueue,
171+
DefaultFetchBatchSize);
172+
173+
// Initialize the downloader
174+
_downloader = new CloudFetchDownloader(
175+
_downloadQueue,
176+
_resultQueue,
177+
_memoryManager,
178+
_httpClient,
179+
parallelDownloads,
180+
_isLz4Compressed,
181+
maxRetries,
182+
retryDelayMs);
183+
}
184+
185+
/// <summary>
186+
/// Initializes a new instance of the <see cref="CloudFetchDownloadManager"/> class.
187+
/// This constructor is intended for testing purposes only.
188+
/// </summary>
189+
/// <param name="statement">The HiveServer2 statement.</param>
190+
/// <param name="schema">The Arrow schema.</param>
191+
/// <param name="isLz4Compressed">Whether the results are LZ4 compressed.</param>
192+
/// <param name="resultFetcher">The result fetcher.</param>
193+
/// <param name="downloader">The downloader.</param>
194+
internal CloudFetchDownloadManager(
195+
DatabricksStatement statement,
196+
Schema schema,
197+
bool isLz4Compressed,
198+
ICloudFetchResultFetcher resultFetcher,
199+
ICloudFetchDownloader downloader)
200+
{
201+
_statement = statement ?? throw new ArgumentNullException(nameof(statement));
202+
_schema = schema ?? throw new ArgumentNullException(nameof(schema));
203+
_isLz4Compressed = isLz4Compressed;
204+
_resultFetcher = resultFetcher ?? throw new ArgumentNullException(nameof(resultFetcher));
205+
_downloader = downloader ?? throw new ArgumentNullException(nameof(downloader));
206+
207+
// Create empty collections for the test
208+
_memoryManager = new CloudFetchMemoryBufferManager(DefaultMemoryBufferSizeMB);
209+
_downloadQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), 10);
210+
_resultQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), 10);
211+
_httpClient = new HttpClient();
212+
}
213+
214+
/// <inheritdoc />
215+
public bool HasMoreResults => !_downloader.IsCompleted || !_resultQueue.IsCompleted;
216+
217+
/// <inheritdoc />
218+
public async Task<IDownloadResult?> GetNextDownloadedFileAsync(CancellationToken cancellationToken)
219+
{
220+
ThrowIfDisposed();
221+
222+
if (!_isStarted)
223+
{
224+
throw new InvalidOperationException("Download manager has not been started.");
225+
}
226+
227+
try
228+
{
229+
return await _downloader.GetNextDownloadedFileAsync(cancellationToken).ConfigureAwait(false);
230+
}
231+
catch (Exception ex) when (_resultFetcher.HasError)
232+
{
233+
throw new AggregateException("Errors in download pipeline", new[] { ex, _resultFetcher.Error! });
234+
}
235+
}
236+
237+
/// <inheritdoc />
238+
public async Task StartAsync()
239+
{
240+
ThrowIfDisposed();
241+
242+
if (_isStarted)
243+
{
244+
throw new InvalidOperationException("Download manager is already started.");
245+
}
246+
247+
// Create a new cancellation token source
248+
_cancellationTokenSource = new CancellationTokenSource();
249+
250+
// Start the result fetcher
251+
await _resultFetcher.StartAsync(_cancellationTokenSource.Token).ConfigureAwait(false);
252+
253+
// Start the downloader
254+
await _downloader.StartAsync(_cancellationTokenSource.Token).ConfigureAwait(false);
255+
256+
_isStarted = true;
257+
}
258+
259+
/// <inheritdoc />
260+
public async Task StopAsync()
261+
{
262+
if (!_isStarted)
263+
{
264+
return;
265+
}
266+
267+
// Cancel the token to signal all operations to stop
268+
_cancellationTokenSource?.Cancel();
269+
270+
// Stop the downloader
271+
await _downloader.StopAsync().ConfigureAwait(false);
272+
273+
// Stop the result fetcher
274+
await _resultFetcher.StopAsync().ConfigureAwait(false);
275+
276+
// Dispose the cancellation token source
277+
_cancellationTokenSource?.Dispose();
278+
_cancellationTokenSource = null;
279+
280+
_isStarted = false;
281+
}
282+
283+
/// <inheritdoc />
284+
public void Dispose()
285+
{
286+
if (_isDisposed)
287+
{
288+
return;
289+
}
290+
291+
// Stop the pipeline
292+
StopAsync().GetAwaiter().GetResult();
293+
294+
// Dispose the HTTP client
295+
_httpClient.Dispose();
296+
297+
// Dispose the cancellation token source if it hasn't been disposed yet
298+
_cancellationTokenSource?.Dispose();
299+
_cancellationTokenSource = null;
300+
301+
// Mark the queues as completed to release any waiting threads
302+
_downloadQueue.CompleteAdding();
303+
_resultQueue.CompleteAdding();
304+
305+
// Dispose any remaining results
306+
foreach (var result in _resultQueue.GetConsumingEnumerable(CancellationToken.None))
307+
{
308+
result.Dispose();
309+
}
310+
311+
foreach (var result in _downloadQueue.GetConsumingEnumerable(CancellationToken.None))
312+
{
313+
result.Dispose();
314+
}
315+
316+
_downloadQueue.Dispose();
317+
_resultQueue.Dispose();
318+
319+
_isDisposed = true;
320+
}
321+
322+
private void ThrowIfDisposed()
323+
{
324+
if (_isDisposed)
325+
{
326+
throw new ObjectDisposedException(nameof(CloudFetchDownloadManager));
327+
}
328+
}
329+
}
330+
}

0 commit comments

Comments
 (0)