|
2 | 2 | import logging |
3 | 3 | import uuid |
4 | 4 | from types import TracebackType |
5 | | -from typing import Annotated |
| 5 | +from typing import Annotated, ClassVar |
6 | 6 |
|
7 | 7 | from common_library.basic_types import DEFAULT_FACTORY |
8 | 8 | from pydantic import ( |
@@ -101,25 +101,22 @@ class DistributedSemaphore(BaseModel): |
101 | 101 | ] = DEFAULT_FACTORY |
102 | 102 |
|
103 | 103 | # Private state attributes (not part of the model) |
104 | | - _acquire_script: AsyncScript |
105 | | - _count_script: AsyncScript |
106 | | - _release_script: AsyncScript |
107 | | - _renew_script: AsyncScript |
| 104 | + _acquire_script: ClassVar[AsyncScript] |
| 105 | + _count_script: ClassVar[AsyncScript] |
| 106 | + _release_script: ClassVar[AsyncScript] |
| 107 | + _renew_script: ClassVar[AsyncScript] |
| 108 | + |
| 109 | + @classmethod |
| 110 | + def _register_scripts(cls, redis_client) -> None: |
| 111 | + if cls._acquire_script is None: |
| 112 | + cls._acquire_script = redis_client.register_script(ACQUIRE_SEMAPHORE_SCRIPT) |
| 113 | + cls._count_script = redis_client.register_script(COUNT_SEMAPHORE_SCRIPT) |
| 114 | + cls._release_script = redis_client.register_script(RELEASE_SEMAPHORE_SCRIPT) |
| 115 | + cls._renew_script = redis_client.register_script(RENEW_SEMAPHORE_SCRIPT) |
108 | 116 |
|
109 | 117 | def __init__(self, **data) -> None: |
110 | 118 | super().__init__(**data) |
111 | | - self._acquire_script = self.redis_client.redis.register_script( |
112 | | - ACQUIRE_SEMAPHORE_SCRIPT |
113 | | - ) |
114 | | - self._count_script = self.redis_client.redis.register_script( |
115 | | - COUNT_SEMAPHORE_SCRIPT |
116 | | - ) |
117 | | - self._release_script = self.redis_client.redis.register_script( |
118 | | - RELEASE_SEMAPHORE_SCRIPT |
119 | | - ) |
120 | | - self._renew_script = self.redis_client.redis.register_script( |
121 | | - RENEW_SEMAPHORE_SCRIPT |
122 | | - ) |
| 119 | + type(self)._register_scripts(self.redis_client) # noqa: SLF001 |
123 | 120 |
|
124 | 121 | @computed_field # type: ignore[prop-decorator] |
125 | 122 | @property |
@@ -199,7 +196,7 @@ async def release(self) -> None: |
199 | 196 | ttl_seconds = int(self.ttl.total_seconds()) |
200 | 197 |
|
201 | 198 | # Execute the release Lua script atomically |
202 | | - result = await self._release_script( |
| 199 | + result = await type(self)._release_script( # noqa: SLF001 |
203 | 200 | keys=( |
204 | 201 | self.semaphore_key, |
205 | 202 | self.holder_key, |
@@ -234,7 +231,7 @@ async def _try_acquire(self) -> bool: |
234 | 231 | ttl_seconds = int(self.ttl.total_seconds()) |
235 | 232 |
|
236 | 233 | # Execute the Lua script atomically |
237 | | - result = await self._acquire_script( |
| 234 | + result = await type(self)._acquire_script( # noqa: SLF001 |
238 | 235 | keys=(self.semaphore_key, self.holder_key), |
239 | 236 | args=(self.instance_id, str(self.capacity), str(ttl_seconds)), |
240 | 237 | client=self.redis_client.redis, |
@@ -277,7 +274,7 @@ async def reacquire(self) -> None: |
277 | 274 | ttl_seconds = int(self.ttl.total_seconds()) |
278 | 275 |
|
279 | 276 | # Execute the renewal Lua script atomically |
280 | | - result = await self._renew_script( |
| 277 | + result = await type(self)._renew_script( # noqa: SLF001 |
281 | 278 | keys=(self.semaphore_key, self.holder_key), |
282 | 279 | args=( |
283 | 280 | self.instance_id, |
@@ -325,7 +322,7 @@ async def get_current_count(self) -> int: |
325 | 322 | ttl_seconds = int(self.ttl.total_seconds()) |
326 | 323 |
|
327 | 324 | # Execute the count Lua script atomically |
328 | | - result = await self._count_script( |
| 325 | + result = await type(self)._count_script( # noqa: SLF001 |
329 | 326 | keys=(self.semaphore_key,), |
330 | 327 | args=(str(ttl_seconds),), |
331 | 328 | client=self.redis_client.redis, |
|
0 commit comments