|
58 | 58 |
|
59 | 59 | import sniffio |
60 | 60 |
|
61 | | -from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc |
| 61 | +from .. import ( |
| 62 | + CapacityLimiterStatistics, |
| 63 | + EventStatistics, |
| 64 | + LockStatistics, |
| 65 | + TaskInfo, |
| 66 | + abc, |
| 67 | +) |
62 | 68 | from .._core._eventloop import claim_worker_thread, threadlocals |
63 | 69 | from .._core._exceptions import ( |
64 | 70 | BrokenResourceError, |
|
70 | 76 | ) |
71 | 77 | from .._core._sockets import convert_ipv6_sockaddr |
72 | 78 | from .._core._streams import create_memory_object_stream |
73 | | -from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter |
| 79 | +from .._core._synchronization import ( |
| 80 | + CapacityLimiter as BaseCapacityLimiter, |
| 81 | +) |
74 | 82 | from .._core._synchronization import Event as BaseEvent |
75 | | -from .._core._synchronization import ResourceGuard |
| 83 | +from .._core._synchronization import Lock as BaseLock |
| 84 | +from .._core._synchronization import ( |
| 85 | + ResourceGuard, |
| 86 | + SemaphoreStatistics, |
| 87 | +) |
| 88 | +from .._core._synchronization import Semaphore as BaseSemaphore |
76 | 89 | from .._core._tasks import CancelScope as BaseCancelScope |
77 | 90 | from ..abc import ( |
78 | 91 | AsyncBackend, |
@@ -1658,6 +1671,154 @@ def statistics(self) -> EventStatistics: |
1658 | 1671 | return EventStatistics(len(self._event._waiters)) |
1659 | 1672 |
|
1660 | 1673 |
|
| 1674 | +class Lock(BaseLock): |
| 1675 | + def __new__(cls, *, fast_acquire: bool = False) -> Lock: |
| 1676 | + return object.__new__(cls) |
| 1677 | + |
| 1678 | + def __init__(self, *, fast_acquire: bool = False) -> None: |
| 1679 | + self._fast_acquire = fast_acquire |
| 1680 | + self._owner_task: asyncio.Task | None = None |
| 1681 | + self._waiters: deque[tuple[asyncio.Task, asyncio.Future]] = deque() |
| 1682 | + |
| 1683 | + async def acquire(self) -> None: |
| 1684 | + if self._owner_task is None and not self._waiters: |
| 1685 | + await AsyncIOBackend.checkpoint_if_cancelled() |
| 1686 | + self._owner_task = current_task() |
| 1687 | + |
| 1688 | + # Unless on the "fast path", yield control of the event loop so that other |
| 1689 | + # tasks can run too |
| 1690 | + if not self._fast_acquire: |
| 1691 | + try: |
| 1692 | + await AsyncIOBackend.cancel_shielded_checkpoint() |
| 1693 | + except CancelledError: |
| 1694 | + self.release() |
| 1695 | + raise |
| 1696 | + |
| 1697 | + return |
| 1698 | + |
| 1699 | + task = cast(asyncio.Task, current_task()) |
| 1700 | + fut: asyncio.Future[None] = asyncio.Future() |
| 1701 | + item = task, fut |
| 1702 | + self._waiters.append(item) |
| 1703 | + try: |
| 1704 | + await fut |
| 1705 | + except CancelledError: |
| 1706 | + self._waiters.remove(item) |
| 1707 | + if self._owner_task is task: |
| 1708 | + self.release() |
| 1709 | + |
| 1710 | + raise |
| 1711 | + |
| 1712 | + self._waiters.remove(item) |
| 1713 | + |
| 1714 | + def acquire_nowait(self) -> None: |
| 1715 | + if self._owner_task is None and not self._waiters: |
| 1716 | + self._owner_task = current_task() |
| 1717 | + return |
| 1718 | + |
| 1719 | + raise WouldBlock |
| 1720 | + |
| 1721 | + def locked(self) -> bool: |
| 1722 | + return self._owner_task is not None |
| 1723 | + |
| 1724 | + def release(self) -> None: |
| 1725 | + if self._owner_task != current_task(): |
| 1726 | + raise RuntimeError("The current task is not holding this lock") |
| 1727 | + |
| 1728 | + for task, fut in self._waiters: |
| 1729 | + if not fut.cancelled(): |
| 1730 | + self._owner_task = task |
| 1731 | + fut.set_result(None) |
| 1732 | + return |
| 1733 | + |
| 1734 | + self._owner_task = None |
| 1735 | + |
| 1736 | + def statistics(self) -> LockStatistics: |
| 1737 | + task_info = AsyncIOTaskInfo(self._owner_task) if self._owner_task else None |
| 1738 | + return LockStatistics(self.locked(), task_info, len(self._waiters)) |
| 1739 | + |
| 1740 | + |
| 1741 | +class Semaphore(BaseSemaphore): |
| 1742 | + def __new__( |
| 1743 | + cls, |
| 1744 | + initial_value: int, |
| 1745 | + *, |
| 1746 | + max_value: int | None = None, |
| 1747 | + fast_acquire: bool = False, |
| 1748 | + ) -> Semaphore: |
| 1749 | + return object.__new__(cls) |
| 1750 | + |
| 1751 | + def __init__( |
| 1752 | + self, |
| 1753 | + initial_value: int, |
| 1754 | + *, |
| 1755 | + max_value: int | None = None, |
| 1756 | + fast_acquire: bool = False, |
| 1757 | + ): |
| 1758 | + super().__init__(initial_value, max_value=max_value) |
| 1759 | + self._value = initial_value |
| 1760 | + self._max_value = max_value |
| 1761 | + self._fast_acquire = fast_acquire |
| 1762 | + self._waiters: deque[asyncio.Future[None]] = deque() |
| 1763 | + |
| 1764 | + async def acquire(self) -> None: |
| 1765 | + if self._value > 0 and not self._waiters: |
| 1766 | + await AsyncIOBackend.checkpoint_if_cancelled() |
| 1767 | + self._value -= 1 |
| 1768 | + |
| 1769 | + # Unless on the "fast path", yield control of the event loop so that other |
| 1770 | + # tasks can run too |
| 1771 | + if not self._fast_acquire: |
| 1772 | + try: |
| 1773 | + await AsyncIOBackend.cancel_shielded_checkpoint() |
| 1774 | + except CancelledError: |
| 1775 | + self.release() |
| 1776 | + raise |
| 1777 | + |
| 1778 | + return |
| 1779 | + |
| 1780 | + fut: asyncio.Future[None] = asyncio.Future() |
| 1781 | + self._waiters.append(fut) |
| 1782 | + try: |
| 1783 | + await fut |
| 1784 | + except CancelledError: |
| 1785 | + try: |
| 1786 | + self._waiters.remove(fut) |
| 1787 | + except ValueError: |
| 1788 | + self.release() |
| 1789 | + |
| 1790 | + raise |
| 1791 | + |
| 1792 | + def acquire_nowait(self) -> None: |
| 1793 | + if self._value == 0: |
| 1794 | + raise WouldBlock |
| 1795 | + |
| 1796 | + self._value -= 1 |
| 1797 | + |
| 1798 | + def release(self) -> None: |
| 1799 | + if self._max_value is not None and self._value == self._max_value: |
| 1800 | + raise ValueError("semaphore released too many times") |
| 1801 | + |
| 1802 | + for fut in self._waiters: |
| 1803 | + if not fut.cancelled(): |
| 1804 | + fut.set_result(None) |
| 1805 | + self._waiters.remove(fut) |
| 1806 | + return |
| 1807 | + |
| 1808 | + self._value += 1 |
| 1809 | + |
| 1810 | + @property |
| 1811 | + def value(self) -> int: |
| 1812 | + return self._value |
| 1813 | + |
| 1814 | + @property |
| 1815 | + def max_value(self) -> int | None: |
| 1816 | + return self._max_value |
| 1817 | + |
| 1818 | + def statistics(self) -> SemaphoreStatistics: |
| 1819 | + return SemaphoreStatistics(len(self._waiters)) |
| 1820 | + |
| 1821 | + |
1661 | 1822 | class CapacityLimiter(BaseCapacityLimiter): |
1662 | 1823 | _total_tokens: float = 0 |
1663 | 1824 |
|
@@ -2108,6 +2269,20 @@ def create_task_group(cls) -> abc.TaskGroup: |
2108 | 2269 | def create_event(cls) -> abc.Event: |
2109 | 2270 | return Event() |
2110 | 2271 |
|
| 2272 | + @classmethod |
| 2273 | + def create_lock(cls, *, fast_acquire: bool) -> abc.Lock: |
| 2274 | + return Lock(fast_acquire=fast_acquire) |
| 2275 | + |
| 2276 | + @classmethod |
| 2277 | + def create_semaphore( |
| 2278 | + cls, |
| 2279 | + initial_value: int, |
| 2280 | + *, |
| 2281 | + max_value: int | None = None, |
| 2282 | + fast_acquire: bool = False, |
| 2283 | + ) -> abc.Semaphore: |
| 2284 | + return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire) |
| 2285 | + |
2111 | 2286 | @classmethod |
2112 | 2287 | def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter: |
2113 | 2288 | return CapacityLimiter(total_tokens) |
|
0 commit comments