Skip to content

Commit 0574ff0

Browse files
cauyxylexierule
authored andcommitted
bugfix: correct node rank (#19437)
(cherry picked from commit 7b867c7)
1 parent 4284762 commit 0574ff0

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

src/lightning/data/utilities/shuffle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ def _intra_node_chunk_shuffle(
1212
current_epoch: int,
1313
) -> List[int]:
1414
chunk_indexes_per_nodes: Any = [[] for _ in range(distributed_env.num_nodes)]
15+
process_per_node = distributed_env.world_size // distributed_env.num_nodes
1516
for rank, chunks_per_rank in enumerate(chunks_per_ranks):
16-
chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // distributed_env.num_nodes].extend(
17+
chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // process_per_node].extend(
1718
chunks_per_rank
1819
)
1920

tests/tests_data/streaming/test_shuffle.py renamed to tests/tests_data/utilities/test_shuffle.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ def test_intra_node_chunk_shuffle():
1212
assert shuffled_indexes == [3, 2, 1, 0, 7, 6, 5, 4]
1313

1414

15+
chunks_per_ranks = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]]
16+
shuffled_indexes = _intra_node_chunk_shuffle(_DistributedEnv(8, 7, 2), chunks_per_ranks, 42, 2)
17+
assert shuffled_indexes == [5, 2, 0, 7, 6, 1, 3, 4, 13, 10, 8, 15, 14, 9, 11, 12]
18+
19+
1520
def test_associate_chunks_and_internals_to_ranks():
1621
indexes = [0, 1, 2, 3, 4, 5, 6, 7]
1722
chunk_intervals = [[0, 50], [0, 50], [0, 50], [0, 50], [0, 50], [0, 50], [0, 50], [0, 50]]

0 commit comments

Comments
 (0)