Skip to content

Commit 0abf712

Browse files
Merge pull request #478 from reef-technologies/fix_upload_threads
allow set_thread_pool_size to be set after pool has been once used already
2 parents bf573f3 + ed48c42 commit 0abf712

File tree

5 files changed

+146
-45
lines changed

5 files changed

+146
-45
lines changed

b2sdk/utils/thread_pool.py

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,106 @@
99
######################################################################
1010
from __future__ import annotations
1111

12-
from concurrent.futures import ThreadPoolExecutor
12+
import os
13+
from concurrent.futures import Future, ThreadPoolExecutor
14+
from typing import Callable
15+
16+
try:
17+
from typing_extensions import Protocol
18+
except ImportError:
19+
from typing import Protocol
1320

1421
from b2sdk.utils import B2TraceMetaAbstract
1522

1623

24+
class DynamicThreadPoolExecutorProtocol(Protocol):
25+
def submit(self, fn: Callable, *args, **kwargs) -> Future:
26+
...
27+
28+
def set_size(self, max_workers: int) -> None:
29+
"""Set the size of the thread pool."""
30+
31+
def get_size(self) -> int:
32+
"""Return the current size of the thread pool."""
33+
34+
35+
class LazyThreadPool:
36+
"""
37+
Lazily initialized thread pool.
38+
"""
39+
40+
_THREAD_POOL_FACTORY = ThreadPoolExecutor
41+
42+
def __init__(self, max_workers: int | None = None, **kwargs):
43+
if max_workers is None:
44+
max_workers = min(
45+
32, (os.cpu_count() or 1) + 4
46+
) # same default as in ThreadPoolExecutor
47+
self._max_workers = max_workers
48+
self._thread_pool: ThreadPoolExecutor | None = None
49+
super().__init__(**kwargs)
50+
51+
def submit(self, fn: Callable, *args, **kwargs) -> Future:
52+
if self._thread_pool is None:
53+
self._thread_pool = self._THREAD_POOL_FACTORY(self._max_workers)
54+
return self._thread_pool.submit(fn, *args, **kwargs)
55+
56+
def set_size(self, max_workers: int) -> None:
57+
"""
58+
Set the size of the thread pool.
59+
60+
This operation will block until all tasks in the current thread pool are completed.
61+
62+
:param max_workers: New size of the thread pool
63+
:return: None
64+
"""
65+
if self._max_workers == max_workers:
66+
return
67+
old_thread_pool = self._thread_pool
68+
self._thread_pool = self._THREAD_POOL_FACTORY(max_workers=max_workers)
69+
if old_thread_pool is not None:
70+
old_thread_pool.shutdown(wait=True)
71+
self._max_workers = max_workers
72+
73+
def get_size(self) -> int:
74+
"""Return the current size of the thread pool."""
75+
return self._max_workers
76+
77+
1778
class ThreadPoolMixin(metaclass=B2TraceMetaAbstract):
1879
"""
1980
Mixin class with ThreadPoolExecutor.
2081
"""
21-
DEFAULT_THREAD_POOL_CLASS = staticmethod(ThreadPoolExecutor)
82+
83+
DEFAULT_THREAD_POOL_CLASS = LazyThreadPool
2284

2385
def __init__(
2486
self,
25-
thread_pool: ThreadPoolExecutor | None = None,
87+
thread_pool: DynamicThreadPoolExecutorProtocol | None = None,
2688
max_workers: int | None = None,
27-
**kwargs
89+
**kwargs,
2890
):
2991
"""
3092
:param thread_pool: thread pool to be used
3193
:param max_workers: maximum number of worker threads (ignored if thread_pool is not None)
3294
"""
33-
self._thread_pool = thread_pool if thread_pool is not None \
34-
else self.DEFAULT_THREAD_POOL_CLASS(max_workers=max_workers)
95+
self._thread_pool = (
96+
thread_pool
97+
if thread_pool is not None else self.DEFAULT_THREAD_POOL_CLASS(max_workers=max_workers)
98+
)
99+
self._max_workers = max_workers
35100
super().__init__(**kwargs)
101+
102+
def set_thread_pool_size(self, max_workers: int) -> None:
103+
"""
104+
Set the size of the thread pool.
105+
106+
This operation will block until all tasks in the current thread pool are completed.
107+
108+
:param max_workers: New size of the thread pool
109+
:return: None
110+
"""
111+
return self._thread_pool.set_size(max_workers)
112+
113+
def get_thread_pool_size(self) -> int:
114+
return self._thread_pool.get_size()

b2sdk/v2/transfer.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,46 +9,17 @@
99
######################################################################
1010
from __future__ import annotations
1111

12-
from concurrent.futures import Future, ThreadPoolExecutor
13-
from typing import Callable
14-
1512
from b2sdk import _v3 as v3
16-
17-
18-
class LazyThreadPool:
19-
"""
20-
Lazily initialized thread pool.
21-
"""
22-
23-
def __init__(self, max_workers: int | None = None, **kwargs):
24-
self._max_workers = max_workers
25-
self._thread_pool = None # type: 'Optional[ThreadPoolExecutor]'
26-
super().__init__(**kwargs)
27-
28-
def submit(self, fn: Callable, *args, **kwargs) -> Future:
29-
if self._thread_pool is None:
30-
self._thread_pool = ThreadPoolExecutor(self._max_workers)
31-
return self._thread_pool.submit(fn, *args, **kwargs)
32-
33-
def set_size(self, max_workers: int) -> None:
34-
if self._max_workers == max_workers:
35-
return
36-
if self._thread_pool is not None:
37-
raise RuntimeError('Thread pool already created')
38-
self._max_workers = max_workers
13+
from b2sdk.utils.thread_pool import LazyThreadPool # noqa: F401
3914

4015

4116
class ThreadPoolMixin(v3.ThreadPoolMixin):
42-
DEFAULT_THREAD_POOL_CLASS = staticmethod(LazyThreadPool)
43-
44-
# This method is used in CLI even though it doesn't belong to the public API
45-
def set_thread_pool_size(self, max_workers: int) -> None:
46-
self._thread_pool.set_size(max_workers)
17+
pass
4718

4819

49-
class DownloadManager(v3.DownloadManager, ThreadPoolMixin):
20+
class DownloadManager(v3.DownloadManager):
5021
pass
5122

5223

53-
class UploadManager(v3.UploadManager, ThreadPoolMixin):
24+
class UploadManager(v3.UploadManager):
5425
pass

changelog.d/+set_threads.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add `set_thread_pool_size`, `get_thread_pool_size` to *Manger classes.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
######################################################################
2+
#
3+
# File: test/unit/utils/test_thread_pool.py
4+
#
5+
# Copyright 2024 Backblaze Inc. All Rights Reserved.
6+
#
7+
# License https://www.backblaze.com/using_b2_code.html
8+
#
9+
######################################################################
10+
from concurrent.futures import Future
11+
12+
import pytest
13+
14+
from b2sdk.utils.thread_pool import LazyThreadPool
15+
16+
17+
class TestLazyThreadPool:
18+
@pytest.fixture
19+
def thread_pool(self):
20+
return LazyThreadPool()
21+
22+
def test_submit(self, thread_pool):
23+
24+
future = thread_pool.submit(sum, (1, 2))
25+
assert isinstance(future, Future)
26+
assert future.result() == 3
27+
28+
def test_set_size(self, thread_pool):
29+
thread_pool.set_size(10)
30+
assert thread_pool.get_size() == 10
31+
32+
def test_get_size(self, thread_pool):
33+
assert thread_pool.get_size() > 0
34+
35+
def test_set_size__after_submit(self, thread_pool):
36+
future = thread_pool.submit(sum, (1, 2))
37+
38+
thread_pool.set_size(7)
39+
assert thread_pool.get_size() == 7
40+
41+
assert future.result() == 3
42+
43+
assert thread_pool.submit(sum, (1,)).result() == 1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
######################################################################
22
#
3-
# File: test/unit/v2/test_transfer.py
3+
# File: test/unit/v_all/test_transfer.py
44
#
55
# Copyright 2022 Backblaze Inc. All Rights Reserved.
66
#
@@ -11,19 +11,26 @@
1111

1212
from unittest.mock import Mock
1313

14+
from apiver_deps import DownloadManager, UploadManager
15+
1416
from ..test_base import TestBase
15-
from .apiver.apiver_deps import DownloadManager, UploadManager
1617

1718

1819
class TestDownloadManager(TestBase):
1920
def test_set_thread_pool_size(self) -> None:
2021
download_manager = DownloadManager(services=Mock())
21-
download_manager.set_thread_pool_size(21)
22-
self.assertEqual(download_manager._thread_pool._max_workers, 21)
22+
assert download_manager.get_thread_pool_size() > 0
23+
24+
pool_size = 21
25+
download_manager.set_thread_pool_size(pool_size)
26+
assert download_manager.get_thread_pool_size() == pool_size
2327

2428

2529
class TestUploadManager(TestBase):
2630
def test_set_thread_pool_size(self) -> None:
2731
upload_manager = UploadManager(services=Mock())
28-
upload_manager.set_thread_pool_size(37)
29-
self.assertEqual(upload_manager._thread_pool._max_workers, 37)
32+
assert upload_manager.get_thread_pool_size() > 0
33+
34+
pool_size = 37
35+
upload_manager.set_thread_pool_size(pool_size)
36+
assert upload_manager.get_thread_pool_size() == pool_size

0 commit comments

Comments
 (0)