24
24
class Shuffle (ABC ):
25
25
"""Shuffle describe how to distribute chunked datasets across processes and workers."""
26
26
27
- def __init__ (self , cache : Cache , seed : int ):
27
+ def __init__ (self , cache : Cache , seed : int , drop_last : bool ):
28
28
self .cache = cache
29
29
self .seed = seed
30
+ self .drop_last = drop_last
30
31
self .random_state = None
31
32
32
- @abstractmethod
33
+ @lru_cache ( maxsize = 10 )
33
34
def get_len (self , distributed_env : _DistributedEnv , current_epoch : int ) -> int :
34
- pass
35
+ _ , intervals_per_ranks = self .get_chunks_and_intervals_per_ranks (distributed_env , current_epoch )
36
+
37
+ if self .drop_last :
38
+ items_per_process = [
39
+ sum ((interval [- 1 ] - interval [0 ]) for interval in intervals ) for intervals in intervals_per_ranks
40
+ ]
41
+ min_items_per_process = min (items_per_process )
42
+ return min_items_per_process
43
+
44
+ return sum ((interval [- 1 ] - interval [0 ]) for interval in intervals_per_ranks [distributed_env .global_rank ])
35
45
36
46
@abstractmethod
37
- def get_chunks_and_intervals_per_process (self , distributed_env : _DistributedEnv , current_epoch : int ) -> Any :
47
+ def get_chunks_and_intervals_per_ranks (self , distributed_env : _DistributedEnv , current_epoch : int ) -> Any :
38
48
pass
39
49
40
50
@abstractmethod
@@ -43,79 +53,29 @@ def __call__(self, array: np.ndarray) -> List[int]:
43
53
44
54
45
55
class NoShuffle (Shuffle ):
46
- """NoShuffle doesn't shuffle the items and ensure all the processes receive the same number of items."""
56
+ """NoShuffle doesn't shuffle the items and ensure all the processes receive the same number of items if drop_last
57
+ is True."""
47
58
48
59
@lru_cache (maxsize = 10 )
49
- def get_len (self , distributed_env : _DistributedEnv , current_epoch : int ) -> int :
50
- _ , intervals_per_process = self .get_chunks_and_intervals_per_process (distributed_env , current_epoch )
51
- min_items_per_process = min (
52
- [sum ([(interval [- 1 ] - interval [0 ]) for interval in intervals ]) for intervals in intervals_per_process ]
53
- )
54
- return min_items_per_process
55
-
56
- @lru_cache (maxsize = 10 )
57
- def get_chunks_and_intervals_per_process (self , distributed_env : _DistributedEnv , current_epoch : int ) -> Any :
60
+ def get_chunks_and_intervals_per_ranks (self , distributed_env : _DistributedEnv , current_epoch : int ) -> Any :
58
61
self .random_state = np .random .RandomState (seed = self .seed + current_epoch ) # type: ignore
59
62
chunk_intervals = self .cache .get_chunk_intervals ()
60
63
indexes = list (range (len (chunk_intervals )))
61
64
shuffled_chunk_intervals = np .asarray (chunk_intervals )[indexes ]
62
65
63
- chunks_per_process : List [List [int ]] = [[] for _ in range (distributed_env .world_size )]
64
- intervals_per_process : List [List [List [int ]]] = [[] for _ in range (distributed_env .world_size )]
66
+ chunks_per_ranks : List [List [int ]] = [[] for _ in range (distributed_env .world_size )]
67
+ intervals_per_ranks : List [List [List [int ]]] = [[] for _ in range (distributed_env .world_size )]
65
68
for index , (chunk_index , chunk_interval ) in enumerate (zip (indexes , shuffled_chunk_intervals )):
66
69
replica_index = index % distributed_env .world_size
67
- chunks_per_process [replica_index ].append (chunk_index )
68
- intervals_per_process [replica_index ].append (chunk_interval )
70
+ chunks_per_ranks [replica_index ].append (chunk_index )
71
+ intervals_per_ranks [replica_index ].append (chunk_interval )
69
72
70
- return chunks_per_process , intervals_per_process
73
+ return chunks_per_ranks , intervals_per_ranks
71
74
72
75
def __call__ (self , array : np .ndarray ) -> List [int ]:
73
76
return array .tolist ()
74
77
75
78
76
- class TruncatedShuffle (Shuffle ):
77
- """TruncatedShuffle shuffles the chunks and associates them to the ranks.
78
-
79
- As the number of items in a chunk varies, it is possible for a rank to end up with more or less items.
80
-
81
- To ensure the same fixed dataset length for all ranks, we compute the minimum number of items across all ranks.
82
-
83
- For the ranks with more items than the minimum, the remaining items are dropped.
84
-
85
- Note: This is the fastest sampling strategy but at the cost of losing items.
86
-
87
- """
88
-
89
- @lru_cache (maxsize = 10 )
90
- def get_len (self , distributed_env : _DistributedEnv , current_epoch : int ) -> int :
91
- _ , intervals_per_process = self .get_chunks_and_intervals_per_process (distributed_env , current_epoch )
92
- min_items_per_process = min (
93
- [sum ([(interval [- 1 ] - interval [0 ]) for interval in intervals ]) for intervals in intervals_per_process ]
94
- )
95
- return min_items_per_process
96
-
97
- @lru_cache (maxsize = 10 )
98
- def get_chunks_and_intervals_per_process (self , distributed_env : _DistributedEnv , current_epoch : int ) -> Any :
99
- self .random_state = np .random .RandomState (seed = self .seed + current_epoch ) # type: ignore
100
- chunk_intervals = self .cache .get_chunk_intervals ()
101
- indexes = range (len (chunk_intervals ))
102
- shuffled_indexes = self .random_state .permutation (indexes )
103
- shuffled_chunk_intervals = np .asarray (chunk_intervals )[shuffled_indexes ]
104
-
105
- chunks_per_process : List [List [int ]] = [[] for _ in range (distributed_env .world_size )]
106
- intervals_per_process : List [List [List [int ]]] = [[] for _ in range (distributed_env .world_size )]
107
- for index , (chunk_index , chunk_interval ) in enumerate (zip (shuffled_indexes , shuffled_chunk_intervals )):
108
- replica_index = index % distributed_env .world_size
109
- chunks_per_process [replica_index ].append (chunk_index )
110
- intervals_per_process [replica_index ].append (chunk_interval )
111
-
112
- return chunks_per_process , intervals_per_process
113
-
114
- def __call__ (self , array : np .ndarray ) -> List [int ]:
115
- assert self .random_state
116
- return self .random_state .permutation (array ).tolist ()
117
-
118
-
119
79
class FullShuffle (Shuffle ):
120
80
"""FullShuffle shuffles the chunks and associates them to the ranks.
121
81
@@ -135,36 +95,40 @@ class FullShuffle(Shuffle):
135
95
"""
136
96
137
97
@lru_cache (maxsize = 10 )
138
- def get_len (self , distributed_env : _DistributedEnv , current_epoch : int ) -> int :
139
- _ , intervals_per_process = self .get_chunks_and_intervals_per_process (distributed_env , current_epoch )
140
- min_items_per_process = min ([sum ([(i [- 1 ] - i [0 ]) for i in intervals ]) for intervals in intervals_per_process ])
141
- return min_items_per_process
142
-
143
- @lru_cache (maxsize = 10 )
144
- def get_chunks_and_intervals_per_process (self , distributed_env : _DistributedEnv , current_epoch : int ) -> Any :
98
+ def get_chunks_and_intervals_per_ranks (self , distributed_env : _DistributedEnv , current_epoch : int ) -> Any :
145
99
self .random_state = np .random .RandomState (seed = self .seed + current_epoch ) # type: ignore
100
+
101
+ # 1. Get the intervals
146
102
chunk_intervals = self .cache .get_chunk_intervals ()
103
+
104
+ # 2. Shuffle them
147
105
indexes = range (len (chunk_intervals ))
148
106
shuffled_indexes = self .random_state .permutation (indexes )
149
107
shuffled_chunk_intervals = np .asarray (chunk_intervals )[shuffled_indexes ]
150
108
109
+ # 3. Compute the items budget of each rank
151
110
num_items = sum ([(interval [- 1 ] - interval [0 ]) for interval in chunk_intervals ])
152
- num_items_per_process : List [int ] = [
153
- num_items // distributed_env .world_size for _ in range (distributed_env .world_size )
111
+ num_items_per_ranks : List [int ] = [
112
+ num_items // distributed_env .world_size + num_items % distributed_env .world_size
113
+ if rank == distributed_env .world_size - 1 and not self .drop_last
114
+ else num_items // distributed_env .world_size
115
+ for rank in range (distributed_env .world_size )
154
116
]
155
- chunks_per_process : List [List [int ]] = [[] for _ in range (distributed_env .world_size )]
156
- intervals_per_process : List [List [List [int ]]] = [[] for _ in range (distributed_env .world_size )]
117
+ chunks_per_ranks : List [List [int ]] = [[] for _ in range (distributed_env .world_size )]
118
+ intervals_per_ranks : List [List [List [int ]]] = [[] for _ in range (distributed_env .world_size )]
119
+
120
+ # 4. Assign the chunk & intervals to each rank
157
121
for chunk_index , chunk_interval in zip (shuffled_indexes , shuffled_chunk_intervals ):
158
- process_index = 0
122
+ rank = 0
159
123
160
124
while True :
161
- if process_index == len (num_items_per_process ):
125
+ if rank == len (num_items_per_ranks ):
162
126
break
163
127
164
- items_left_to_assign = num_items_per_process [ process_index ]
128
+ items_left_to_assign = num_items_per_ranks [ rank ]
165
129
166
130
if items_left_to_assign == 0 :
167
- process_index += 1
131
+ rank += 1
168
132
continue
169
133
170
134
items_in_chunk = chunk_interval [- 1 ] - chunk_interval [0 ]
@@ -173,19 +137,19 @@ def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv,
173
137
break
174
138
175
139
if items_in_chunk > items_left_to_assign :
176
- chunks_per_process [ process_index ].append (chunk_index )
140
+ chunks_per_ranks [ rank ].append (chunk_index )
177
141
begin , end = chunk_interval
178
- intervals_per_process [ process_index ].append ([begin , begin + items_left_to_assign ])
179
- chunk_interval = (begin + items_left_to_assign + 1 , end )
180
- num_items_per_process [ process_index ] = 0
181
- process_index += 1
142
+ intervals_per_ranks [ rank ].append ([begin , begin + items_left_to_assign ])
143
+ chunk_interval = (begin + items_left_to_assign , end )
144
+ num_items_per_ranks [ rank ] = 0
145
+ rank += 1
182
146
else :
183
- chunks_per_process [ process_index ].append (chunk_index )
184
- intervals_per_process [ process_index ].append (chunk_interval )
185
- num_items_per_process [ process_index ] -= items_in_chunk
147
+ chunks_per_ranks [ rank ].append (chunk_index )
148
+ intervals_per_ranks [ rank ].append (chunk_interval )
149
+ num_items_per_ranks [ rank ] -= items_in_chunk
186
150
break
187
151
188
- return chunks_per_process , intervals_per_process
152
+ return chunks_per_ranks , intervals_per_ranks
189
153
190
154
def __call__ (self , array : np .ndarray ) -> List [int ]:
191
155
assert self .random_state
0 commit comments