@@ -239,15 +239,14 @@ def _backpressure(
239239 """
240240 if self .max_parallel_tasks is None :
241241 return 0
242- else :
243- while (n_in_flight := n_dispatched - n_finished ) > self .max_parallel_tasks :
244- wait_for_num_jobs = n_in_flight - self .max_parallel_tasks
245- finished_jobs , _ = self .parallel_backend .wait (
246- jobs ,
247- num_returns = wait_for_num_jobs ,
248- timeout = 10 , # FIXME make parameter?
249- )
250- n_finished += len (finished_jobs )
242+ while (n_in_flight := n_dispatched - n_finished ) > self .max_parallel_tasks :
243+ wait_for_num_jobs = n_in_flight - self .max_parallel_tasks
244+ finished_jobs , _ = self .parallel_backend .wait (
245+ jobs ,
246+ num_returns = wait_for_num_jobs ,
247+ timeout = 10 , # FIXME make parameter?
248+ )
249+ n_finished += len (finished_jobs )
251250 return n_finished
252251
253252 def _chunkify (self , data : ChunkifyInputType , n_chunks : int ) -> List ["ObjectRef[T]" ]:
@@ -257,41 +256,39 @@ def _chunkify(self, data: ChunkifyInputType, n_chunks: int) -> List["ObjectRef[T
257256 if n_chunks <= 0 :
258257 raise ValueError ("Number of chunks should be greater than 0" )
259258
260- elif n_chunks == 1 :
259+ if n_chunks == 1 :
261260 data_id = self .parallel_backend .put (data )
262261 return [data_id ]
262+
263+ try :
264+ # This is used as a check to determine whether data is iterable or not
265+ # if it's the former, then the value will be used to determine the chunk indices.
266+ n = len (data )
267+ except TypeError :
268+ data_id = self .parallel_backend .put (data )
269+ return list (repeat (data_id , times = n_chunks ))
263270 else :
264- try :
265- # This is used as a check to determine whether data is iterable or not
266- # if it's the former, then the value will be used to determine the chunk indices.
267- n = len (data )
268- except TypeError :
269- data_id = self .parallel_backend .put (data )
270- return list (repeat (data_id , times = n_chunks ))
271- else :
272- # This is very much inspired by numpy's array_split function
273- # The difference is that it only uses built-in functions
274- # and does not convert the input data to an array
275- chunk_size , remainder = divmod (n , n_chunks )
276- chunk_indices = tuple (
277- accumulate (
278- [0 ]
279- + remainder * [chunk_size + 1 ]
280- + (n_chunks - remainder ) * [chunk_size ]
281- )
271+ # This is very much inspired by numpy's array_split function
272+ # The difference is that it only uses built-in functions
273+ # and does not convert the input data to an array
274+ chunk_size , remainder = divmod (n , n_chunks )
275+ chunk_indices = tuple (
276+ accumulate (
277+ [0 ]
278+ + remainder * [chunk_size + 1 ]
279+ + (n_chunks - remainder ) * [chunk_size ]
282280 )
281+ )
283282
284- chunks = []
283+ chunks = []
285284
286- for start_index , end_index in zip (
287- chunk_indices [:- 1 ], chunk_indices [1 :]
288- ):
289- if start_index >= end_index :
290- break
291- chunk_id = self .parallel_backend .put (data [start_index :end_index ])
292- chunks .append (chunk_id )
285+ for start_index , end_index in zip (chunk_indices [:- 1 ], chunk_indices [1 :]):
286+ if start_index >= end_index :
287+ break
288+ chunk_id = self .parallel_backend .put (data [start_index :end_index ])
289+ chunks .append (chunk_id )
293290
294- return chunks
291+ return chunks
295292
296293 @property
297294 def n_jobs (self ) -> int :
0 commit comments