Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/3321.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow `trio.CapacityLimiter` to have zero total_tokens.
4 changes: 2 additions & 2 deletions src/trio/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ def total_tokens(self) -> int | float:
def total_tokens(self, new_total_tokens: int | float) -> None: # noqa: PYI041
if not isinstance(new_total_tokens, int) and new_total_tokens != math.inf:
raise TypeError("total_tokens must be an int or math.inf")
if new_total_tokens < 1:
raise ValueError("total_tokens must be >= 1")
if new_total_tokens < 0:
raise ValueError("total_tokens must be >= 0")
self._total_tokens = new_total_tokens
self._wake_waiters()

Expand Down
83 changes: 79 additions & 4 deletions src/trio/_tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ async def child() -> None:


async def test_CapacityLimiter() -> None:
assert CapacityLimiter(0).total_tokens == 0
with pytest.raises(TypeError):
CapacityLimiter(1.0)
with pytest.raises(ValueError, match=r"^total_tokens must be >= 1$"):
with pytest.raises(ValueError, match=r"^total_tokens must be >= 0$"):
CapacityLimiter(-1)
c = CapacityLimiter(2)
repr(c) # smoke test
Expand Down Expand Up @@ -145,10 +146,10 @@ async def test_CapacityLimiter_change_total_tokens() -> None:
with pytest.raises(TypeError):
c.total_tokens = 1.0

with pytest.raises(ValueError, match=r"^total_tokens must be >= 1$"):
c.total_tokens = 0
with pytest.raises(ValueError, match=r"^total_tokens must be >= 0$"):
c.total_tokens = -1

with pytest.raises(ValueError, match=r"^total_tokens must be >= 1$"):
with pytest.raises(ValueError, match=r"^total_tokens must be >= 0$"):
c.total_tokens = -10

assert c.total_tokens == 2
Expand Down Expand Up @@ -190,6 +191,80 @@ async def test_CapacityLimiter_memleak_548() -> None:
assert len(limiter._pending_borrowers) == 0


async def test_CapacityLimiter_zero_limit_tokens() -> None:
c = CapacityLimiter(5)

assert c.total_tokens == 5

async with _core.open_nursery() as nursery:
c.total_tokens = 0

for i in range(5):
nursery.start_soon(c.acquire_on_behalf_of, i)
await wait_all_tasks_blocked()

assert set(c.statistics().borrowers) == set()
assert c.statistics().tasks_waiting == 5

c.total_tokens = 5

assert set(c.statistics().borrowers) == {0, 1, 2, 3, 4}

nursery.start_soon(c.acquire_on_behalf_of, 5)
await wait_all_tasks_blocked()

assert c.statistics().tasks_waiting == 1

for i in range(5):
c.release_on_behalf_of(i)

assert c.statistics().tasks_waiting == 0
c.release_on_behalf_of(5)

# making sure that zero limit capacity limiter doesn't let any tasks through

c.total_tokens = 0

with pytest.raises(_core.WouldBlock):
c.acquire_nowait()

nursery.start_soon(c.acquire_on_behalf_of, 6)
await wait_all_tasks_blocked()

assert c.statistics().tasks_waiting == 1
assert c.statistics().borrowers == []

c.total_tokens = 1
assert c.statistics().tasks_waiting == 0
assert c.statistics().borrowers == [6]
c.release_on_behalf_of(6)

await c.acquire_on_behalf_of(0) # total_tokens is 1

nursery.start_soon(c.acquire_on_behalf_of, 1)
c.total_tokens = 0

assert c.statistics().borrowers == [0]

c.release_on_behalf_of(0)
assert c.statistics().borrowers == []
assert c.statistics().tasks_waiting == 0

c.total_tokens = 1
await wait_all_tasks_blocked()
assert c.statistics().borrowers == [1]

c.release_on_behalf_of(1)

c.total_tokens = 0

nursery.cancel_scope.cancel()

assert c.total_tokens == 0
assert c.statistics().borrowers == []
assert c._pending_borrowers == {}


async def test_Semaphore() -> None:
with pytest.raises(TypeError):
Semaphore(1.0) # type: ignore[arg-type]
Expand Down
Loading