Skip to content

Commit af8da2a

Browse files
committed
asynciomotor/db fixes (preciser types)
1 parent 3bc5f31 commit af8da2a

25 files changed

+120
-114
lines changed

backend/app/core/database_context.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,21 @@
77
from motor.motor_asyncio import (
88
AsyncIOMotorClient,
99
AsyncIOMotorClientSession,
10+
AsyncIOMotorCollection,
11+
AsyncIOMotorCursor,
1012
AsyncIOMotorDatabase,
1113
)
1214
from pymongo.errors import ServerSelectionTimeoutError
1315

1416
from app.core.logging import logger
1517

1618
# Python 3.12 type aliases using the new 'type' statement
17-
type DBClient = AsyncIOMotorClient[Any]
18-
type Database = AsyncIOMotorDatabase[Any]
19+
# MongoDocument represents the raw document type returned by Motor operations
20+
type MongoDocument = dict[str, Any]
21+
type DBClient = AsyncIOMotorClient[MongoDocument]
22+
type Database = AsyncIOMotorDatabase[MongoDocument]
23+
type Collection = AsyncIOMotorCollection[MongoDocument]
24+
type Cursor = AsyncIOMotorCursor[MongoDocument]
1925
type DBSession = AsyncIOMotorClientSession
2026

2127
# Type variable for generic database provider
@@ -102,7 +108,7 @@ async def connect(self) -> None:
102108
# Always explicitly bind to current event loop for consistency
103109
import asyncio
104110

105-
client: AsyncIOMotorClient = AsyncIOMotorClient(
111+
client: DBClient = AsyncIOMotorClient(
106112
self._config.mongodb_url,
107113
serverSelectionTimeoutMS=self._config.server_selection_timeout_ms,
108114
connectTimeoutMS=self._config.connect_timeout_ms,

backend/app/core/dishka_lifespan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import redis.asyncio as redis
55
from dishka import AsyncContainer
66
from fastapi import FastAPI
7-
from motor.motor_asyncio import AsyncIOMotorDatabase
87

8+
from app.core.database_context import Database
99
from app.core.logging import logger
1010
from app.core.startup import initialize_metrics_context, initialize_rate_limits
1111
from app.core.tracing import init_tracing
@@ -64,7 +64,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
6464
await initialize_event_schemas(schema_registry)
6565

6666
# Initialize database schema at application scope using app-scoped DB
67-
database = await container.get(AsyncIOMotorDatabase)
67+
database = await container.get(Database)
6868
schema_manager = SchemaManager(database)
6969
await schema_manager.apply_all()
7070
logger.info("Database schema ensured by SchemaManager")

backend/app/core/providers.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import redis.asyncio as redis
44
from dishka import Provider, Scope, provide
5-
from motor.motor_asyncio import AsyncIOMotorDatabase
65

76
from app.core.database_context import (
87
AsyncDatabaseConnection,
8+
Database,
99
DatabaseConfig,
1010
create_database_connection,
1111
)
@@ -105,7 +105,7 @@ async def get_database_connection(self, settings: Settings) -> AsyncIterator[Asy
105105
await db_connection.disconnect()
106106

107107
@provide
108-
def get_database(self, db_connection: AsyncDatabaseConnection) -> AsyncIOMotorDatabase:
108+
def get_database(self, db_connection: AsyncDatabaseConnection) -> Database:
109109
return db_connection.database
110110

111111

@@ -174,7 +174,7 @@ async def get_kafka_producer(
174174
await producer.stop()
175175

176176
@provide
177-
async def get_dlq_manager(self, database: AsyncIOMotorDatabase) -> AsyncIterator[DLQManager]:
177+
async def get_dlq_manager(self, database: Database) -> AsyncIterator[DLQManager]:
178178
manager = create_dlq_manager(database)
179179
await manager.start()
180180
try:
@@ -210,7 +210,7 @@ def get_schema_registry(self) -> SchemaRegistryManager:
210210
@provide
211211
async def get_event_store(
212212
self,
213-
database: AsyncIOMotorDatabase,
213+
database: Database,
214214
schema_registry: SchemaRegistryManager
215215
) -> EventStore:
216216
store = create_event_store(
@@ -332,7 +332,7 @@ async def get_sse_kafka_redis_bridge(
332332
@provide
333333
def get_sse_repository(
334334
self,
335-
database: AsyncIOMotorDatabase
335+
database: Database
336336
) -> SSERepository:
337337
return SSERepository(database)
338338

@@ -365,7 +365,7 @@ class AuthProvider(Provider):
365365
scope = Scope.APP
366366

367367
@provide
368-
def get_user_repository(self, database: AsyncIOMotorDatabase) -> UserRepository:
368+
def get_user_repository(self, database: Database) -> UserRepository:
369369
return UserRepository(database)
370370

371371
@provide
@@ -377,11 +377,11 @@ class UserServicesProvider(Provider):
377377
scope = Scope.APP
378378

379379
@provide
380-
def get_user_settings_repository(self, database: AsyncIOMotorDatabase) -> UserSettingsRepository:
380+
def get_user_settings_repository(self, database: Database) -> UserSettingsRepository:
381381
return UserSettingsRepository(database)
382382

383383
@provide
384-
def get_event_repository(self, database: AsyncIOMotorDatabase) -> EventRepository:
384+
def get_event_repository(self, database: Database) -> EventRepository:
385385
return EventRepository(database)
386386

387387
@provide
@@ -415,7 +415,7 @@ class AdminServicesProvider(Provider):
415415
scope = Scope.APP
416416

417417
@provide
418-
def get_admin_events_repository(self, database: AsyncIOMotorDatabase) -> AdminEventsRepository:
418+
def get_admin_events_repository(self, database: Database) -> AdminEventsRepository:
419419
return AdminEventsRepository(database)
420420

421421
@provide(scope=Scope.REQUEST)
@@ -427,7 +427,7 @@ def get_admin_events_service(
427427
return AdminEventsService(admin_events_repository, replay_service)
428428

429429
@provide
430-
def get_admin_settings_repository(self, database: AsyncIOMotorDatabase) -> AdminSettingsRepository:
430+
def get_admin_settings_repository(self, database: Database) -> AdminSettingsRepository:
431431
return AdminSettingsRepository(database)
432432

433433
@provide
@@ -438,15 +438,15 @@ def get_admin_settings_service(
438438
return AdminSettingsService(admin_settings_repository)
439439

440440
@provide
441-
def get_admin_user_repository(self, database: AsyncIOMotorDatabase) -> AdminUserRepository:
441+
def get_admin_user_repository(self, database: Database) -> AdminUserRepository:
442442
return AdminUserRepository(database)
443443

444444
@provide
445-
def get_saga_repository(self, database: AsyncIOMotorDatabase) -> SagaRepository:
445+
def get_saga_repository(self, database: Database) -> SagaRepository:
446446
return SagaRepository(database)
447447

448448
@provide
449-
def get_notification_repository(self, database: AsyncIOMotorDatabase) -> NotificationRepository:
449+
def get_notification_repository(self, database: Database) -> NotificationRepository:
450450
return NotificationRepository(database)
451451

452452
@provide
@@ -482,23 +482,23 @@ class BusinessServicesProvider(Provider):
482482
scope = Scope.REQUEST
483483

484484
@provide
485-
def get_execution_repository(self, database: AsyncIOMotorDatabase) -> ExecutionRepository:
485+
def get_execution_repository(self, database: Database) -> ExecutionRepository:
486486
return ExecutionRepository(database)
487487

488488
@provide
489-
def get_resource_allocation_repository(self, database: AsyncIOMotorDatabase) -> ResourceAllocationRepository:
489+
def get_resource_allocation_repository(self, database: Database) -> ResourceAllocationRepository:
490490
return ResourceAllocationRepository(database)
491491

492492
@provide
493-
def get_saved_script_repository(self, database: AsyncIOMotorDatabase) -> SavedScriptRepository:
493+
def get_saved_script_repository(self, database: Database) -> SavedScriptRepository:
494494
return SavedScriptRepository(database)
495495

496496
@provide
497-
def get_dlq_repository(self, database: AsyncIOMotorDatabase) -> DLQRepository:
497+
def get_dlq_repository(self, database: Database) -> DLQRepository:
498498
return DLQRepository(database)
499499

500500
@provide
501-
def get_replay_repository(self, database: AsyncIOMotorDatabase) -> ReplayRepository:
501+
def get_replay_repository(self, database: Database) -> ReplayRepository:
502502
return ReplayRepository(database)
503503

504504
@provide
@@ -623,5 +623,5 @@ class ResultProcessorProvider(Provider):
623623
scope = Scope.APP
624624

625625
@provide
626-
def get_execution_repository(self, database: AsyncIOMotorDatabase) -> ExecutionRepository:
626+
def get_execution_repository(self, database: Database) -> ExecutionRepository:
627627
return ExecutionRepository(database)

backend/app/db/repositories/admin/admin_events_repository.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from datetime import datetime, timedelta, timezone
22
from typing import Any, Dict, List
33

4-
from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase
54
from pymongo import ReturnDocument
65

6+
from app.core.database_context import Collection, Database
77
from app.domain.admin import (
88
ReplayQuery,
99
ReplaySession,
@@ -43,18 +43,18 @@
4343
class AdminEventsRepository:
4444
"""Repository for admin event operations using domain models."""
4545

46-
def __init__(self, db: AsyncIOMotorDatabase):
46+
def __init__(self, db: Database):
4747
self.db = db
48-
self.events_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.EVENTS)
49-
self.event_store_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.EVENT_STORE)
48+
self.events_collection: Collection = self.db.get_collection(CollectionNames.EVENTS)
49+
self.event_store_collection: Collection = self.db.get_collection(CollectionNames.EVENT_STORE)
5050
# Bind related collections used by this repository
51-
self.executions_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.EXECUTIONS)
52-
self.events_archive_collection: AsyncIOMotorCollection = self.db.get_collection(
51+
self.executions_collection: Collection = self.db.get_collection(CollectionNames.EXECUTIONS)
52+
self.events_archive_collection: Collection = self.db.get_collection(
5353
CollectionNames.EVENTS_ARCHIVE
5454
)
5555
self.replay_mapper = ReplaySessionMapper()
5656
self.replay_query_mapper = ReplayQueryMapper()
57-
self.replay_sessions_collection: AsyncIOMotorCollection = self.db.get_collection(
57+
self.replay_sessions_collection: Collection = self.db.get_collection(
5858
CollectionNames.REPLAY_SESSIONS)
5959
self.mapper = EventMapper()
6060
self.summary_mapper = EventSummaryMapper()

backend/app/db/repositories/admin/admin_settings_repository.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from datetime import datetime, timezone
22

3-
from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase
4-
3+
from app.core.database_context import Collection, Database
54
from app.core.logging import logger
65
from app.domain.admin import (
76
AuditAction,
@@ -12,10 +11,10 @@
1211

1312

1413
class AdminSettingsRepository:
15-
def __init__(self, db: AsyncIOMotorDatabase):
14+
def __init__(self, db: Database):
1615
self.db = db
17-
self.settings_collection: AsyncIOMotorCollection = self.db.get_collection("system_settings")
18-
self.audit_log_collection: AsyncIOMotorCollection = self.db.get_collection("audit_log")
16+
self.settings_collection: Collection = self.db.get_collection("system_settings")
17+
self.audit_log_collection: Collection = self.db.get_collection("audit_log")
1918
self.settings_mapper = SettingsMapper()
2019
self.audit_mapper = AuditLogMapper()
2120

backend/app/db/repositories/admin/admin_user_repository.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from datetime import datetime, timezone
22

3-
from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase
4-
3+
from app.core.database_context import Collection, Database
54
from app.core.security import SecurityService
65
from app.domain.enums import UserRole
76
from app.domain.events.event_models import CollectionNames
@@ -17,17 +16,17 @@
1716

1817

1918
class AdminUserRepository:
20-
def __init__(self, db: AsyncIOMotorDatabase):
19+
def __init__(self, db: Database):
2120
self.db = db
22-
self.users_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.USERS)
21+
self.users_collection: Collection = self.db.get_collection(CollectionNames.USERS)
2322

2423
# Related collections used by this repository (e.g., cascade deletes)
25-
self.executions_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.EXECUTIONS)
26-
self.saved_scripts_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.SAVED_SCRIPTS)
27-
self.notifications_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.NOTIFICATIONS)
28-
self.user_settings_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.USER_SETTINGS)
29-
self.events_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.EVENTS)
30-
self.sagas_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.SAGAS)
24+
self.executions_collection: Collection = self.db.get_collection(CollectionNames.EXECUTIONS)
25+
self.saved_scripts_collection: Collection = self.db.get_collection(CollectionNames.SAVED_SCRIPTS)
26+
self.notifications_collection: Collection = self.db.get_collection(CollectionNames.NOTIFICATIONS)
27+
self.user_settings_collection: Collection = self.db.get_collection(CollectionNames.USER_SETTINGS)
28+
self.events_collection: Collection = self.db.get_collection(CollectionNames.EVENTS)
29+
self.sagas_collection: Collection = self.db.get_collection(CollectionNames.SAGAS)
3130
self.security_service = SecurityService()
3231
self.mapper = UserMapper()
3332

backend/app/db/repositories/dlq_repository.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from datetime import datetime, timezone
22
from typing import Dict, List, Mapping
33

4-
from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase
5-
4+
from app.core.database_context import Collection, Database
65
from app.core.logging import logger
76
from app.dlq import (
87
AgeStatistics,
@@ -24,9 +23,9 @@
2423

2524

2625
class DLQRepository:
27-
def __init__(self, db: AsyncIOMotorDatabase):
26+
def __init__(self, db: Database):
2827
self.db = db
29-
self.dlq_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.DLQ_MESSAGES)
28+
self.dlq_collection: Collection = self.db.get_collection(CollectionNames.DLQ_MESSAGES)
3029

3130
async def get_dlq_stats(self) -> DLQStatistics:
3231
# Get counts by status

backend/app/db/repositories/event_repository.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from types import MappingProxyType
44
from typing import Any, AsyncIterator, Mapping
55

6-
from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase
76
from pymongo import ASCENDING, DESCENDING
87

8+
from app.core.database_context import Collection, Database
99
from app.core.logging import logger
1010
from app.core.tracing import EventAttributes
1111
from app.core.tracing.utils import add_span_attributes
@@ -25,10 +25,10 @@
2525

2626

2727
class EventRepository:
28-
def __init__(self, database: AsyncIOMotorDatabase) -> None:
28+
def __init__(self, database: Database) -> None:
2929
self.database = database
3030
self.mapper = EventMapper()
31-
self._collection: AsyncIOMotorCollection = self.database.get_collection(CollectionNames.EVENTS)
31+
self._collection: Collection = self.database.get_collection(CollectionNames.EVENTS)
3232

3333
def _build_time_filter(
3434
self,

backend/app/db/repositories/execution_repository.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
from datetime import datetime, timezone
22

3-
from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase
4-
3+
from app.core.database_context import Collection, Database
54
from app.core.logging import logger
65
from app.domain.enums.execution import ExecutionStatus
76
from app.domain.events.event_models import CollectionNames
87
from app.domain.execution import DomainExecution, ExecutionResultDomain, ResourceUsageDomain
98

109

1110
class ExecutionRepository:
12-
def __init__(self, db: AsyncIOMotorDatabase):
11+
def __init__(self, db: Database):
1312
self.db = db
14-
self.collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.EXECUTIONS)
15-
self.results_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.EXECUTION_RESULTS)
13+
self.collection: Collection = self.db.get_collection(CollectionNames.EXECUTIONS)
14+
self.results_collection: Collection = self.db.get_collection(CollectionNames.EXECUTION_RESULTS)
1615

1716
async def create_execution(self, execution: DomainExecution) -> DomainExecution:
1817
execution_dict = {

backend/app/db/repositories/notification_repository.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from datetime import UTC, datetime, timedelta
22

3-
from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase
43
from pymongo import ASCENDING, DESCENDING, IndexModel
54

5+
from app.core.database_context import Collection, Database
66
from app.core.logging import logger
77
from app.domain.enums.notification import (
88
NotificationChannel,
@@ -16,11 +16,11 @@
1616

1717

1818
class NotificationRepository:
19-
def __init__(self, database: AsyncIOMotorDatabase):
20-
self.db: AsyncIOMotorDatabase = database
19+
def __init__(self, database: Database):
20+
self.db: Database = database
2121

22-
self.notifications_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.NOTIFICATIONS)
23-
self.subscriptions_collection: AsyncIOMotorCollection = self.db.get_collection(
22+
self.notifications_collection: Collection = self.db.get_collection(CollectionNames.NOTIFICATIONS)
23+
self.subscriptions_collection: Collection = self.db.get_collection(
2424
CollectionNames.NOTIFICATION_SUBSCRIPTIONS)
2525
self.mapper = NotificationMapper()
2626

0 commit comments

Comments
 (0)