@@ -86,13 +86,8 @@ def test_hide_token(self, serializer):
8686
8787class TestJobDownloadTask :
8888
89- # Use a temporary directory for safe file handling
90- @pytest .fixture
91- def temp_dir (self ):
92- with tempfile .TemporaryDirectory () as temp_dir :
93- yield Path (temp_dir )
9489
95- def test_job_download_success (self , requests_mock : Mocker , temp_dir : Path ):
90+ def test_job_download_success (self , requests_mock : Mocker , tmp_path : Path ):
9691 """
9792 Test a successful job download and verify file content and stats update.
9893 """
@@ -107,7 +102,7 @@ def test_job_download_success(self, requests_mock: Mocker, temp_dir: Path):
107102 backend ._set_job_status (job_id = job_id , status = "finished" )
108103 backend .batch_jobs [job_id ]["status" ] = "finished"
109104
110- download_dir = temp_dir / job_id / "results"
105+ download_dir = tmp_path / job_id / "results"
111106 download_dir .mkdir (parents = True )
112107
113108 # Create the task instance
@@ -136,7 +131,7 @@ def test_job_download_success(self, requests_mock: Mocker, temp_dir: Path):
136131 assert downloaded_file .read_bytes () == b"The downloaded file content."
137132
138133
139- def test_job_download_failure (self , requests_mock : Mocker , temp_dir : Path ):
134+ def test_job_download_failure (self , requests_mock : Mocker , tmp_path : Path ):
140135 """
141136 Test a failed download (e.g., bad connection) and verify error reporting.
142137 """
@@ -156,7 +151,7 @@ def test_job_download_failure(self, requests_mock: Mocker, temp_dir: Path):
156151 backend ._set_job_status (job_id = job_id , status = "finished" )
157152 backend .batch_jobs [job_id ]["finished" ] = "error"
158153
159- download_dir = temp_dir / job_id / "results"
154+ download_dir = tmp_path / job_id / "results"
160155 download_dir .mkdir (parents = True )
161156
162157 # Create the task instance
@@ -392,14 +387,6 @@ def test_job_start_task_failure(self, worker_pool, dummy_backend, caplog):
392387 ]
393388
394389
395- import pytest
396- import time
397- import threading
398- import logging
399- from typing import Iterator
400-
401- _log = logging .getLogger (__name__ )
402-
403390
404391class TestJobManagerWorkerThreadPool :
405392 @pytest .fixture
@@ -447,18 +434,17 @@ def test_submit_task_creates_pool(self, thread_pool):
447434 """Test that submitting a task creates a pool dynamically."""
448435 task = NopTask (job_id = "j-1" , df_idx = 1 )
449436
450- # No pools initially
451437 assert thread_pool .list_pools () == []
452438
453439 # Submit task - should create pool
454440 thread_pool .submit_task (task )
455441
456- # Pool should be created with default workers (1)
442+ # Pool should be created
457443 assert thread_pool .list_pools () == ["NopTask" ]
458444 assert "NopTask" in thread_pool ._pools
459445
460446 # Process to complete the task
461- results , remaining = thread_pool .process_all_updates (timeout = 0.1 )
447+ results , remaining = thread_pool .process_futures (timeout = 0.1 )
462448 assert len (results ) == 1
463449 assert results [0 ].job_id == "j-1"
464450 assert remaining == {"NopTask" : 0 }
@@ -471,7 +457,6 @@ def test_submit_task_uses_config(self, configured_pool):
471457 configured_pool .submit_task (task )
472458
473459 assert "NopTask" in configured_pool ._pools
474- # Can't directly check max_workers, but pool should exist
475460 assert "NopTask" in configured_pool .list_pools ()
476461
477462 def test_submit_multiple_task_types (self , thread_pool ):
@@ -495,25 +480,23 @@ def test_submit_multiple_task_types(self, thread_pool):
495480 assert thread_pool .num_pending_tasks ("DummyTask" ) == 1
496481 assert thread_pool .num_pending_tasks ("NonExistent" ) == 0
497482
498- def test_process_all_updates_empty (self , thread_pool ):
499- """Test processing updates with no pools."""
500- results , remaining = thread_pool .process_all_updates (timeout = 0 )
483+ def test_process_futures_updates_empty (self , thread_pool ):
484+ """Test process futures with no pools."""
485+ results , remaining = thread_pool .process_futures (timeout = 0 )
501486 assert results == []
502487 assert remaining == {}
503488
504- def test_process_all_updates_multiple_pools (self , thread_pool ):
489+ def test_process_futures_updates_multiple_pools (self , thread_pool ):
505490 """Test processing updates across multiple pools."""
506491 # Submit tasks to different pools
507492 thread_pool .submit_task (NopTask (job_id = "j-1" , df_idx = 1 )) # NopTask pool
508493 thread_pool .submit_task (NopTask (job_id = "j-2" , df_idx = 2 )) # NopTask pool
509494 thread_pool .submit_task (DummyTask (job_id = "j-3" , df_idx = 3 )) # DummyTask pool
510495
511- # Process updates
512- results , remaining = thread_pool .process_all_updates (timeout = 0.1 )
496+ results , remaining = thread_pool .process_futures (timeout = 0.1 )
513497
514- # Should get 3 results
515498 assert len (results ) == 3
516- # Check results by pool
499+
517500 nop_results = [r for r in results if r .job_id in ["j-1" , "j-2" ]]
518501 dummy_results = [r for r in results if r .job_id == "j-3" ]
519502 assert len (nop_results ) == 2
@@ -522,7 +505,7 @@ def test_process_all_updates_multiple_pools(self, thread_pool):
522505 # All tasks should be completed
523506 assert remaining == {"NopTask" : 0 , "DummyTask" : 0 }
524507
525- def test_process_all_updates_partial_completion (self ):
508+ def test_process_futures_updates_partial_completion (self ):
526509 """Test processing when some tasks are still running."""
527510 # Use a pool with blocking tasks
528511 pool = _JobManagerWorkerThreadPool ()
@@ -538,7 +521,7 @@ def test_process_all_updates_partial_completion(self):
538521 pool .submit_task (quick_task ) # NopTask pool
539522
540523 # Process with timeout=0 - only quick task should complete
541- results , remaining = pool .process_all_updates (timeout = 0 )
524+ results , remaining = pool .process_futures (timeout = 0 )
542525
543526 # Only quick task completed
544527 assert len (results ) == 1
@@ -551,7 +534,7 @@ def test_process_all_updates_partial_completion(self):
551534
552535 # Release blocking task and process again
553536 event .set ()
554- results2 , remaining2 = pool .process_all_updates (timeout = 0.1 )
537+ results2 , remaining2 = pool .process_futures (timeout = 0.1 )
555538
556539 assert len (results2 ) == 1
557540 assert results2 [0 ].job_id == "j-block"
@@ -577,7 +560,7 @@ def test_num_pending_tasks(self, thread_pool):
577560 assert thread_pool .num_pending_tasks ("NonExistentPool" ) == 0
578561
579562 # Process all
580- thread_pool .process_all_updates (timeout = 0.1 )
563+ thread_pool .process_futures (timeout = 0.1 )
581564
582565 # Should be empty
583566 assert thread_pool .num_pending_tasks () == 0
@@ -620,28 +603,20 @@ def test_shutdown_all(self):
620603
621604 # Shutdown all
622605 pool .shutdown ()
623-
624- # All pools should be gone
606+
625607 assert pool .list_pools () == []
626-
627- # Can't submit any more tasks after shutdown
628- # Actually, shutdown() doesn't prevent creating new pools
629- # So we can test that shutdown clears existing pools
630608 assert len (pool ._pools ) == 0
631609
632610 def test_custom_get_pool_name (self ):
633611 """Test custom task class to verify pool name selection."""
634612
635613 @dataclass (frozen = True )
636- class CustomTask (Task ):
637- # Fields are inherited from Task: job_id, df_idx
638-
614+ class CustomTask (Task ):
639615 def execute (self ) -> _TaskResult :
640616 return _TaskResult (job_id = self .job_id , df_idx = self .df_idx )
641617
642618 pool = _JobManagerWorkerThreadPool ()
643619
644- # Submit custom task - must provide all required fields
645620 task = CustomTask (job_id = "j-1" , df_idx = 1 )
646621 pool .submit_task (task )
647622
@@ -650,7 +625,7 @@ def execute(self) -> _TaskResult:
650625 assert pool .num_pending_tasks () == 1
651626
652627 # Process it
653- results , remaining = pool .process_all_updates (timeout = 0.1 )
628+ results , _ = pool .process_futures (timeout = 0.1 )
654629 assert len (results ) == 1
655630 assert results [0 ].job_id == "j-1"
656631
@@ -674,7 +649,7 @@ def submit_tasks(start_idx: int):
674649 assert thread_pool .num_pending_tasks () == 15
675650
676651 # Process them all
677- results , remaining = thread_pool .process_all_updates (timeout = 0.5 )
652+ results , remaining = thread_pool .process_futures (timeout = 0.5 )
678653
679654 assert len (results ) == 15
680655 assert remaining == {"NopTask" : 0 }
@@ -687,7 +662,6 @@ def test_pool_parallelism_with_blocking_tasks(self):
687662
688663 # Create multiple blocking tasks
689664 events = [threading .Event () for _ in range (5 )]
690- start_time = time .time ()
691665
692666 for i , event in enumerate (events ):
693667 pool .submit_task (BlockingTask (
@@ -704,14 +678,10 @@ def test_pool_parallelism_with_blocking_tasks(self):
704678 for event in events :
705679 event .set ()
706680
707- # Process with timeout - all should complete
708- results , remaining = pool .process_all_updates (timeout = 0.5 )
709-
710- # All should complete (if pool had enough workers)
681+ results , remaining = pool .process_futures (timeout = 0.5 )
711682 assert len (results ) == 5
712683 assert remaining == {"BlockingTask" : 0 }
713684
714- # Check they all completed
715685 for result in results :
716686 assert result .job_id .startswith ("j-block-" )
717687
@@ -723,7 +693,7 @@ def test_task_with_error_handling(self, thread_pool):
723693 thread_pool .submit_task (DummyTask (job_id = "j-666" , df_idx = 0 ))
724694
725695 # Process it
726- results , remaining = thread_pool .process_all_updates (timeout = 0.1 )
696+ results , remaining = thread_pool .process_futures (timeout = 0.1 )
727697
728698 # Should get error result
729699 assert len (results ) == 1
@@ -741,7 +711,7 @@ def test_mixed_success_and_error_tasks(self, thread_pool):
741711 thread_pool .submit_task (DummyTask (job_id = "j-3" , df_idx = 3 )) # Success
742712
743713 # Process all
744- results , remaining = thread_pool .process_all_updates (timeout = 0.1 )
714+ results , remaining = thread_pool .process_futures (timeout = 0.1 )
745715
746716 # Should get 3 results
747717 assert len (results ) == 3
0 commit comments