@@ -97,55 +97,108 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c
97
97
# 2. Shuffle them
98
98
indexes = range (len (chunk_intervals ))
99
99
100
- # FIXME: Shuffling should be done only within the nodes to benefit
101
- # from cache if the dataset doesn't fit on the node.
102
- shuffled_indexes = np .random .RandomState (seed = self .seed + current_epoch ).permutation (indexes )
103
- shuffled_chunk_intervals = np .asarray (chunk_intervals )[shuffled_indexes ]
100
+ # If we have multiple nodes, the seed_shift is constant here.
101
+ # Here is why. When you are running epoch 1, we need to shuffle the chunks
102
+ # and associate to each rank. This is done there.
103
+ # When you are running epoch 2 or more, we need to keep the same shuffling
104
+ # than in epoch 1 because shuffle a second time within the node.
105
+ # This is done slighyly down this function.
106
+ seed_shift = 1 if distributed_env .num_nodes > 1 else current_epoch
107
+ shuffled_indexes = np .random .RandomState (seed = self .seed + seed_shift ).permutation (indexes )
108
+ shuffled_chunk_intervals = np .asarray (chunk_intervals )[shuffled_indexes ].tolist ()
104
109
105
110
# 3. Compute the items budget of each rank
106
- num_items = sum ([(interval [- 1 ] - interval [0 ]) for interval in chunk_intervals ])
107
- num_items_per_ranks : List [int ] = [
108
- num_items // distributed_env .world_size + num_items % distributed_env .world_size
109
- if rank == distributed_env .world_size - 1 and not self .drop_last
110
- else num_items // distributed_env .world_size
111
- for rank in range (distributed_env .world_size )
112
- ]
113
- chunks_per_ranks : List [List [int ]] = [[] for _ in range (distributed_env .world_size )]
114
- intervals_per_ranks : List [List [List [int ]]] = [[] for _ in range (distributed_env .world_size )]
115
-
116
- # 4. Assign the chunk & intervals to each rank
117
- for chunk_index , chunk_interval in zip (shuffled_indexes , shuffled_chunk_intervals ):
118
- rank = 0
119
-
120
- while True :
121
- if rank == len (num_items_per_ranks ):
122
- break
123
-
124
- items_left_to_assign = num_items_per_ranks [rank ]
111
+ chunks_per_ranks , intervals_per_ranks = _associate_chunks_and_internals_to_ranks (
112
+ distributed_env , shuffled_indexes , shuffled_chunk_intervals , self .drop_last
113
+ )
125
114
126
- if items_left_to_assign == 0 :
127
- rank += 1
128
- continue
115
+ # For the first epoch, no need of further shuffling
116
+ if current_epoch == 1 or distributed_env . num_nodes == 1 :
117
+ return chunks_per_ranks , intervals_per_ranks
129
118
130
- items_in_chunk = chunk_interval [- 1 ] - chunk_interval [0 ]
119
+ # Perform shuffle within the nodes to avoid cache miss.
120
+ # Note: It is possible for the overlapping chunks to change due to the changing order.
121
+ shuffled_indexes = _intra_node_chunk_shuffle (distributed_env , chunks_per_ranks , self .seed , current_epoch )
122
+ shuffled_chunk_intervals = np .asarray (chunk_intervals )[shuffled_indexes ].tolist ()
131
123
132
- if items_in_chunk == 0 :
133
- break
134
-
135
- if items_in_chunk > items_left_to_assign :
136
- chunks_per_ranks [rank ].append (chunk_index )
137
- begin , end = chunk_interval
138
- intervals_per_ranks [rank ].append ([begin , begin + items_left_to_assign ])
139
- chunk_interval = (begin + items_left_to_assign , end )
140
- num_items_per_ranks [rank ] = 0
141
- rank += 1
142
- else :
143
- chunks_per_ranks [rank ].append (chunk_index )
144
- intervals_per_ranks [rank ].append (chunk_interval )
145
- num_items_per_ranks [rank ] -= items_in_chunk
146
- break
124
+ chunks_per_ranks , intervals_per_ranks = _associate_chunks_and_internals_to_ranks (
125
+ distributed_env , shuffled_indexes , shuffled_chunk_intervals , self .drop_last
126
+ )
147
127
148
128
return chunks_per_ranks , intervals_per_ranks
149
129
150
130
def __call__ (self , array : np .ndarray , num_chunks : int , current_epoch : int , chunk_index : int ) -> List [int ]:
151
131
return np .random .RandomState ([self .seed , num_chunks * current_epoch , chunk_index ]).permutation (array ).tolist ()
132
+
133
+
134
+ def _intra_node_chunk_shuffle (
135
+ distributed_env : _DistributedEnv ,
136
+ chunks_per_ranks : List [List [int ]],
137
+ seed : int ,
138
+ current_epoch : int ,
139
+ ) -> List [int ]:
140
+ chunk_indexes_per_nodes : Any = [[] for _ in range (distributed_env .num_nodes )]
141
+ for rank , chunks_per_rank in enumerate (chunks_per_ranks ):
142
+ chunk_indexes_per_nodes [0 if distributed_env .num_nodes == 1 else rank // distributed_env .num_nodes ].extend (
143
+ chunks_per_rank
144
+ )
145
+
146
+ # shuffle the chunks associated to the node
147
+ for i in range (len (chunk_indexes_per_nodes )):
148
+ # permute the indexes within the node
149
+ chunk_indexes_per_nodes [i ] = np .random .RandomState (seed = seed + current_epoch ).permutation (
150
+ chunk_indexes_per_nodes [i ]
151
+ )
152
+
153
+ return [index for chunks in chunk_indexes_per_nodes for index in chunks ]
154
+
155
+
156
+ def _associate_chunks_and_internals_to_ranks (
157
+ distributed_env : _DistributedEnv ,
158
+ indexes : Any ,
159
+ chunk_intervals : Any ,
160
+ drop_last : bool ,
161
+ ) -> Tuple [List [List [int ]], List [Any ]]:
162
+ num_items = sum ([(interval [- 1 ] - interval [0 ]) for interval in chunk_intervals ])
163
+ num_items_per_ranks : List [int ] = [
164
+ num_items // distributed_env .world_size + num_items % distributed_env .world_size
165
+ if rank == distributed_env .world_size - 1 and not drop_last
166
+ else num_items // distributed_env .world_size
167
+ for rank in range (distributed_env .world_size )
168
+ ]
169
+ chunks_per_ranks : List [List [int ]] = [[] for _ in range (distributed_env .world_size )]
170
+ intervals_per_ranks : List [List [List [int ]]] = [[] for _ in range (distributed_env .world_size )]
171
+
172
+ # 4. Assign the chunk & intervals to each rank
173
+ for chunk_index , chunk_interval in zip (indexes , chunk_intervals ):
174
+ rank = 0
175
+
176
+ while True :
177
+ if rank == len (num_items_per_ranks ):
178
+ break
179
+
180
+ items_left_to_assign = num_items_per_ranks [rank ]
181
+
182
+ if items_left_to_assign == 0 :
183
+ rank += 1
184
+ continue
185
+
186
+ items_in_chunk = chunk_interval [- 1 ] - chunk_interval [0 ]
187
+
188
+ if items_in_chunk == 0 :
189
+ break
190
+
191
+ if items_in_chunk > items_left_to_assign :
192
+ chunks_per_ranks [rank ].append (chunk_index )
193
+ begin , end = chunk_interval
194
+ intervals_per_ranks [rank ].append ([begin , begin + items_left_to_assign ])
195
+ chunk_interval = (begin + items_left_to_assign , end )
196
+ num_items_per_ranks [rank ] = 0
197
+ rank += 1
198
+ else :
199
+ chunks_per_ranks [rank ].append (chunk_index )
200
+ intervals_per_ranks [rank ].append (chunk_interval )
201
+ num_items_per_ranks [rank ] -= items_in_chunk
202
+ break
203
+
204
+ return chunks_per_ranks , intervals_per_ranks
0 commit comments