3131using System . Threading . Tasks ;
3232using Apache . Arrow . Adbc . Drivers . Apache . Hive2 ;
3333using Apache . Arrow . Adbc . Drivers . Databricks . Reader . CloudFetch ;
34+ using Apache . Arrow . Adbc . Tracing ;
3435using Apache . Hive . Service . Rpc . Thrift ;
3536using Moq ;
3637using Moq . Protected ;
3738using Xunit ;
3839
3940namespace Apache . Arrow . Adbc . Tests . Drivers . Databricks . CloudFetch
4041{
41- public class CloudFetchDownloaderTest
42+ public class CloudFetchDownloaderTest : IDisposable
4243 {
4344 private readonly BlockingCollection < IDownloadResult > _downloadQueue ;
4445 private readonly BlockingCollection < IDownloadResult > _resultQueue ;
4546 private readonly Mock < ICloudFetchMemoryBufferManager > _mockMemoryManager ;
4647 private readonly Mock < IHiveServer2Statement > _mockStatement ;
4748 private readonly Mock < ICloudFetchResultFetcher > _mockResultFetcher ;
49+ private readonly ActivityTrace _activityTrace ;
4850
4951 public CloudFetchDownloaderTest ( )
5052 {
@@ -54,6 +56,13 @@ public CloudFetchDownloaderTest()
5456 _mockStatement = new Mock < IHiveServer2Statement > ( ) ;
5557 _mockResultFetcher = new Mock < ICloudFetchResultFetcher > ( ) ;
5658
59+ // Set up activity trace for tracing support
60+ _activityTrace = new ActivityTrace ( "TestActivitySource" ) ;
61+ _mockStatement . Setup ( s => s . Trace ) . Returns ( _activityTrace ) ;
62+ _mockStatement . Setup ( s => s . TraceParent ) . Returns ( ( string ? ) null ) ;
63+ _mockStatement . Setup ( s => s . AssemblyVersion ) . Returns ( "1.0.0" ) ;
64+ _mockStatement . Setup ( s => s . AssemblyName ) . Returns ( "TestAssembly" ) ;
65+
5766 // Set up memory manager defaults
5867 _mockMemoryManager . Setup ( m => m . TryAcquireMemory ( It . IsAny < long > ( ) ) ) . Returns ( true ) ;
5968 _mockMemoryManager . Setup ( m => m . AcquireMemoryAsync ( It . IsAny < long > ( ) , It . IsAny < CancellationToken > ( ) ) )
@@ -296,7 +305,10 @@ public async Task DownloadFileAsync_WithError_StopsProcessingRemainingFiles()
296305 // Create test download results
297306 var mockDownloadResult = new Mock < IDownloadResult > ( ) ;
298307 var resultLink = new TSparkArrowResultLink {
308+ StartRowOffset = 0 ,
299309 FileLink = "http://test.com/file1" ,
310+ RowCount = 100 ,
311+ BytesNum = 100 ,
300312 ExpiryTime = DateTimeOffset . UtcNow . AddMinutes ( 30 ) . ToUnixTimeMilliseconds ( ) // Set expiry 30 minutes in the future
301313 } ;
302314 mockDownloadResult . Setup ( r => r . Link ) . Returns ( resultLink ) ;
@@ -306,8 +318,13 @@ public async Task DownloadFileAsync_WithError_StopsProcessingRemainingFiles()
306318
307319 // Capture when SetFailed is called
308320 Exception ? capturedException = null ;
321+ bool setFailedCalled = false ;
309322 mockDownloadResult . Setup ( r => r . SetFailed ( It . IsAny < Exception > ( ) ) )
310- . Callback < Exception > ( ex => capturedException = ex ) ;
323+ . Callback < Exception > ( ex => {
324+ capturedException = ex ;
325+ setFailedCalled = true ;
326+ Console . WriteLine ( $ "SetFailed called with exception: { ex . Message } ") ;
327+ } ) ;
311328
312329 // Create the downloader
313330 var downloader = new CloudFetchDownloader (
@@ -324,30 +341,42 @@ public async Task DownloadFileAsync_WithError_StopsProcessingRemainingFiles()
324341
325342 // Act
326343 await downloader . StartAsync ( CancellationToken . None ) ;
344+ Console . WriteLine ( "Downloader started" ) ;
327345 _downloadQueue . Add ( mockDownloadResult . Object ) ;
346+ Console . WriteLine ( "Added download result to queue" ) ;
328347
329- // Wait for the download to be processed and fail
330- await Task . Delay ( 200 ) ;
331-
332- // Add the end of results guard
348+ // Add the end of results guard immediately
333349 _downloadQueue . Add ( EndOfResultsGuard . Instance ) ;
350+ Console . WriteLine ( "Added end guard" ) ;
334351
335- // Wait for all processing to complete
336- await Task . Delay ( 200 ) ;
352+ // Wait for the download to fail - use a timeout to avoid hanging
353+ int maxWaitMs = 2000 ;
354+ int waitedMs = 0 ;
355+ while ( ! downloader . HasError && ! setFailedCalled && waitedMs < maxWaitMs )
356+ {
357+ await Task . Delay ( 50 ) ;
358+ waitedMs += 50 ;
359+ }
360+
361+ Console . WriteLine ( $ "Finished waiting: HasError={ downloader . HasError } , SetFailedCalled={ setFailedCalled } , WaitedMs={ waitedMs } ") ;
337362
338363 // Assert
339364 // Verify the download failed
365+ Assert . True ( setFailedCalled , $ "SetFailed was not called. Downloader HasError={ downloader . HasError } ") ;
340366 mockDownloadResult . Verify ( r => r . SetFailed ( It . IsAny < Exception > ( ) ) , Times . Once ) ;
341367
368+ // Verify GetNextDownloadedFileAsync throws an exception
369+ await Assert . ThrowsAsync < AdbcException > ( ( ) => downloader . GetNextDownloadedFileAsync ( CancellationToken . None ) ) ;
370+
342371 // Verify the downloader has an error
343372 Assert . True ( downloader . HasError ) ;
344373 Assert . NotNull ( downloader . Error ) ;
345374
346- // Verify GetNextDownloadedFileAsync throws an exception
347- await Assert . ThrowsAsync < AdbcException > ( ( ) => downloader . GetNextDownloadedFileAsync ( CancellationToken . None ) ) ;
375+ // Cleanup with timeout and verify task completed
376+ var stopTask = downloader . StopAsync ( ) ;
377+ var completedTask = await Task . WhenAny ( stopTask , Task . Delay ( 2000 ) ) ;
378+ Assert . Same ( stopTask , completedTask ) ; // Ensure that StopAsync completed before the timeout
348379
349- // Cleanup
350- await downloader . StopAsync ( ) ;
351380 }
352381
353382 [ Fact ]
@@ -680,5 +709,12 @@ private static Mock<HttpMessageHandler> CreateMockHttpMessageHandler(
680709
681710 return mockHandler ;
682711 }
712+
713+ public void Dispose ( )
714+ {
715+ _activityTrace ? . Dispose ( ) ;
716+ _downloadQueue ? . Dispose ( ) ;
717+ _resultQueue ? . Dispose ( ) ;
718+ }
683719 }
684720}
0 commit comments