@@ -484,40 +484,29 @@ def test_distributed_interleaved_checkpointing(self):
484484 """
485485 rank = dist .get_rank ()
486486
487- # Create shared temp directory (only rank 0 creates it)
488- if rank == 0 :
489- temp_dir = tempfile .mkdtemp (prefix = "interleaved_test_" )
490- else :
491- temp_dir = None
492-
493- # Broadcast temp directory to all ranks
494- temp_dir_list = [temp_dir ] if temp_dir is not None else ["" ]
495- dist .broadcast_object_list (temp_dir_list , src = 0 )
496- temp_dir = temp_dir_list [0 ]
487+ # Each rank creates its own local temp dir and files (no broadcast/barrier needed for creation)
488+ temp_dir = tempfile .mkdtemp (prefix = f"interleaved_test_rank{ rank } _" )
497489 tmp_path = Path (temp_dir )
498490
499491 try :
500492 # ============================================
501- # SETUP: Create test files ONCE at the start
493+ # SETUP: Each rank creates its own test files
502494 # ============================================
503495 file1 = tmp_path / "ds1.json"
504496 file2 = tmp_path / "ds2.json"
505497 file3 = tmp_path / "ds3.json"
506498
507- # Only rank 0 creates the data files
508- if rank == 0 :
509- create_test_json_file (file1 , SMALL_DATASET_SIZE , offset = 0 )
510- create_test_json_file (file2 , MEDIUM_DATASET_SIZE , offset = 100 )
511- create_test_json_file (file3 , LARGE_DATASET_SIZE , offset = 1000 )
499+ create_test_json_file (file1 , SMALL_DATASET_SIZE , offset = 0 )
500+ create_test_json_file (file2 , MEDIUM_DATASET_SIZE , offset = 100 )
501+ create_test_json_file (file3 , LARGE_DATASET_SIZE , offset = 1000 )
512502
513- # Wait for all ranks to reach this point
514- dist .barrier ()
503+ # No barrier needed since files are local to each rank
515504
516505 # ============================================
517506 # TEST LOGIC: Functions that use the files
518507 # ============================================
519- def create_dataset ():
520- """Create interleaved dataset from pre-created files."""
508+ def create_dataset () -> InterleavedDataset :
509+ """Create interleaved dataset from local files."""
521510 ds1 = HfIterableDataset (
522511 path = "json" ,
523512 data_files = str (file1 ),
@@ -631,6 +620,5 @@ def create_dataloader(dataset):
631620 ), f"ds3 ratio { ds3_ratio :.2f} should be ~{ expected_ds3_ratio } "
632621
633622 finally :
634- # Clean up temp directory (only rank 0)
635- if rank == 0 :
636- shutil .rmtree (temp_dir )
623+ # Each rank cleans its own temp dir
624+ shutil .rmtree (temp_dir )
0 commit comments