Skip to content

Commit 759a186

Browse files
committed
ruff fixes
1 parent 31b09d8 commit 759a186

File tree

1 file changed

+17
-28
lines changed

1 file changed

+17
-28
lines changed

backend/app/events/event_store.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717

1818
# Base fields stored at document level (everything else goes into payload)
1919
_BASE_FIELDS = {"event_id", "event_type", "event_version", "timestamp", "aggregate_id", "metadata"}
20+
_EXCLUDE_FIELDS = {"id", "revision_id", "stored_at", "ttl_expires_at"}
21+
22+
23+
def _flatten_doc(doc: "EventDocument") -> dict[str, Any]:
24+
"""Flatten EventDocument payload to top level for schema registry deserialization."""
25+
d = doc.model_dump(exclude=_EXCLUDE_FIELDS)
26+
return {**{k: v for k, v in d.items() if k != "payload"}, **d.get("payload", {})}
2027

2128

2229
class EventStore:
@@ -52,7 +59,8 @@ async def store_event(self, event: BaseEvent) -> bool:
5259
now = datetime.now(timezone.utc)
5360
data = event.model_dump(exclude={"topic"})
5461
payload = {k: data.pop(k) for k in list(data) if k not in _BASE_FIELDS}
55-
doc = EventDocument(**data, payload=payload, stored_at=now, ttl_expires_at=now + timedelta(days=self.ttl_days))
62+
ttl = now + timedelta(days=self.ttl_days)
63+
doc = EventDocument(**data, payload=payload, stored_at=now, ttl_expires_at=ttl)
5664
await doc.insert()
5765

5866
add_span_attributes(
@@ -122,8 +130,7 @@ async def get_event(self, event_id: str) -> BaseEvent | None:
122130
if not doc:
123131
return None
124132

125-
data = doc.model_dump(exclude={"id", "revision_id", "stored_at", "ttl_expires_at"})
126-
event = self.schema_registry.deserialize_json({**{k: v for k, v in data.items() if k != "payload"}, **data.get("payload", {})})
133+
event = self.schema_registry.deserialize_json(_flatten_doc(doc))
127134

128135
duration = asyncio.get_event_loop().time() - start
129136
self.metrics.record_event_query_duration(duration, "get_by_id", "event_store")
@@ -149,10 +156,7 @@ async def get_events_by_type(
149156
.limit(limit)
150157
.to_list()
151158
)
152-
events = []
153-
for doc in docs:
154-
d = doc.model_dump(exclude={"id", "revision_id", "stored_at", "ttl_expires_at"})
155-
events.append(self.schema_registry.deserialize_json({**{k: v for k, v in d.items() if k != "payload"}, **d.get("payload", {})}))
159+
events = [self.schema_registry.deserialize_json(_flatten_doc(doc)) for doc in docs]
156160

157161
duration = asyncio.get_event_loop().time() - start
158162
self.metrics.record_event_query_duration(duration, "get_by_type", "event_store")
@@ -164,17 +168,12 @@ async def get_execution_events(
164168
event_types: list[EventType] | None = None,
165169
) -> list[BaseEvent]:
166170
start = asyncio.get_event_loop().time()
167-
query: dict[str, Any] = {
168-
"$or": [{"payload.execution_id": execution_id}, {"aggregate_id": execution_id}]
169-
}
171+
query: dict[str, Any] = {"$or": [{"payload.execution_id": execution_id}, {"aggregate_id": execution_id}]}
170172
if event_types:
171173
query["event_type"] = {"$in": event_types}
172174

173175
docs = await EventDocument.find(query).sort([("timestamp", SortDirection.ASCENDING)]).to_list()
174-
events = []
175-
for doc in docs:
176-
d = doc.model_dump(exclude={"id", "revision_id", "stored_at", "ttl_expires_at"})
177-
events.append(self.schema_registry.deserialize_json({**{k: v for k, v in d.items() if k != "payload"}, **d.get("payload", {})}))
176+
events = [self.schema_registry.deserialize_json(_flatten_doc(doc)) for doc in docs]
178177

179178
duration = asyncio.get_event_loop().time() - start
180179
self.metrics.record_event_query_duration(duration, "get_execution_events", "event_store")
@@ -196,10 +195,7 @@ async def get_user_events(
196195
query["timestamp"] = tr
197196

198197
docs = await EventDocument.find(query).sort([("timestamp", SortDirection.DESCENDING)]).limit(limit).to_list()
199-
events = []
200-
for doc in docs:
201-
d = doc.model_dump(exclude={"id", "revision_id", "stored_at", "ttl_expires_at"})
202-
events.append(self.schema_registry.deserialize_json({**{k: v for k, v in d.items() if k != "payload"}, **d.get("payload", {})}))
198+
events = [self.schema_registry.deserialize_json(_flatten_doc(doc)) for doc in docs]
203199

204200
duration = asyncio.get_event_loop().time() - start
205201
self.metrics.record_event_query_duration(duration, "get_user_events", "event_store")
@@ -220,10 +216,7 @@ async def get_security_events(
220216
query["timestamp"] = tr
221217

222218
docs = await EventDocument.find(query).sort([("timestamp", SortDirection.DESCENDING)]).limit(limit).to_list()
223-
events = []
224-
for doc in docs:
225-
d = doc.model_dump(exclude={"id", "revision_id", "stored_at", "ttl_expires_at"})
226-
events.append(self.schema_registry.deserialize_json({**{k: v for k, v in d.items() if k != "payload"}, **d.get("payload", {})}))
219+
events = [self.schema_registry.deserialize_json(_flatten_doc(doc)) for doc in docs]
227220

228221
duration = asyncio.get_event_loop().time() - start
229222
self.metrics.record_event_query_duration(duration, "get_security_events", "event_store")
@@ -236,10 +229,7 @@ async def get_correlation_chain(self, correlation_id: str) -> list[BaseEvent]:
236229
.sort([("timestamp", SortDirection.ASCENDING)])
237230
.to_list()
238231
)
239-
events = []
240-
for doc in docs:
241-
d = doc.model_dump(exclude={"id", "revision_id", "stored_at", "ttl_expires_at"})
242-
events.append(self.schema_registry.deserialize_json({**{k: v for k, v in d.items() if k != "payload"}, **d.get("payload", {})}))
232+
events = [self.schema_registry.deserialize_json(_flatten_doc(doc)) for doc in docs]
243233

244234
duration = asyncio.get_event_loop().time() - start
245235
self.metrics.record_event_query_duration(duration, "get_correlation_chain", "event_store")
@@ -263,8 +253,7 @@ async def replay_events(
263253
query["event_type"] = {"$in": event_types}
264254

265255
async for doc in EventDocument.find(query).sort([("timestamp", SortDirection.ASCENDING)]):
266-
d = doc.model_dump(exclude={"id", "revision_id", "stored_at", "ttl_expires_at"})
267-
event = self.schema_registry.deserialize_json({**{k: v for k, v in d.items() if k != "payload"}, **d.get("payload", {})})
256+
event = self.schema_registry.deserialize_json(_flatten_doc(doc))
268257
if callback:
269258
await callback(event)
270259
count += 1

0 commit comments

Comments
 (0)