Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ jobs:
SECRET_KEY: ${{ secrets.TEST_SECRET_KEY }}
run: |
cd backend
uv run mypy --config-file pyproject.toml .
uv run mypy --config-file pyproject.toml --strict .
12 changes: 9 additions & 3 deletions backend/app/core/database_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,21 @@
from motor.motor_asyncio import (
AsyncIOMotorClient,
AsyncIOMotorClientSession,
AsyncIOMotorCollection,
AsyncIOMotorCursor,
AsyncIOMotorDatabase,
)
from pymongo.errors import ServerSelectionTimeoutError

from app.core.logging import logger

# Python 3.12 type aliases using the new 'type' statement
type DBClient = AsyncIOMotorClient[Any]
type Database = AsyncIOMotorDatabase[Any]
# MongoDocument represents the raw document type returned by Motor operations
type MongoDocument = dict[str, Any]
type DBClient = AsyncIOMotorClient[MongoDocument]
type Database = AsyncIOMotorDatabase[MongoDocument]
type Collection = AsyncIOMotorCollection[MongoDocument]
type Cursor = AsyncIOMotorCursor[MongoDocument]
type DBSession = AsyncIOMotorClientSession

# Type variable for generic database provider
Expand Down Expand Up @@ -102,7 +108,7 @@ async def connect(self) -> None:
# Always explicitly bind to current event loop for consistency
import asyncio

client: AsyncIOMotorClient = AsyncIOMotorClient(
client: DBClient = AsyncIOMotorClient(
self._config.mongodb_url,
serverSelectionTimeoutMS=self._config.server_selection_timeout_ms,
connectTimeoutMS=self._config.connect_timeout_ms,
Expand Down
4 changes: 2 additions & 2 deletions backend/app/core/dishka_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import redis.asyncio as redis
from dishka import AsyncContainer
from fastapi import FastAPI
from motor.motor_asyncio import AsyncIOMotorDatabase

from app.core.database_context import Database
from app.core.logging import logger
from app.core.startup import initialize_metrics_context, initialize_rate_limits
from app.core.tracing import init_tracing
Expand Down Expand Up @@ -64,7 +64,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await initialize_event_schemas(schema_registry)

# Initialize database schema at application scope using app-scoped DB
database = await container.get(AsyncIOMotorDatabase)
database = await container.get(Database)
schema_manager = SchemaManager(database)
await schema_manager.apply_all()
logger.info("Database schema ensured by SchemaManager")
Expand Down
38 changes: 19 additions & 19 deletions backend/app/core/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import redis.asyncio as redis
from dishka import Provider, Scope, provide
from motor.motor_asyncio import AsyncIOMotorDatabase

from app.core.database_context import (
AsyncDatabaseConnection,
Database,
DatabaseConfig,
create_database_connection,
)
Expand Down Expand Up @@ -105,7 +105,7 @@ async def get_database_connection(self, settings: Settings) -> AsyncIterator[Asy
await db_connection.disconnect()

@provide
def get_database(self, db_connection: AsyncDatabaseConnection) -> AsyncIOMotorDatabase:
def get_database(self, db_connection: AsyncDatabaseConnection) -> Database:
return db_connection.database


Expand Down Expand Up @@ -174,7 +174,7 @@ async def get_kafka_producer(
await producer.stop()

@provide
async def get_dlq_manager(self, database: AsyncIOMotorDatabase) -> AsyncIterator[DLQManager]:
async def get_dlq_manager(self, database: Database) -> AsyncIterator[DLQManager]:
manager = create_dlq_manager(database)
await manager.start()
try:
Expand Down Expand Up @@ -210,7 +210,7 @@ def get_schema_registry(self) -> SchemaRegistryManager:
@provide
async def get_event_store(
self,
database: AsyncIOMotorDatabase,
database: Database,
schema_registry: SchemaRegistryManager
) -> EventStore:
store = create_event_store(
Expand Down Expand Up @@ -332,7 +332,7 @@ async def get_sse_kafka_redis_bridge(
@provide
def get_sse_repository(
self,
database: AsyncIOMotorDatabase
database: Database
) -> SSERepository:
return SSERepository(database)

Expand Down Expand Up @@ -365,7 +365,7 @@ class AuthProvider(Provider):
scope = Scope.APP

@provide
def get_user_repository(self, database: AsyncIOMotorDatabase) -> UserRepository:
def get_user_repository(self, database: Database) -> UserRepository:
return UserRepository(database)

@provide
Expand All @@ -377,11 +377,11 @@ class UserServicesProvider(Provider):
scope = Scope.APP

@provide
def get_user_settings_repository(self, database: AsyncIOMotorDatabase) -> UserSettingsRepository:
def get_user_settings_repository(self, database: Database) -> UserSettingsRepository:
return UserSettingsRepository(database)

@provide
def get_event_repository(self, database: AsyncIOMotorDatabase) -> EventRepository:
def get_event_repository(self, database: Database) -> EventRepository:
return EventRepository(database)

@provide
Expand Down Expand Up @@ -415,7 +415,7 @@ class AdminServicesProvider(Provider):
scope = Scope.APP

@provide
def get_admin_events_repository(self, database: AsyncIOMotorDatabase) -> AdminEventsRepository:
def get_admin_events_repository(self, database: Database) -> AdminEventsRepository:
return AdminEventsRepository(database)

@provide(scope=Scope.REQUEST)
Expand All @@ -427,7 +427,7 @@ def get_admin_events_service(
return AdminEventsService(admin_events_repository, replay_service)

@provide
def get_admin_settings_repository(self, database: AsyncIOMotorDatabase) -> AdminSettingsRepository:
def get_admin_settings_repository(self, database: Database) -> AdminSettingsRepository:
return AdminSettingsRepository(database)

@provide
Expand All @@ -438,15 +438,15 @@ def get_admin_settings_service(
return AdminSettingsService(admin_settings_repository)

@provide
def get_admin_user_repository(self, database: AsyncIOMotorDatabase) -> AdminUserRepository:
def get_admin_user_repository(self, database: Database) -> AdminUserRepository:
return AdminUserRepository(database)

@provide
def get_saga_repository(self, database: AsyncIOMotorDatabase) -> SagaRepository:
def get_saga_repository(self, database: Database) -> SagaRepository:
return SagaRepository(database)

@provide
def get_notification_repository(self, database: AsyncIOMotorDatabase) -> NotificationRepository:
def get_notification_repository(self, database: Database) -> NotificationRepository:
return NotificationRepository(database)

@provide
Expand Down Expand Up @@ -482,23 +482,23 @@ class BusinessServicesProvider(Provider):
scope = Scope.REQUEST

@provide
def get_execution_repository(self, database: AsyncIOMotorDatabase) -> ExecutionRepository:
def get_execution_repository(self, database: Database) -> ExecutionRepository:
return ExecutionRepository(database)

@provide
def get_resource_allocation_repository(self, database: AsyncIOMotorDatabase) -> ResourceAllocationRepository:
def get_resource_allocation_repository(self, database: Database) -> ResourceAllocationRepository:
return ResourceAllocationRepository(database)

@provide
def get_saved_script_repository(self, database: AsyncIOMotorDatabase) -> SavedScriptRepository:
def get_saved_script_repository(self, database: Database) -> SavedScriptRepository:
return SavedScriptRepository(database)

@provide
def get_dlq_repository(self, database: AsyncIOMotorDatabase) -> DLQRepository:
def get_dlq_repository(self, database: Database) -> DLQRepository:
return DLQRepository(database)

@provide
def get_replay_repository(self, database: AsyncIOMotorDatabase) -> ReplayRepository:
def get_replay_repository(self, database: Database) -> ReplayRepository:
return ReplayRepository(database)

@provide
Expand Down Expand Up @@ -623,5 +623,5 @@ class ResultProcessorProvider(Provider):
scope = Scope.APP

@provide
def get_execution_repository(self, database: AsyncIOMotorDatabase) -> ExecutionRepository:
def get_execution_repository(self, database: Database) -> ExecutionRepository:
return ExecutionRepository(database)
34 changes: 20 additions & 14 deletions backend/app/core/tracing/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import asyncio
import functools
from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator
from typing import Any, ParamSpec, TypeVar

from opentelemetry import context, propagate, trace
from opentelemetry.trace import SpanKind, Status, StatusCode

P = ParamSpec("P")
R = TypeVar("R")


def get_tracer() -> trace.Tracer:
"""Get a tracer for the current module."""
Expand All @@ -16,7 +20,7 @@ def get_tracer() -> trace.Tracer:
def trace_span(
name: str,
kind: SpanKind = SpanKind.INTERNAL,
attributes: Dict[str, Any] | None = None,
attributes: dict[str, Any] | None = None,
set_status_on_exception: bool = True,
tracer: trace.Tracer | None = None
) -> Generator[trace.Span, None, None]:
Expand Down Expand Up @@ -53,38 +57,40 @@ def trace_span(
def trace_method(
name: str | None = None,
kind: SpanKind = SpanKind.INTERNAL,
attributes: Dict[str, Any] | None = None
) -> Callable:
attributes: dict[str, Any] | None = None
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Decorator for tracing method calls.

Args:
name: Custom span name, defaults to module.method_name
kind: Kind of span (INTERNAL, CLIENT, SERVER, etc.)
attributes: Additional attributes to set on the span

Returns:
Decorated function with tracing
"""
def decorator(func: Callable) -> Callable:
def decorator(func: Callable[P, R]) -> Callable[P, R]:
span_name = name or f"{func.__module__}.{func.__name__}"

@functools.wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
with trace_span(span_name, kind=kind, attributes=attributes):
return await func(*args, **kwargs)
return await func(*args, **kwargs) # type: ignore[misc, no-any-return]

@functools.wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
with trace_span(span_name, kind=kind, attributes=attributes):
return func(*args, **kwargs)

return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
if asyncio.iscoroutinefunction(func):
return async_wrapper # type: ignore[return-value]
return sync_wrapper

return decorator


def inject_trace_context(headers: Dict[str, str]) -> Dict[str, str]:
def inject_trace_context(headers: dict[str, str]) -> dict[str, str]:
"""
Inject current trace context into headers for propagation.

Expand All @@ -99,7 +105,7 @@ def inject_trace_context(headers: Dict[str, str]) -> Dict[str, str]:
return propagation_headers


def extract_trace_context(headers: Dict[str, str]) -> context.Context:
def extract_trace_context(headers: dict[str, str]) -> context.Context:
"""
Extract trace context from headers.

Expand All @@ -126,7 +132,7 @@ def add_span_attributes(**attributes: Any) -> None:
span.set_attribute(key, value)


def add_span_event(name: str, attributes: Dict[str, Any] | None = None) -> None:
def add_span_event(name: str, attributes: dict[str, Any] | None = None) -> None:
"""
Add an event to the current span.

Expand Down
14 changes: 7 additions & 7 deletions backend/app/db/repositories/admin/admin_events_repository.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List

from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase
from pymongo import ReturnDocument

from app.core.database_context import Collection, Database
from app.domain.admin import (
ReplayQuery,
ReplaySession,
Expand Down Expand Up @@ -43,18 +43,18 @@
class AdminEventsRepository:
"""Repository for admin event operations using domain models."""

def __init__(self, db: AsyncIOMotorDatabase):
def __init__(self, db: Database):
self.db = db
self.events_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.EVENTS)
self.event_store_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.EVENT_STORE)
self.events_collection: Collection = self.db.get_collection(CollectionNames.EVENTS)
self.event_store_collection: Collection = self.db.get_collection(CollectionNames.EVENT_STORE)
# Bind related collections used by this repository
self.executions_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.EXECUTIONS)
self.events_archive_collection: AsyncIOMotorCollection = self.db.get_collection(
self.executions_collection: Collection = self.db.get_collection(CollectionNames.EXECUTIONS)
self.events_archive_collection: Collection = self.db.get_collection(
CollectionNames.EVENTS_ARCHIVE
)
self.replay_mapper = ReplaySessionMapper()
self.replay_query_mapper = ReplayQueryMapper()
self.replay_sessions_collection: AsyncIOMotorCollection = self.db.get_collection(
self.replay_sessions_collection: Collection = self.db.get_collection(
CollectionNames.REPLAY_SESSIONS)
self.mapper = EventMapper()
self.summary_mapper = EventSummaryMapper()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from datetime import datetime, timezone

from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase

from app.core.database_context import Collection, Database
from app.core.logging import logger
from app.domain.admin import (
AuditAction,
Expand All @@ -12,10 +11,10 @@


class AdminSettingsRepository:
def __init__(self, db: AsyncIOMotorDatabase):
def __init__(self, db: Database):
self.db = db
self.settings_collection: AsyncIOMotorCollection = self.db.get_collection("system_settings")
self.audit_log_collection: AsyncIOMotorCollection = self.db.get_collection("audit_log")
self.settings_collection: Collection = self.db.get_collection("system_settings")
self.audit_log_collection: Collection = self.db.get_collection("audit_log")
self.settings_mapper = SettingsMapper()
self.audit_mapper = AuditLogMapper()

Expand Down
19 changes: 9 additions & 10 deletions backend/app/db/repositories/admin/admin_user_repository.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from datetime import datetime, timezone

from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorDatabase

from app.core.database_context import Collection, Database
from app.core.security import SecurityService
from app.domain.enums import UserRole
from app.domain.events.event_models import CollectionNames
Expand All @@ -17,17 +16,17 @@


class AdminUserRepository:
def __init__(self, db: AsyncIOMotorDatabase):
def __init__(self, db: Database):
self.db = db
self.users_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.USERS)
self.users_collection: Collection = self.db.get_collection(CollectionNames.USERS)

# Related collections used by this repository (e.g., cascade deletes)
self.executions_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.EXECUTIONS)
self.saved_scripts_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.SAVED_SCRIPTS)
self.notifications_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.NOTIFICATIONS)
self.user_settings_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.USER_SETTINGS)
self.events_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.EVENTS)
self.sagas_collection: AsyncIOMotorCollection = self.db.get_collection(CollectionNames.SAGAS)
self.executions_collection: Collection = self.db.get_collection(CollectionNames.EXECUTIONS)
self.saved_scripts_collection: Collection = self.db.get_collection(CollectionNames.SAVED_SCRIPTS)
self.notifications_collection: Collection = self.db.get_collection(CollectionNames.NOTIFICATIONS)
self.user_settings_collection: Collection = self.db.get_collection(CollectionNames.USER_SETTINGS)
self.events_collection: Collection = self.db.get_collection(CollectionNames.EVENTS)
self.sagas_collection: Collection = self.db.get_collection(CollectionNames.SAGAS)
self.security_service = SecurityService()
self.mapper = UserMapper()

Expand Down
Loading
Loading