Skip to content

Commit cdfe9eb

Browse files
Apply suggestions from code review
Co-authored-by: Miguel de Benito Delgado <[email protected]>
1 parent cdabd2f commit cdfe9eb

File tree

1 file changed

+34
-37
lines changed

1 file changed

+34
-37
lines changed

src/pydvl/utils/parallel/map_reduce.py

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -239,15 +239,14 @@ def _backpressure(
239239
"""
240240
if self.max_parallel_tasks is None:
241241
return 0
242-
else:
243-
while (n_in_flight := n_dispatched - n_finished) > self.max_parallel_tasks:
244-
wait_for_num_jobs = n_in_flight - self.max_parallel_tasks
245-
finished_jobs, _ = self.parallel_backend.wait(
246-
jobs,
247-
num_returns=wait_for_num_jobs,
248-
timeout=10, # FIXME make parameter?
249-
)
250-
n_finished += len(finished_jobs)
242+
while (n_in_flight := n_dispatched - n_finished) > self.max_parallel_tasks:
243+
wait_for_num_jobs = n_in_flight - self.max_parallel_tasks
244+
finished_jobs, _ = self.parallel_backend.wait(
245+
jobs,
246+
num_returns=wait_for_num_jobs,
247+
timeout=10, # FIXME make parameter?
248+
)
249+
n_finished += len(finished_jobs)
251250
return n_finished
252251

253252
def _chunkify(self, data: ChunkifyInputType, n_chunks: int) -> List["ObjectRef[T]"]:
@@ -257,41 +256,39 @@ def _chunkify(self, data: ChunkifyInputType, n_chunks: int) -> List["ObjectRef[T
257256
if n_chunks <= 0:
258257
raise ValueError("Number of chunks should be greater than 0")
259258

260-
elif n_chunks == 1:
259+
if n_chunks == 1:
261260
data_id = self.parallel_backend.put(data)
262261
return [data_id]
262+
263+
try:
264+
# This is used as a check to determine whether data is iterable or not
265+
# if it's the former, then the value will be used to determine the chunk indices.
266+
n = len(data)
267+
except TypeError:
268+
data_id = self.parallel_backend.put(data)
269+
return list(repeat(data_id, times=n_chunks))
263270
else:
264-
try:
265-
# This is used as a check to determine whether data is iterable or not
266-
# if it's the former, then the value will be used to determine the chunk indices.
267-
n = len(data)
268-
except TypeError:
269-
data_id = self.parallel_backend.put(data)
270-
return list(repeat(data_id, times=n_chunks))
271-
else:
272-
# This is very much inspired by numpy's array_split function
273-
# The difference is that it only uses built-in functions
274-
# and does not convert the input data to an array
275-
chunk_size, remainder = divmod(n, n_chunks)
276-
chunk_indices = tuple(
277-
accumulate(
278-
[0]
279-
+ remainder * [chunk_size + 1]
280-
+ (n_chunks - remainder) * [chunk_size]
281-
)
271+
# This is very much inspired by numpy's array_split function
272+
# The difference is that it only uses built-in functions
273+
# and does not convert the input data to an array
274+
chunk_size, remainder = divmod(n, n_chunks)
275+
chunk_indices = tuple(
276+
accumulate(
277+
[0]
278+
+ remainder * [chunk_size + 1]
279+
+ (n_chunks - remainder) * [chunk_size]
282280
)
281+
)
283282

284-
chunks = []
283+
chunks = []
285284

286-
for start_index, end_index in zip(
287-
chunk_indices[:-1], chunk_indices[1:]
288-
):
289-
if start_index >= end_index:
290-
break
291-
chunk_id = self.parallel_backend.put(data[start_index:end_index])
292-
chunks.append(chunk_id)
285+
for start_index, end_index in zip(chunk_indices[:-1], chunk_indices[1:]):
286+
if start_index >= end_index:
287+
break
288+
chunk_id = self.parallel_backend.put(data[start_index:end_index])
289+
chunks.append(chunk_id)
293290

294-
return chunks
291+
return chunks
295292

296293
@property
297294
def n_jobs(self) -> int:

0 commit comments

Comments
 (0)