@@ -497,29 +497,33 @@ def test_distributed_interleaved_checkpointing(self):
497497 tmp_path = Path (temp_dir )
498498
499499 try :
500+ # ============================================
501+ # SETUP: Create test files ONCE at the start
502+ # ============================================
503+ file1 = tmp_path / "ds1.json"
504+ file2 = tmp_path / "ds2.json"
505+ file3 = tmp_path / "ds3.json"
506+
507+ # Only rank 0 creates the data files
508+ if rank == 0 :
509+ create_test_json_file (file1 , SMALL_DATASET_SIZE , offset = 0 )
510+ create_test_json_file (file2 , MEDIUM_DATASET_SIZE , offset = 100 )
511+ create_test_json_file (file3 , LARGE_DATASET_SIZE , offset = 1000 )
500512
501- def create_dataset ():
502- file1 = tmp_path / "ds1.json"
503- file2 = tmp_path / "ds2.json"
504- file3 = tmp_path / "ds3.json"
505-
506- # Only rank 0 creates the data files
507- if rank == 0 :
508- create_test_json_file (file1 , SMALL_DATASET_SIZE ) # IDs 0-22
509- create_test_json_file (
510- file2 , MEDIUM_DATASET_SIZE , offset = 100
511- ) # IDs 100-134
512- create_test_json_file (
513- file3 , LARGE_DATASET_SIZE , offset = 1000
514- ) # IDs 1000-1046
515- dist .barrier () # Wait for file creation
513+ # Wait for all ranks to reach this point
514+ dist .barrier ()
516515
516+ # ============================================
517+ # TEST LOGIC: Functions that use the files
518+ # ============================================
519+ def create_dataset ():
520+ """Create interleaved dataset from pre-created files."""
517521 ds1 = HfIterableDataset (
518522 path = "json" ,
519523 data_files = str (file1 ),
520524 split = "train" ,
521525 dataset_name = "ds1" ,
522- shuffle_buffer_size = 0 , # No shuffle for determinism
526+ shuffle_buffer_size = 0 ,
523527 metric_transform = DefaultDatasetMetricTransform (),
524528 num_shards_per_rank = 2 ,
525529 weight = 0.3 ,
@@ -529,7 +533,7 @@ def create_dataset():
529533 data_files = str (file2 ),
530534 split = "train" ,
531535 dataset_name = "ds2" ,
532- shuffle_buffer_size = 0 , # No shuffle for determinism
536+ shuffle_buffer_size = 0 ,
533537 metric_transform = DefaultDatasetMetricTransform (),
534538 num_shards_per_rank = 2 ,
535539 weight = 0.7 ,
@@ -539,7 +543,7 @@ def create_dataset():
539543 data_files = str (file3 ),
540544 split = "train" ,
541545 dataset_name = "ds3" ,
542- shuffle_buffer_size = 0 , # No shuffle for determinism
546+ shuffle_buffer_size = 0 ,
543547 metric_transform = DefaultDatasetMetricTransform (),
544548 num_shards_per_rank = 2 ,
545549 weight = 1.0 ,
0 commit comments