Skip to content

Commit 3e67347

Browse files
msrathore-dbclaude
andcommitted
feat(csharp): Implement straggler download mitigation for CloudFetch
Adds straggler download mitigation feature to improve CloudFetch performance by detecting and cancelling abnormally slow parallel downloads. Implementation: - New StragglerDownloadDetector class for detecting slow downloads - New FileDownloadMetrics class for tracking download performance - New CloudFetchStragglerMitigationConfig for configuration management - Integration into CloudFetchDownloader with background monitoring thread - Automatic fallback to sequential downloads after threshold Configuration Parameters: - adbc.databricks.cloudfetch.straggler_mitigation_enabled (default: false) - adbc.databricks.cloudfetch.straggler_multiplier (default: 1.5) - adbc.databricks.cloudfetch.straggler_quantile (default: 0.6) - adbc.databricks.cloudfetch.straggler_padding_seconds (default: 5) - adbc.databricks.cloudfetch.max_stragglers_per_query (default: 10) - adbc.databricks.cloudfetch.synchronous_fallback_enabled (default: true) Tests: - 19 comprehensive unit tests covering basic functionality and advanced scenarios - 19 E2E tests with mocked HTTP responses validating real-world scenarios - All tests pass successfully Documentation: - straggler-mitigation-design.md: comprehensive design documentation 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 982d29f commit 3e67347

12 files changed

+2107
-1710
lines changed

csharp/src/Drivers/Databricks/Reader/CloudFetch/CloudFetchDownloader.cs

Lines changed: 84 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch
3535
/// </summary>
3636
internal sealed class CloudFetchDownloader : ICloudFetchDownloader, IActivityTracer
3737
{
38+
// Straggler mitigation timing constants
39+
private static readonly TimeSpan StragglerMonitoringInterval = TimeSpan.FromSeconds(2);
40+
private static readonly TimeSpan MetricsCleanupDelay = TimeSpan.FromSeconds(5); // Must be > monitoring interval
41+
private static readonly TimeSpan CtsDisposalDelay = TimeSpan.FromSeconds(6); // Must be > metrics cleanup delay
42+
3843
private readonly ITracingStatement _statement;
3944
private readonly BlockingCollection<IDownloadResult> _downloadQueue;
4045
private readonly BlockingCollection<IDownloadResult> _resultQueue;
@@ -60,10 +65,11 @@ internal sealed class CloudFetchDownloader : ICloudFetchDownloader, IActivityTra
6065
private readonly ConcurrentDictionary<long, FileDownloadMetrics>? _activeDownloadMetrics;
6166
private readonly ConcurrentDictionary<long, CancellationTokenSource>? _perFileDownloadCancellationTokens;
6267
private readonly ConcurrentDictionary<long, bool>? _alreadyCountedStragglers; // Prevents duplicate counting of same file
68+
private readonly ConcurrentDictionary<long, Task>? _metricCleanupTasks; // Tracks cleanup tasks for proper shutdown
6369
private Task? _stragglerMonitoringTask;
6470
private CancellationTokenSource? _stragglerMonitoringCts;
6571
private volatile bool _hasTriggeredSequentialDownloadFallback;
66-
private SemaphoreSlim _sequentialSemaphore = new SemaphoreSlim(1, 1); // Not disposed - lightweight, safe to leave allocated
72+
private SemaphoreSlim _sequentialSemaphore = new SemaphoreSlim(1, 1);
6773
private volatile bool _isSequentialMode;
6874

6975
/// <summary>
@@ -81,6 +87,7 @@ internal sealed class CloudFetchDownloader : ICloudFetchDownloader, IActivityTra
8187
/// <param name="retryDelayMs">The delay between retry attempts in milliseconds.</param>
8288
/// <param name="maxUrlRefreshAttempts">The maximum number of URL refresh attempts.</param>
8389
/// <param name="urlExpirationBufferSeconds">Buffer time in seconds before URL expiration to trigger refresh.</param>
90+
/// <param name="stragglerConfig">Optional configuration for straggler mitigation (null = disabled).</param>
8491
public CloudFetchDownloader(
8592
ITracingStatement statement,
8693
BlockingCollection<IDownloadResult> downloadQueue,
@@ -93,7 +100,8 @@ public CloudFetchDownloader(
93100
int maxRetries = 3,
94101
int retryDelayMs = 500,
95102
int maxUrlRefreshAttempts = 3,
96-
int urlExpirationBufferSeconds = 60)
103+
int urlExpirationBufferSeconds = 60,
104+
CloudFetchStragglerMitigationConfig? stragglerConfig = null)
97105
{
98106
_statement = statement ?? throw new ArgumentNullException(nameof(statement));
99107
_downloadQueue = downloadQueue ?? throw new ArgumentNullException(nameof(downloadQueue));
@@ -110,28 +118,22 @@ public CloudFetchDownloader(
110118
_downloadSemaphore = new SemaphoreSlim(_maxParallelDownloads, _maxParallelDownloads);
111119
_isCompleted = false;
112120

113-
// Parse straggler mitigation configuration
114-
var hiveStatement = _statement as IHiveServer2Statement;
115-
var properties = hiveStatement?.Connection?.Properties;
116-
_isStragglerMitigationEnabled = properties != null && ParseBooleanProperty(properties, DatabricksParameters.CloudFetchStragglerMitigationEnabled, defaultValue: false);
121+
// Initialize straggler mitigation from config object
122+
var config = stragglerConfig ?? CloudFetchStragglerMitigationConfig.Disabled;
123+
_isStragglerMitigationEnabled = config.Enabled;
117124

118-
if (_isStragglerMitigationEnabled && properties != null)
125+
if (config.Enabled)
119126
{
120-
double stragglerMultiplier = ParseDoubleProperty(properties, DatabricksParameters.CloudFetchStragglerMultiplier, defaultValue: 1.5);
121-
double stragglerQuantile = ParseDoubleProperty(properties, DatabricksParameters.CloudFetchStragglerQuantile, defaultValue: 0.6);
122-
int stragglerPaddingSeconds = ParseIntProperty(properties, DatabricksParameters.CloudFetchStragglerPaddingSeconds, defaultValue: 5);
123-
int maxStragglersPerQuery = ParseIntProperty(properties, DatabricksParameters.CloudFetchMaxStragglersPerQuery, defaultValue: 10);
124-
bool synchronousFallbackEnabled = ParseBooleanProperty(properties, DatabricksParameters.CloudFetchSynchronousFallbackEnabled, defaultValue: false);
125-
126127
_stragglerDetector = new StragglerDownloadDetector(
127-
stragglerMultiplier,
128-
stragglerQuantile,
129-
TimeSpan.FromSeconds(stragglerPaddingSeconds),
130-
synchronousFallbackEnabled ? maxStragglersPerQuery : int.MaxValue);
128+
config.Multiplier,
129+
config.Quantile,
130+
config.Padding,
131+
config.SynchronousFallbackEnabled ? config.MaxStragglersBeforeFallback : int.MaxValue);
131132

132133
_activeDownloadMetrics = new ConcurrentDictionary<long, FileDownloadMetrics>();
133134
_perFileDownloadCancellationTokens = new ConcurrentDictionary<long, CancellationTokenSource>();
134135
_alreadyCountedStragglers = new ConcurrentDictionary<long, bool>();
136+
_metricCleanupTasks = new ConcurrentDictionary<long, Task>();
135137
_hasTriggeredSequentialDownloadFallback = false;
136138
}
137139
}
@@ -145,6 +147,27 @@ public CloudFetchDownloader(
145147
/// <inheritdoc />
146148
public Exception? Error => _error;
147149

150+
/// <summary>
151+
/// Internal property to check if straggler mitigation is enabled (for testing).
152+
/// </summary>
153+
internal bool IsStragglerMitigationEnabled => _isStragglerMitigationEnabled;
154+
155+
/// <summary>
156+
/// Internal property to get total stragglers detected (for testing).
157+
/// </summary>
158+
internal long GetTotalStragglersDetected() => _stragglerDetector?.GetTotalStragglersDetectedInQuery() ?? 0;
159+
160+
/// <summary>
161+
/// Internal property to get count of active downloads being tracked (for testing).
162+
/// </summary>
163+
internal int GetActiveDownloadCount() => _activeDownloadMetrics?.Count ?? 0;
164+
165+
/// <summary>
166+
/// Internal property to check if tracking dictionaries are initialized (for testing).
167+
/// </summary>
168+
internal bool AreTrackingDictionariesInitialized() => _activeDownloadMetrics != null && _perFileDownloadCancellationTokens != null;
169+
170+
148171
/// <inheritdoc />
149172
public async Task StartAsync(CancellationToken cancellationToken)
150173
{
@@ -219,6 +242,20 @@ public async Task StopAsync()
219242
_cancellationTokenSource = null;
220243
_downloadTask = null;
221244

245+
// Await all metric cleanup tasks before disposing resources
246+
if (_metricCleanupTasks != null && _metricCleanupTasks.Count > 0)
247+
{
248+
try
249+
{
250+
await Task.WhenAll(_metricCleanupTasks.Values).ConfigureAwait(false);
251+
}
252+
catch
253+
{
254+
// Ignore cleanup task exceptions during shutdown
255+
}
256+
_metricCleanupTasks.Clear();
257+
}
258+
222259
// Cleanup per-file cancellation tokens
223260
if (_perFileDownloadCancellationTokens != null)
224261
{
@@ -229,8 +266,8 @@ public async Task StopAsync()
229266
_perFileDownloadCancellationTokens.Clear();
230267
}
231268

232-
// Note: _sequentialSemaphore is intentionally not disposed to support restart scenarios
233-
// Semaphores are lightweight and safe to leave allocated
269+
// Dispose sequential semaphore
270+
_sequentialSemaphore?.Dispose();
234271
}
235272
}
236273

@@ -358,7 +395,6 @@ await this.TraceActivityAsync(async activity =>
358395
// Acquire a download slot
359396
await _downloadSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
360397

361-
// Capture mode atomically to avoid TOCTOU race with monitor thread
362398
bool shouldAcquireSequential = _isSequentialMode;
363399
bool acquiredSequential = false;
364400
if (shouldAcquireSequential)
@@ -762,25 +798,30 @@ await this.TraceActivityAsync(async activity =>
762798
}
763799
finally
764800
{
765-
// Cleanup per-file cancellation token (always runs, even on exception)
801+
// Delay CTS disposal to avoid race with monitoring thread
802+
// Monitoring thread may still be checking this CTS, so schedule disposal after monitoring can complete
766803
if (_perFileDownloadCancellationTokens != null)
767804
{
768805
if (_perFileDownloadCancellationTokens.TryRemove(fileOffset, out var cts))
769806
{
770-
cts?.Dispose();
807+
// Schedule disposal after delay to allow monitoring thread to finish
808+
_ = Task.Run(async () =>
809+
{
810+
await Task.Delay(CtsDisposalDelay);
811+
cts?.Dispose();
812+
});
771813
}
772814
}
773815

774-
// Remove from active metrics after a short delay to allow final detection cycle
775-
// Use fire-and-forget with exception handling to prevent unobserved task exceptions
776-
if (_activeDownloadMetrics != null)
816+
// Track cleanup task instead of fire-and-forget to ensure proper shutdown
817+
if (_activeDownloadMetrics != null && _metricCleanupTasks != null)
777818
{
778-
_ = Task.Run(async () =>
819+
var cleanupTask = Task.Run(async () =>
779820
{
780821
try
781822
{
782823
// Use cancellationToken to respect shutdown - removes immediately if cancelled
783-
await Task.Delay(TimeSpan.FromSeconds(3), cancellationToken);
824+
await Task.Delay(MetricsCleanupDelay, cancellationToken);
784825
_activeDownloadMetrics?.TryRemove(fileOffset, out _);
785826
}
786827
catch (OperationCanceledException)
@@ -792,7 +833,13 @@ await this.TraceActivityAsync(async activity =>
792833
{
793834
// Ignore other exceptions in cleanup task
794835
}
836+
finally
837+
{
838+
// Always remove from tracking dictionary
839+
_metricCleanupTasks?.TryRemove(fileOffset, out _);
840+
}
795841
});
842+
_metricCleanupTasks[fileOffset] = cleanupTask;
796843
}
797844
}
798845
}, activityName: "DownloadFile");
@@ -837,7 +884,7 @@ await this.TraceActivityAsync(async activity =>
837884
{
838885
try
839886
{
840-
await Task.Delay(TimeSpan.FromSeconds(2), cancellationToken).ConfigureAwait(false);
887+
await Task.Delay(StragglerMonitoringInterval, cancellationToken).ConfigureAwait(false);
841888

842889
if (_activeDownloadMetrics == null || _stragglerDetector == null || _perFileDownloadCancellationTokens == null)
843890
{
@@ -881,7 +928,15 @@ await this.TraceActivityAsync(async activity =>
881928
new("offset", offset)
882929
]);
883930

884-
cts.Cancel();
931+
try
932+
{
933+
cts.Cancel();
934+
}
935+
catch (ObjectDisposedException)
936+
{
937+
// Expected race condition: CTS was disposed between TryGetValue and Cancel
938+
// This is harmless - the download has already completed
939+
}
885940
}
886941
}
887942
}
@@ -914,35 +969,6 @@ private string SanitizeUrl(string url)
914969
return "cloud-storage-url";
915970
}
916971
}
917-
918-
// Helper methods for parsing configuration properties
919-
private static bool ParseBooleanProperty(IReadOnlyDictionary<string, string> properties, string key, bool defaultValue)
920-
{
921-
if (properties.TryGetValue(key, out string? value) && bool.TryParse(value, out bool result))
922-
{
923-
return result;
924-
}
925-
return defaultValue;
926-
}
927-
928-
private static int ParseIntProperty(IReadOnlyDictionary<string, string> properties, string key, int defaultValue)
929-
{
930-
if (properties.TryGetValue(key, out string? value) && int.TryParse(value, out int result))
931-
{
932-
return result;
933-
}
934-
return defaultValue;
935-
}
936-
937-
private static double ParseDoubleProperty(IReadOnlyDictionary<string, string> properties, string key, double defaultValue)
938-
{
939-
if (properties.TryGetValue(key, out string? value) && double.TryParse(value, out double result))
940-
{
941-
return result;
942-
}
943-
return defaultValue;
944-
}
945-
946972
// IActivityTracer implementation - delegates to statement
947973
ActivityTrace IActivityTracer.Trace => _statement.Trace;
948974

0 commit comments

Comments
 (0)