Skip to content
This repository was archived by the owner on Dec 1, 2025. It is now read-only.

Commit afd6186

Browse files
author
JacksonMaxfield
authored
feature/compute-default-batch-size (#1)
1 parent d02b94c commit afd6186

File tree

3 files changed

+53
-6
lines changed

3 files changed

+53
-6
lines changed

aics_dask_utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ def get_module_version():
1313
return __version__
1414

1515

16-
from .distributed_handler import DistributedHandler # noqa: F401
16+
from .distributed_handler import DEFAULT_MAX_THREADS, DistributedHandler # noqa: F401

aics_dask_utils/distributed_handler.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*-
33

4+
import os
45
from concurrent.futures import Future as ThreadFuture
56
from concurrent.futures import ThreadPoolExecutor as ThreadClient
67
from typing import Any, Iterable, List, Optional, Union
@@ -10,6 +11,11 @@
1011

1112
#######################################################################################
1213

14+
# Equivalent to the default in ThreadPoolExecutor
15+
DEFAULT_MAX_THREADS = os.cpu_count() * 5
16+
17+
#######################################################################################
18+
1319

1420
class DistributedHandler:
1521
"""
@@ -66,8 +72,26 @@ def client(self):
6672
"""
6773
return self._client
6874

75+
@staticmethod
76+
def _get_batch_size(client: Union[DaskClient, ThreadClient]) -> int:
77+
"""
78+
Returns an integer that matches either the number of Dask workers or number of
79+
threads available.
80+
"""
81+
# Handle dask
82+
if isinstance(client, DaskClient):
83+
# Using a LocalCluster with processes = False
84+
if client.cluster is None:
85+
return DEFAULT_MAX_THREADS
86+
87+
# In all other cases, there will be a cluster attached
88+
return len(client.cluster.workers)
89+
90+
# Return default number of max threads
91+
return DEFAULT_MAX_THREADS
92+
6993
def batched_map(
70-
self, func, *iterables, batch_size: int = 10, **kwargs,
94+
self, func, *iterables, batch_size: Optional[int] = None, **kwargs,
7195
) -> List[Any]:
7296
"""
7397
Map a function across iterables in a batched fashion.
@@ -90,8 +114,9 @@ def batched_map(
90114
A serializable callable function to run across each iterable set.
91115
iterables: Iterables
92116
List-like objects to map over. They should have the same length.
93-
batch_size: int
117+
batch_size: Optional[int]
94118
Number of items to process and _complete_ in a single batch.
119+
Default: number of available workers or threads.
95120
**kwargs: dict
96121
Other keyword arguments to pass down to this handler's client.
97122
@@ -101,6 +126,11 @@ def batched_map(
101126
The complete results of all items after they have been fully processed
102127
and gathered.
103128
"""
129+
# If no batch size was provided, get batch size based off client
130+
if batch_size is None:
131+
batch_size = self._get_batch_size(self.client)
132+
133+
# Batch process iterables
104134
results = []
105135
for i in range(0, len(iterables[0]), batch_size):
106136
this_batch_iterables = []

aics_dask_utils/tests/test_distributed_handler.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*-
33

4-
import pytest
5-
from aics_dask_utils import DistributedHandler
6-
74
from concurrent.futures import ThreadPoolExecutor
5+
6+
import pytest
87
from distributed import Client, LocalCluster
98

9+
from aics_dask_utils import DEFAULT_MAX_THREADS, DistributedHandler
10+
1011

1112
@pytest.mark.parametrize(
1213
"values, expected_values",
@@ -69,3 +70,19 @@ def test_distributed_handler_distributed(values, expected_values):
6970
handler_map_results == handler_batched_results
7071
and handler_map_results == distributed_results
7172
)
73+
74+
cluster.close()
75+
76+
77+
def test_get_batch_size_threadpool():
78+
with DistributedHandler() as handler:
79+
assert handler._get_batch_size(handler.client) == DEFAULT_MAX_THREADS
80+
81+
82+
def test_get_batch_size_distributed():
83+
cluster = LocalCluster(processes=False)
84+
85+
with DistributedHandler(cluster.scheduler_address) as handler:
86+
assert handler._get_batch_size(handler.client) == DEFAULT_MAX_THREADS
87+
88+
cluster.close()

0 commit comments

Comments
 (0)