|
19 | 19 |
|
20 | 20 | from lightning.data.streaming import Cache |
21 | 21 | from lightning.data.utilities.env import _DistributedEnv |
| 22 | +from lightning.data.utilities.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle |
22 | 23 |
|
23 | 24 |
|
24 | 25 | class Shuffle(ABC): |
@@ -129,76 +130,3 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c |
129 | 130 |
|
130 | 131 | def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]: |
131 | 132 | return np.random.RandomState([self.seed, num_chunks * current_epoch, chunk_index]).permutation(array).tolist() |
132 | | - |
133 | | - |
134 | | -def _intra_node_chunk_shuffle( |
135 | | - distributed_env: _DistributedEnv, |
136 | | - chunks_per_ranks: List[List[int]], |
137 | | - seed: int, |
138 | | - current_epoch: int, |
139 | | -) -> List[int]: |
140 | | - chunk_indexes_per_nodes: Any = [[] for _ in range(distributed_env.num_nodes)] |
141 | | - for rank, chunks_per_rank in enumerate(chunks_per_ranks): |
142 | | - chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // distributed_env.num_nodes].extend( |
143 | | - chunks_per_rank |
144 | | - ) |
145 | | - |
146 | | - # shuffle the chunks associated to the node |
147 | | - for i in range(len(chunk_indexes_per_nodes)): |
148 | | - # permute the indexes within the node |
149 | | - chunk_indexes_per_nodes[i] = np.random.RandomState(seed=seed + current_epoch).permutation( |
150 | | - chunk_indexes_per_nodes[i] |
151 | | - ) |
152 | | - |
153 | | - return [index for chunks in chunk_indexes_per_nodes for index in chunks] |
154 | | - |
155 | | - |
156 | | -def _associate_chunks_and_internals_to_ranks( |
157 | | - distributed_env: _DistributedEnv, |
158 | | - indexes: Any, |
159 | | - chunk_intervals: Any, |
160 | | - drop_last: bool, |
161 | | -) -> Tuple[List[List[int]], List[Any]]: |
162 | | - num_items = sum([(interval[-1] - interval[0]) for interval in chunk_intervals]) |
163 | | - num_items_per_ranks: List[int] = [ |
164 | | - num_items // distributed_env.world_size + num_items % distributed_env.world_size |
165 | | - if rank == distributed_env.world_size - 1 and not drop_last |
166 | | - else num_items // distributed_env.world_size |
167 | | - for rank in range(distributed_env.world_size) |
168 | | - ] |
169 | | - chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)] |
170 | | - intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)] |
171 | | - |
172 | | - # 4. Assign the chunk & intervals to each rank |
173 | | - for chunk_index, chunk_interval in zip(indexes, chunk_intervals): |
174 | | - rank = 0 |
175 | | - |
176 | | - while True: |
177 | | - if rank == len(num_items_per_ranks): |
178 | | - break |
179 | | - |
180 | | - items_left_to_assign = num_items_per_ranks[rank] |
181 | | - |
182 | | - if items_left_to_assign == 0: |
183 | | - rank += 1 |
184 | | - continue |
185 | | - |
186 | | - items_in_chunk = chunk_interval[-1] - chunk_interval[0] |
187 | | - |
188 | | - if items_in_chunk == 0: |
189 | | - break |
190 | | - |
191 | | - if items_in_chunk > items_left_to_assign: |
192 | | - chunks_per_ranks[rank].append(chunk_index) |
193 | | - begin, end = chunk_interval |
194 | | - intervals_per_ranks[rank].append([begin, begin + items_left_to_assign]) |
195 | | - chunk_interval = (begin + items_left_to_assign, end) |
196 | | - num_items_per_ranks[rank] = 0 |
197 | | - rank += 1 |
198 | | - else: |
199 | | - chunks_per_ranks[rank].append(chunk_index) |
200 | | - intervals_per_ranks[rank].append(chunk_interval) |
201 | | - num_items_per_ranks[rank] -= items_in_chunk |
202 | | - break |
203 | | - |
204 | | - return chunks_per_ranks, intervals_per_ranks |
0 commit comments