Skip to content

Commit 2db33e8

Browse files
committed
fix: Ensured transactional safety when writing to redis
1 parent 3b2cf1b commit 2db33e8

File tree

1 file changed

+29
-37
lines changed

1 file changed

+29
-37
lines changed

scheduler/redis_models/base.py

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,9 @@ def deserialize(cls, data: Dict[str, Any]) -> Self:
115115
class HashModel(BaseModel):
116116
created_at: Optional[datetime] = None
117117
parent: Optional[str] = None
118-
_dirty_fields: Set[str] = dataclasses.field(default_factory=set) # fields that were changed
119-
_save_all: bool = True # Save all fields to broker, after init, or after delete
120118
_list_key: ClassVar[str] = ":list_all:"
121119
_children_key_template: ClassVar[str] = ":children:{}:"
122120

123-
def __post_init__(self):
124-
self._dirty_fields = set()
125-
self._save_all = True
126-
127-
def __setattr__(self, key, value):
128-
if key != "_dirty_fields" and hasattr(self, "_dirty_fields"):
129-
self._dirty_fields.add(key)
130-
super(HashModel, self).__setattr__(key, value)
131-
132121
@property
133122
def _parent_key(self) -> Optional[str]:
134123
if self.parent is None:
@@ -155,8 +144,10 @@ def exists(cls, name: str, connection: ConnectionType) -> bool:
155144

156145
@classmethod
157146
def delete_many(cls, names: List[str], connection: ConnectionType) -> None:
158-
for name in names:
159-
connection.delete(cls._element_key_template.format(name))
147+
with connection.pipeline() as pipeline:
148+
for name in names:
149+
pipeline.delete(cls._element_key_template.format(name))
150+
pipeline.execute()
160151

161152
@classmethod
162153
def get(cls, name: str, connection: ConnectionType) -> Optional[Self]:
@@ -171,34 +162,35 @@ def get(cls, name: str, connection: ConnectionType) -> Optional[Self]:
171162

172163
@classmethod
173164
def get_many(cls, names: Sequence[str], connection: ConnectionType) -> List[Optional[Self]]:
174-
pipeline = connection.pipeline()
175-
for name in names:
176-
pipeline.hgetall(cls._element_key_template.format(name))
177-
values = pipeline.execute()
178-
return [(cls.deserialize(decode_dict(v, set())) if v else None) for v in values]
165+
with connection.pipeline() as pipeline:
166+
for name in names:
167+
pipeline.hgetall(cls._element_key_template.format(name))
168+
values = pipeline.execute()
169+
return [(cls.deserialize(decode_dict(v, set())) if v else None) for v in values]
179170

180171
def save(self, connection: ConnectionType) -> None:
181-
connection.sadd(self._list_key, self.name)
182-
if self._parent_key is not None:
183-
connection.sadd(self._parent_key, self.name)
184-
mapping = self.serialize(with_nones=True)
185-
if not self._save_all and len(self._dirty_fields) > 0:
186-
mapping = {k: v for k, v in mapping.items() if k in self._dirty_fields}
187-
none_values = {k for k, v in mapping.items() if v is None}
188-
if none_values:
189-
connection.hdel(self._key, *none_values)
190-
mapping = {k: v for k, v in mapping.items() if v is not None}
191-
if mapping:
192-
connection.hset(self._key, mapping=mapping)
193-
self._dirty_fields = set()
194-
self._save_all = False
172+
with connection.pipeline() as pipeline:
173+
pipeline.sadd(self._list_key, self.name)
174+
if self._parent_key is not None:
175+
pipeline.sadd(self._parent_key, self.name)
176+
mapping = self.serialize(with_nones=True)
177+
none_values = {k for k, v in mapping.items() if v is None}
178+
if none_values:
179+
pipeline.hdel(self._key, *none_values)
180+
mapping = {k: v for k, v in mapping.items() if v is not None}
181+
if mapping:
182+
pipeline.hset(self._key, mapping=mapping)
183+
184+
pipeline.execute()
195185

196186
def delete(self, connection: ConnectionType) -> None:
197-
connection.srem(self._list_key, self._key)
198-
if self._parent_key is not None:
199-
connection.srem(self._parent_key, 0, self._key)
200-
connection.delete(self._key)
201-
self._save_all = True
187+
with connection.pipeline() as pipeline:
188+
pipeline.srem(self._list_key, self._key)
189+
if self._parent_key is not None:
190+
pipeline.srem(self._parent_key, 0, self._key)
191+
pipeline.delete(self._key)
192+
193+
pipeline.execute()
202194

203195
@classmethod
204196
def count(cls, connection: ConnectionType, parent: Optional[str] = None) -> int:

0 commit comments

Comments
 (0)