Skip to content

Commit 47ed901

Browse files
author
Felipe Mello
committed
will it work?
1 parent cba0257 commit 47ed901

File tree

2 files changed

+16
-40
lines changed

2 files changed

+16
-40
lines changed

tests/unit_tests/datasets/test_hf.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -287,25 +287,15 @@ def test_distributed_epoch_boundary_checkpointing(self):
287287
"""
288288
rank = dist.get_rank()
289289

290-
# Create shared temp directory (only rank 0 creates it)
291-
if rank == 0:
292-
temp_dir = tempfile.mkdtemp(prefix="epoch_test_")
293-
else:
294-
temp_dir = ""
295-
296-
# Broadcast temp directory path to all ranks
297-
temp_dir_list = [temp_dir]
298-
dist.broadcast_object_list(temp_dir_list, src=0)
299-
temp_dir = temp_dir_list[0]
290+
# Each rank creates its own local temp dir and files
291+
temp_dir = tempfile.mkdtemp(prefix=f"epoch_test_rank{rank}_")
300292
tmp_path = Path(temp_dir)
301293

302294
try:
303295
medium_dataset_file = tmp_path / "medium_data.json"
304296

305-
# Only rank 0 creates the data file, all ranks read from it
306-
if rank == 0:
307-
create_test_json_file(medium_dataset_file, MEDIUM_DATASET_SIZE)
308-
dist.barrier() # Wait for file creation
297+
# Each rank creates its own file
298+
create_test_json_file(medium_dataset_file, MEDIUM_DATASET_SIZE)
309299

310300
# Test multiple epoch boundaries
311301
for num_epochs in [0.9, 1.0, 2.5]:
@@ -373,6 +363,4 @@ def create_loader():
373363
), f"Epoch count incorrect for {num_epochs} epochs test scenario"
374364

375365
finally:
376-
# Clean up temp directory (only rank 0)
377-
if rank == 0:
378-
shutil.rmtree(temp_dir)
366+
shutil.rmtree(temp_dir)

tests/unit_tests/datasets/test_interleaved.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -484,40 +484,29 @@ def test_distributed_interleaved_checkpointing(self):
484484
"""
485485
rank = dist.get_rank()
486486

487-
# Create shared temp directory (only rank 0 creates it)
488-
if rank == 0:
489-
temp_dir = tempfile.mkdtemp(prefix="interleaved_test_")
490-
else:
491-
temp_dir = None
492-
493-
# Broadcast temp directory to all ranks
494-
temp_dir_list = [temp_dir] if temp_dir is not None else [""]
495-
dist.broadcast_object_list(temp_dir_list, src=0)
496-
temp_dir = temp_dir_list[0]
487+
# Each rank creates its own local temp dir and files (no broadcast/barrier needed for creation)
488+
temp_dir = tempfile.mkdtemp(prefix=f"interleaved_test_rank{rank}_")
497489
tmp_path = Path(temp_dir)
498490

499491
try:
500492
# ============================================
501-
# SETUP: Create test files ONCE at the start
493+
# SETUP: Each rank creates its own test files
502494
# ============================================
503495
file1 = tmp_path / "ds1.json"
504496
file2 = tmp_path / "ds2.json"
505497
file3 = tmp_path / "ds3.json"
506498

507-
# Only rank 0 creates the data files
508-
if rank == 0:
509-
create_test_json_file(file1, SMALL_DATASET_SIZE, offset=0)
510-
create_test_json_file(file2, MEDIUM_DATASET_SIZE, offset=100)
511-
create_test_json_file(file3, LARGE_DATASET_SIZE, offset=1000)
499+
create_test_json_file(file1, SMALL_DATASET_SIZE, offset=0)
500+
create_test_json_file(file2, MEDIUM_DATASET_SIZE, offset=100)
501+
create_test_json_file(file3, LARGE_DATASET_SIZE, offset=1000)
512502

513-
# Wait for all ranks to reach this point
514-
dist.barrier()
503+
# No barrier needed since files are local to each rank
515504

516505
# ============================================
517506
# TEST LOGIC: Functions that use the files
518507
# ============================================
519-
def create_dataset():
520-
"""Create interleaved dataset from pre-created files."""
508+
def create_dataset() -> InterleavedDataset:
509+
"""Create interleaved dataset from local files."""
521510
ds1 = HfIterableDataset(
522511
path="json",
523512
data_files=str(file1),
@@ -631,6 +620,5 @@ def create_dataloader(dataset):
631620
), f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}"
632621

633622
finally:
634-
# Clean up temp directory (only rank 0)
635-
if rank == 0:
636-
shutil.rmtree(temp_dir)
623+
# Each rank cleans its own temp dir
624+
shutil.rmtree(temp_dir)

0 commit comments

Comments
 (0)