Skip to content

Commit cba0257

Browse files
author
Felipe Mello
committed
fix race condition
1 parent e9b736a commit cba0257

File tree

1 file changed

+22
-18
lines changed

1 file changed

+22
-18
lines changed

tests/unit_tests/datasets/test_interleaved.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -497,29 +497,33 @@ def test_distributed_interleaved_checkpointing(self):
497497
tmp_path = Path(temp_dir)
498498

499499
try:
500+
# ============================================
501+
# SETUP: Create test files ONCE at the start
502+
# ============================================
503+
file1 = tmp_path / "ds1.json"
504+
file2 = tmp_path / "ds2.json"
505+
file3 = tmp_path / "ds3.json"
506+
507+
# Only rank 0 creates the data files
508+
if rank == 0:
509+
create_test_json_file(file1, SMALL_DATASET_SIZE, offset=0)
510+
create_test_json_file(file2, MEDIUM_DATASET_SIZE, offset=100)
511+
create_test_json_file(file3, LARGE_DATASET_SIZE, offset=1000)
500512

501-
def create_dataset():
502-
file1 = tmp_path / "ds1.json"
503-
file2 = tmp_path / "ds2.json"
504-
file3 = tmp_path / "ds3.json"
505-
506-
# Only rank 0 creates the data files
507-
if rank == 0:
508-
create_test_json_file(file1, SMALL_DATASET_SIZE) # IDs 0-22
509-
create_test_json_file(
510-
file2, MEDIUM_DATASET_SIZE, offset=100
511-
) # IDs 100-134
512-
create_test_json_file(
513-
file3, LARGE_DATASET_SIZE, offset=1000
514-
) # IDs 1000-1046
515-
dist.barrier() # Wait for file creation
513+
# Wait for all ranks to reach this point
514+
dist.barrier()
516515

516+
# ============================================
517+
# TEST LOGIC: Functions that use the files
518+
# ============================================
519+
def create_dataset():
520+
"""Create interleaved dataset from pre-created files."""
517521
ds1 = HfIterableDataset(
518522
path="json",
519523
data_files=str(file1),
520524
split="train",
521525
dataset_name="ds1",
522-
shuffle_buffer_size=0, # No shuffle for determinism
526+
shuffle_buffer_size=0,
523527
metric_transform=DefaultDatasetMetricTransform(),
524528
num_shards_per_rank=2,
525529
weight=0.3,
@@ -529,7 +533,7 @@ def create_dataset():
529533
data_files=str(file2),
530534
split="train",
531535
dataset_name="ds2",
532-
shuffle_buffer_size=0, # No shuffle for determinism
536+
shuffle_buffer_size=0,
533537
metric_transform=DefaultDatasetMetricTransform(),
534538
num_shards_per_rank=2,
535539
weight=0.7,
@@ -539,7 +543,7 @@ def create_dataset():
539543
data_files=str(file3),
540544
split="train",
541545
dataset_name="ds3",
542-
shuffle_buffer_size=0, # No shuffle for determinism
546+
shuffle_buffer_size=0,
543547
metric_transform=DefaultDatasetMetricTransform(),
544548
num_shards_per_rank=2,
545549
weight=1.0,

0 commit comments

Comments
 (0)