11#!/usr/bin/env python
22# -*- coding: utf-8 -*-
33
4+ import os
45from concurrent .futures import Future as ThreadFuture
56from concurrent .futures import ThreadPoolExecutor as ThreadClient
67from typing import Any , Iterable , List , Optional , Union
1011
1112#######################################################################################
1213
14+ # Equivalent to the default in ThreadPoolExecutor
15+ DEFAULT_MAX_THREADS = os .cpu_count () * 5
16+
17+ #######################################################################################
18+
1319
1420class DistributedHandler :
1521 """
@@ -66,8 +72,26 @@ def client(self):
6672 """
6773 return self ._client
6874
75+ @staticmethod
76+ def _get_batch_size (client : Union [DaskClient , ThreadClient ]) -> int :
77+ """
78+ Returns an integer that matches either the number of Dask workers or number of
79+ threads available.
80+ """
81+ # Handle dask
82+ if isinstance (client , DaskClient ):
83+ # Using a LocalCluster with processes = False
84+ if client .cluster is None :
85+ return DEFAULT_MAX_THREADS
86+
87+ # In all other cases, there will be a cluster attached
88+ return len (client .cluster .workers )
89+
90+ # Return default number of max threads
91+ return DEFAULT_MAX_THREADS
92+
6993 def batched_map (
70- self , func , * iterables , batch_size : int = 10 , ** kwargs ,
94+ self , func , * iterables , batch_size : Optional [ int ] = None , ** kwargs ,
7195 ) -> List [Any ]:
7296 """
7397 Map a function across iterables in a batched fashion.
@@ -90,8 +114,9 @@ def batched_map(
90114 A serializable callable function to run across each iterable set.
91115 iterables: Iterables
92116 List-like objects to map over. They should have the same length.
93- batch_size: int
117+ batch_size: Optional[ int]
94118 Number of items to process and _complete_ in a single batch.
119+ Default: number of available workers or threads.
95120 **kwargs: dict
96121 Other keyword arguments to pass down to this handler's client.
97122
@@ -101,6 +126,11 @@ def batched_map(
101126 The complete results of all items after they have been fully processed
102127 and gathered.
103128 """
129+ # If no batch size was provided, get batch size based off client
130+ if batch_size is None :
131+ batch_size = self ._get_batch_size (self .client )
132+
133+ # Batch process iterables
104134 results = []
105135 for i in range (0 , len (iterables [0 ]), batch_size ):
106136 this_batch_iterables = []
0 commit comments