Skip to content

Commit 60e96b3

Browse files
committed
use class variable
1 parent ac62fee commit 60e96b3

File tree

1 file changed

+18
-21
lines changed

1 file changed

+18
-21
lines changed

packages/service-library/src/servicelib/redis/_semaphore.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import uuid
44
from types import TracebackType
5-
from typing import Annotated
5+
from typing import Annotated, ClassVar
66

77
from common_library.basic_types import DEFAULT_FACTORY
88
from pydantic import (
@@ -101,25 +101,22 @@ class DistributedSemaphore(BaseModel):
101101
] = DEFAULT_FACTORY
102102

103103
# Private state attributes (not part of the model)
104-
_acquire_script: AsyncScript
105-
_count_script: AsyncScript
106-
_release_script: AsyncScript
107-
_renew_script: AsyncScript
104+
_acquire_script: ClassVar[AsyncScript]
105+
_count_script: ClassVar[AsyncScript]
106+
_release_script: ClassVar[AsyncScript]
107+
_renew_script: ClassVar[AsyncScript]
108+
109+
@classmethod
110+
def _register_scripts(cls, redis_client) -> None:
111+
if cls._acquire_script is None:
112+
cls._acquire_script = redis_client.register_script(ACQUIRE_SEMAPHORE_SCRIPT)
113+
cls._count_script = redis_client.register_script(COUNT_SEMAPHORE_SCRIPT)
114+
cls._release_script = redis_client.register_script(RELEASE_SEMAPHORE_SCRIPT)
115+
cls._renew_script = redis_client.register_script(RENEW_SEMAPHORE_SCRIPT)
108116

109117
def __init__(self, **data) -> None:
110118
super().__init__(**data)
111-
self._acquire_script = self.redis_client.redis.register_script(
112-
ACQUIRE_SEMAPHORE_SCRIPT
113-
)
114-
self._count_script = self.redis_client.redis.register_script(
115-
COUNT_SEMAPHORE_SCRIPT
116-
)
117-
self._release_script = self.redis_client.redis.register_script(
118-
RELEASE_SEMAPHORE_SCRIPT
119-
)
120-
self._renew_script = self.redis_client.redis.register_script(
121-
RENEW_SEMAPHORE_SCRIPT
122-
)
119+
type(self)._register_scripts(self.redis_client) # noqa: SLF001
123120

124121
@computed_field # type: ignore[prop-decorator]
125122
@property
@@ -199,7 +196,7 @@ async def release(self) -> None:
199196
ttl_seconds = int(self.ttl.total_seconds())
200197

201198
# Execute the release Lua script atomically
202-
result = await self._release_script(
199+
result = await type(self)._release_script( # noqa: SLF001
203200
keys=(
204201
self.semaphore_key,
205202
self.holder_key,
@@ -234,7 +231,7 @@ async def _try_acquire(self) -> bool:
234231
ttl_seconds = int(self.ttl.total_seconds())
235232

236233
# Execute the Lua script atomically
237-
result = await self._acquire_script(
234+
result = await type(self)._acquire_script( # noqa: SLF001
238235
keys=(self.semaphore_key, self.holder_key),
239236
args=(self.instance_id, str(self.capacity), str(ttl_seconds)),
240237
client=self.redis_client.redis,
@@ -277,7 +274,7 @@ async def reacquire(self) -> None:
277274
ttl_seconds = int(self.ttl.total_seconds())
278275

279276
# Execute the renewal Lua script atomically
280-
result = await self._renew_script(
277+
result = await type(self)._renew_script( # noqa: SLF001
281278
keys=(self.semaphore_key, self.holder_key),
282279
args=(
283280
self.instance_id,
@@ -325,7 +322,7 @@ async def get_current_count(self) -> int:
325322
ttl_seconds = int(self.ttl.total_seconds())
326323

327324
# Execute the count Lua script atomically
328-
result = await self._count_script(
325+
result = await type(self)._count_script( # noqa: SLF001
329326
keys=(self.semaphore_key,),
330327
args=(str(ttl_seconds),),
331328
client=self.redis_client.redis,

0 commit comments

Comments
 (0)