diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index f3c56abd..ece0f874 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -47,6 +47,7 @@ jobs: run: | cd backend uv run pytest tests/unit -v -rs \ + --durations=0 \ --cov=app \ --cov-report=xml --cov-report=term @@ -90,32 +91,12 @@ jobs: docker compose -f docker-compose.ci.yaml up -d --wait --wait-timeout 120 docker compose -f docker-compose.ci.yaml ps - - name: Create Kafka topics - timeout-minutes: 2 - env: - KAFKA_BOOTSTRAP_SERVERS: localhost:9092 - KAFKA_TOPIC_PREFIX: "ci.${{ github.run_id }}." - run: | - cd backend - uv run python -m scripts.create_topics - - name: Run integration tests timeout-minutes: 10 - env: - MONGO_ROOT_USER: root - MONGO_ROOT_PASSWORD: rootpassword - MONGODB_HOST: 127.0.0.1 - MONGODB_PORT: 27017 - MONGODB_URL: mongodb://root:rootpassword@127.0.0.1:27017/?authSource=admin - KAFKA_BOOTSTRAP_SERVERS: localhost:9092 - KAFKA_TOPIC_PREFIX: "ci.${{ github.run_id }}." - SCHEMA_REGISTRY_URL: http://localhost:8081 - REDIS_HOST: localhost - REDIS_PORT: 6379 - SCHEMA_SUBJECT_PREFIX: "ci.${{ github.run_id }}." run: | cd backend uv run pytest tests/integration -v -rs \ + --durations=0 \ --cov=app \ --cov-report=xml --cov-report=term @@ -184,32 +165,15 @@ jobs: timeout 90 bash -c 'until sudo k3s kubectl cluster-info; do sleep 5; done' kubectl create namespace integr8scode --dry-run=client -o yaml | kubectl apply -f - - - name: Create Kafka topics - timeout-minutes: 2 - env: - KAFKA_BOOTSTRAP_SERVERS: localhost:9092 - KAFKA_TOPIC_PREFIX: "ci.${{ github.run_id }}." - run: | - cd backend - uv run python -m scripts.create_topics - - name: Run E2E tests timeout-minutes: 10 env: - MONGO_ROOT_USER: root - MONGO_ROOT_PASSWORD: rootpassword - MONGODB_URL: mongodb://root:rootpassword@127.0.0.1:27017/?authSource=admin - KAFKA_BOOTSTRAP_SERVERS: localhost:9092 - KAFKA_TOPIC_PREFIX: "ci.${{ github.run_id }}." - SCHEMA_REGISTRY_URL: http://localhost:8081 - REDIS_HOST: localhost - REDIS_PORT: 6379 - SCHEMA_SUBJECT_PREFIX: "ci.${{ github.run_id }}." KUBECONFIG: /home/runner/.kube/config K8S_NAMESPACE: integr8scode run: | cd backend uv run pytest tests/e2e -v -rs \ + --durations=0 \ --cov=app \ --cov-report=xml --cov-report=term diff --git a/backend/.env b/backend/.env index 0c5e5982..cec960d1 100644 --- a/backend/.env +++ b/backend/.env @@ -76,3 +76,6 @@ WEB_BACKLOG=2048 # When running uvicorn locally (outside Docker), bind to IPv4 loopback to avoid # IPv6-only localhost resolution on some Linux distros. SERVER_HOST=127.0.0.1 + +# Security +BCRYPT_ROUNDS=12 diff --git a/backend/.env.test b/backend/.env.test index efc653de..5f984770 100644 --- a/backend/.env.test +++ b/backend/.env.test @@ -3,7 +3,6 @@ PROJECT_NAME=integr8scode DATABASE_NAME=integr8scode_test API_V1_STR=/api/v1 SECRET_KEY=test-secret-key-for-testing-only-32chars!! -ENVIRONMENT=testing TESTING=true # MongoDB - use localhost for tests @@ -23,22 +22,21 @@ REDIS_DECODE_RESPONSES=true # Kafka - use localhost for tests KAFKA_BOOTSTRAP_SERVERS=localhost:9092 KAFKA_TOPIC_PREFIX=test. +SCHEMA_SUBJECT_PREFIX=test. SCHEMA_REGISTRY_URL=http://localhost:8081 # Security SECURE_COOKIES=true -CORS_ALLOWED_ORIGINS=["http://localhost:3000","https://localhost:3000"] +BCRYPT_ROUNDS=4 # Features RATE_LIMIT_ENABLED=true ENABLE_TRACING=false -# OpenTelemetry - explicitly disabled for tests (no endpoint = NoOp meter) +# OpenTelemetry - explicitly disabled for tests OTEL_EXPORTER_OTLP_ENDPOINT= -OTEL_METRICS_EXPORTER=none -OTEL_TRACES_EXPORTER=none -OTEL_LOGS_EXPORTER=none # Development DEVELOPMENT_MODE=false LOG_LEVEL=INFO +ENVIRONMENT=test diff --git a/backend/Dockerfile b/backend/Dockerfile index b9897fac..97ab3a74 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -27,7 +27,7 @@ done echo "Starting application..." [ -f /app/kubeconfig.yaml ] && export KUBECONFIG=/app/kubeconfig.yaml -exec gunicorn app.main:app \ +exec gunicorn 'app.main:create_app()' \ -k uvicorn.workers.UvicornWorker \ --bind 0.0.0.0:443 \ --workers ${WEB_CONCURRENCY:-4} \ diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py index 606d57f2..55c9e4a3 100644 --- a/backend/app/api/routes/auth.py +++ b/backend/app/api/routes/auth.py @@ -7,7 +7,7 @@ from fastapi.security import OAuth2PasswordRequestForm from pymongo.errors import DuplicateKeyError -from app.core.security import security_service +from app.core.security import SecurityService from app.core.utils import get_client_ip from app.db.repositories import UserRepository from app.domain.user import DomainUserCreate @@ -19,7 +19,7 @@ UserResponse, ) from app.services.auth_service import AuthService -from app.settings import get_settings +from app.settings import Settings router = APIRouter(prefix="/auth", tags=["authentication"], route_class=DishkaRoute) @@ -29,6 +29,8 @@ async def login( request: Request, response: Response, user_repo: FromDishka[UserRepository], + security_service: FromDishka[SecurityService], + settings: FromDishka[Settings], logger: FromDishka[logging.Logger], form_data: OAuth2PasswordRequestForm = Depends(), ) -> LoginResponse: @@ -74,8 +76,6 @@ async def login( headers={"WWW-Authenticate": "Bearer"}, ) - settings = get_settings() - logger.info( "Login successful", extra={ @@ -127,6 +127,7 @@ async def register( request: Request, user: UserCreate, user_repo: FromDishka[UserRepository], + security_service: FromDishka[SecurityService], logger: FromDishka[logging.Logger], ) -> UserResponse: logger.info( diff --git a/backend/app/api/routes/dlq.py b/backend/app/api/routes/dlq.py index 1ab6136d..df88504a 100644 --- a/backend/app/api/routes/dlq.py +++ b/backend/app/api/routes/dlq.py @@ -95,9 +95,9 @@ async def get_dlq_message(event_id: str, repository: FromDishka[DLQRepository]) @router.post("/retry", response_model=DLQBatchRetryResponse) async def retry_dlq_messages( - retry_request: ManualRetryRequest, repository: FromDishka[DLQRepository], dlq_manager: FromDishka[DLQManager] + retry_request: ManualRetryRequest, dlq_manager: FromDishka[DLQManager] ) -> DLQBatchRetryResponse: - result = await repository.retry_messages_batch(retry_request.event_ids, dlq_manager) + result = await dlq_manager.retry_messages_batch(retry_request.event_ids) return DLQBatchRetryResponse( total=result.total, successful=result.successful, diff --git a/backend/app/api/routes/events.py b/backend/app/api/routes/events.py index bc15166f..7cfea823 100644 --- a/backend/app/api/routes/events.py +++ b/backend/app/api/routes/events.py @@ -27,7 +27,7 @@ from app.schemas_pydantic.user import UserResponse from app.services.event_service import EventService from app.services.kafka_event_service import KafkaEventService -from app.settings import get_settings +from app.settings import Settings router = APIRouter(prefix="/events", tags=["events"], route_class=DishkaRoute) @@ -229,8 +229,8 @@ async def publish_custom_event( event_request: PublishEventRequest, request: Request, event_service: FromDishka[KafkaEventService], + settings: FromDishka[Settings], ) -> PublishEventResponse: - settings = get_settings() base_meta = EventMetadata( service_name=settings.SERVICE_NAME, service_version=settings.SERVICE_VERSION, @@ -311,6 +311,7 @@ async def replay_aggregate_events( admin: Annotated[UserResponse, Depends(admin_user)], event_service: FromDishka[EventService], kafka_event_service: FromDishka[KafkaEventService], + settings: FromDishka[Settings], logger: FromDishka[logging.Logger], target_service: str | None = Query(None, description="Service to replay events to"), dry_run: bool = Query(True, description="If true, only show what would be replayed"), @@ -339,7 +340,6 @@ async def replay_aggregate_events( await asyncio.sleep(0.1) try: - settings = get_settings() meta = EventMetadata( service_name=settings.SERVICE_NAME, service_version=settings.SERVICE_VERSION, diff --git a/backend/app/api/routes/execution.py b/backend/app/api/routes/execution.py index 37723a01..85762a2f 100644 --- a/backend/app/api/routes/execution.py +++ b/backend/app/api/routes/execution.py @@ -34,7 +34,7 @@ from app.services.execution_service import ExecutionService from app.services.idempotency import IdempotencyManager from app.services.kafka_event_service import KafkaEventService -from app.settings import get_settings +from app.settings import Settings router = APIRouter(route_class=DishkaRoute) @@ -162,6 +162,7 @@ async def cancel_execution( current_user: Annotated[UserResponse, Depends(current_user)], cancel_request: CancelExecutionRequest, event_service: FromDishka[KafkaEventService], + settings: FromDishka[Settings], ) -> CancelResponse: # Handle terminal states terminal_states = [ExecutionStatus.COMPLETED, ExecutionStatus.FAILED, ExecutionStatus.TIMEOUT] @@ -178,7 +179,6 @@ async def cancel_execution( event_id="-1", # exact event_id unknown ) - settings = get_settings() payload = { "execution_id": execution.execution_id, "status": str(ExecutionStatus.CANCELLED), diff --git a/backend/app/core/adaptive_sampling.py b/backend/app/core/adaptive_sampling.py index 26e27883..5855018c 100644 --- a/backend/app/core/adaptive_sampling.py +++ b/backend/app/core/adaptive_sampling.py @@ -2,14 +2,14 @@ import threading import time from collections import deque -from typing import Any, Sequence, Tuple +from typing import Sequence, Tuple from opentelemetry.context import Context from opentelemetry.sdk.trace.sampling import Decision, Sampler, SamplingResult from opentelemetry.trace import Link, SpanKind, TraceState, get_current_span from opentelemetry.util.types import Attributes -from app.settings import get_settings +from app.settings import Settings class AdaptiveSampler(Sampler): @@ -239,11 +239,8 @@ def shutdown(self) -> None: self._adjustment_thread.join(timeout=5.0) -def create_adaptive_sampler(settings: Any | None = None) -> AdaptiveSampler: +def create_adaptive_sampler(settings: Settings) -> AdaptiveSampler: """Create adaptive sampler with settings""" - if settings is None: - settings = get_settings() - return AdaptiveSampler( base_rate=settings.TRACING_SAMPLING_RATE, min_rate=max(0.001, settings.TRACING_SAMPLING_RATE / 100), # 1/100th of base diff --git a/backend/app/core/container.py b/backend/app/core/container.py index 97411a49..97e0c48f 100644 --- a/backend/app/core/container.py +++ b/backend/app/core/container.py @@ -45,6 +45,7 @@ def create_app_container(settings: Settings) -> AsyncContainer: RepositoryProvider(), MessagingProvider(), EventProvider(), + SagaOrchestratorProvider(), KafkaServicesProvider(), SSEProvider(), AuthProvider(), diff --git a/backend/app/core/dishka_lifespan.py b/backend/app/core/dishka_lifespan.py index 038fd18d..d419bf54 100644 --- a/backend/app/core/dishka_lifespan.py +++ b/backend/app/core/dishka_lifespan.py @@ -14,7 +14,7 @@ from app.events.event_store_consumer import EventStoreConsumer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge -from app.settings import get_settings +from app.settings import Settings @asynccontextmanager @@ -27,10 +27,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: - No manual service management - Dishka handles all lifecycle automatically """ - settings = get_settings() - - # Get logger from DI container + # Get settings and logger from DI container (uses test settings in tests) container: AsyncContainer = app.state.dishka_container + settings = await container.get(Settings) logger = await container.get(logging.Logger) logger.info( @@ -44,25 +43,32 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # Metrics setup moved to app creation to allow middleware registration logger.info("Lifespan start: tracing and services initialization") - # Initialize tracing - instrumentation_report = init_tracing( - service_name=settings.TRACING_SERVICE_NAME, - logger=logger, - service_version=settings.TRACING_SERVICE_VERSION, - sampling_rate=settings.TRACING_SAMPLING_RATE, - enable_console_exporter=settings.TESTING, - adaptive_sampling=settings.TRACING_ADAPTIVE_SAMPLING, - ) - - if instrumentation_report.has_failures(): - logger.warning( - "Some instrumentation libraries failed to initialize", - extra={"instrumentation_summary": instrumentation_report.get_summary()}, + # Initialize tracing only when enabled (avoid exporter retries in tests) + if settings.ENABLE_TRACING and not settings.TESTING: + instrumentation_report = init_tracing( + service_name=settings.TRACING_SERVICE_NAME, + settings=settings, + logger=logger, + service_version=settings.TRACING_SERVICE_VERSION, + sampling_rate=settings.TRACING_SAMPLING_RATE, + enable_console_exporter=settings.TESTING, + adaptive_sampling=settings.TRACING_ADAPTIVE_SAMPLING, ) + + if instrumentation_report.has_failures(): + logger.warning( + "Some instrumentation libraries failed to initialize", + extra={"instrumentation_summary": instrumentation_report.get_summary()}, + ) + else: + logger.info( + "Distributed tracing initialized successfully", + extra={"instrumentation_summary": instrumentation_report.get_summary()}, + ) else: logger.info( - "Distributed tracing initialized successfully", - extra={"instrumentation_summary": instrumentation_report.get_summary()}, + "Distributed tracing disabled", + extra={"testing": settings.TESTING, "enable_tracing": settings.ENABLE_TRACING}, ) # Initialize schema registry once at startup diff --git a/backend/app/core/k8s_clients.py b/backend/app/core/k8s_clients.py index 2a475df3..0aedd5c7 100644 --- a/backend/app/core/k8s_clients.py +++ b/backend/app/core/k8s_clients.py @@ -3,14 +3,18 @@ from kubernetes import client as k8s_client from kubernetes import config as k8s_config +from kubernetes import watch as k8s_watch @dataclass(frozen=True) class K8sClients: + """Kubernetes API clients bundle for dependency injection.""" + api_client: k8s_client.ApiClient v1: k8s_client.CoreV1Api apps_v1: k8s_client.AppsV1Api networking_v1: k8s_client.NetworkingV1Api + watch: k8s_watch.Watch def create_k8s_clients( @@ -33,6 +37,7 @@ def create_k8s_clients( v1=k8s_client.CoreV1Api(api_client), apps_v1=k8s_client.AppsV1Api(api_client), networking_v1=k8s_client.NetworkingV1Api(api_client), + watch=k8s_watch.Watch(), ) diff --git a/backend/app/core/metrics/base.py b/backend/app/core/metrics/base.py index 03899642..911ed583 100644 --- a/backend/app/core/metrics/base.py +++ b/backend/app/core/metrics/base.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter from opentelemetry.metrics import Meter, NoOpMeterProvider @@ -7,51 +6,49 @@ from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader from opentelemetry.sdk.resources import Resource -from app.settings import get_settings +from app.settings import Settings @dataclass class MetricsConfig: service_name: str = "integr8scode-backend" service_version: str = "1.0.0" - otlp_endpoint: Optional[str] = None + otlp_endpoint: str | None = None export_interval_millis: int = 10000 console_export_interval_millis: int = 60000 class BaseMetrics: - def __init__(self, meter_name: str | None = None): + def __init__(self, settings: Settings, meter_name: str | None = None): """Initialize base metrics with its own meter. Args: + settings: Application settings. meter_name: Optional name for the meter. Defaults to class name. """ - # Get settings and create config - settings = get_settings() config = MetricsConfig( service_name=settings.TRACING_SERVICE_NAME or "integr8scode-backend", service_version="1.0.0", otlp_endpoint=settings.OTEL_EXPORTER_OTLP_ENDPOINT, ) - # Each collector creates its own independent meter meter_name = meter_name or self.__class__.__name__ - self._meter = self._create_meter(config, meter_name) + self._meter = self._create_meter(settings, config, meter_name) self._create_instruments() - def _create_meter(self, config: MetricsConfig, meter_name: str) -> Meter: + def _create_meter(self, settings: Settings, config: MetricsConfig, meter_name: str) -> Meter: """Create a new meter instance for this collector. Args: + settings: Application settings config: Metrics configuration meter_name: Name for this meter Returns: A new meter instance """ - # If tracing/metrics disabled or no OTLP endpoint configured, use NoOp meter to avoid threads/network - settings = get_settings() - if not settings.ENABLE_TRACING or not config.otlp_endpoint: + # If tracing/metrics disabled or no OTLP endpoint configured, use NoOp meter + if not config.otlp_endpoint: return NoOpMeterProvider().get_meter(meter_name) resource = Resource.create( diff --git a/backend/app/core/metrics/context.py b/backend/app/core/metrics/context.py index 54a88e60..dd87c3b2 100644 --- a/backend/app/core/metrics/context.py +++ b/backend/app/core/metrics/context.py @@ -45,21 +45,20 @@ def __init__(self, name: str, metric_class: Type[T], logger: logging.Logger) -> def get(self) -> T: """ - Get the metric from context, creating it if necessary. - - This method implements lazy initialization - if no metric exists - in the current context, it creates one. This is useful for testing - and standalone scripts where the context might not be initialized. + Get the metric from context. Returns: The metric instance for the current context + + Raises: + RuntimeError: If metrics not initialized via DI """ metric = self._context_var.get() if metric is None: - # Lazy initialization with logging - self.logger.debug(f"Lazy initializing {self._name} metrics in context") - metric = self._metric_class() - self._context_var.set(metric) + raise RuntimeError( + f"{self._name} metrics not initialized. " + "Ensure MetricsContext.initialize_all() is called during app startup." + ) return metric def set(self, metric: T) -> contextvars.Token[Optional[T]]: diff --git a/backend/app/core/middlewares/__init__.py b/backend/app/core/middlewares/__init__.py index a1a2441d..0ea15f3a 100644 --- a/backend/app/core/middlewares/__init__.py +++ b/backend/app/core/middlewares/__init__.py @@ -1,10 +1,12 @@ from .cache import CacheControlMiddleware +from .csrf import CSRFMiddleware from .metrics import MetricsMiddleware, create_system_metrics, setup_metrics from .rate_limit import RateLimitMiddleware from .request_size_limit import RequestSizeLimitMiddleware __all__ = [ "CacheControlMiddleware", + "CSRFMiddleware", "MetricsMiddleware", "setup_metrics", "create_system_metrics", diff --git a/backend/app/core/middlewares/csrf.py b/backend/app/core/middlewares/csrf.py new file mode 100644 index 00000000..1c644ee9 --- /dev/null +++ b/backend/app/core/middlewares/csrf.py @@ -0,0 +1,56 @@ +import logging + +from dishka import AsyncContainer +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.types import ASGIApp, Receive, Scope, Send + +from app.core.security import SecurityService +from app.domain.user import CSRFValidationError + +logger = logging.getLogger(__name__) + + +class CSRFMiddleware: + """ + Middleware for CSRF protection using double-submit cookie pattern. + + This middleware validates that state-changing requests (POST, PUT, DELETE, PATCH) + include a valid CSRF token in the X-CSRF-Token header that matches the csrf_token cookie. + + Requests are skipped if: + - Method is safe (GET, HEAD, OPTIONS) + - Path is an auth endpoint (login, register, logout) + - Path is not under /api/ + - User is not authenticated (no access_token cookie) + """ + + def __init__(self, app: ASGIApp, container: AsyncContainer) -> None: + self.app = app + self.container = container + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + security_service: SecurityService = await self.container.get(SecurityService) + + request = Request(scope, receive=receive) + + try: + # validate_csrf_from_request returns "skip" or the token if valid + # raises CSRFValidationError if invalid + security_service.validate_csrf_from_request(request) + await self.app(scope, receive, send) + + except CSRFValidationError as e: + logger.warning( + "CSRF validation failed", + extra={"path": request.url.path, "method": request.method, "reason": str(e)}, + ) + response = JSONResponse( + status_code=403, + content={"detail": "CSRF validation failed"}, + ) + await response(scope, receive, send) diff --git a/backend/app/core/middlewares/metrics.py b/backend/app/core/middlewares/metrics.py index 3513bee4..784dc174 100644 --- a/backend/app/core/middlewares/metrics.py +++ b/backend/app/core/middlewares/metrics.py @@ -13,7 +13,7 @@ from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_VERSION, Resource from starlette.types import ASGIApp, Message, Receive, Scope, Send -from app.settings import get_settings +from app.settings import Settings class MetricsMiddleware: @@ -118,29 +118,24 @@ def _get_path_template(path: str) -> str: return path -def setup_metrics(app: FastAPI, logger: logging.Logger) -> None: +def setup_metrics(app: FastAPI, settings: Settings, logger: logging.Logger) -> None: """Set up OpenTelemetry metrics with OTLP exporter.""" - settings = get_settings() - # Fast opt-out for tests or when explicitly disabled - if settings.TESTING or os.getenv("OTEL_SDK_DISABLED", "").lower() in {"1", "true", "yes"}: - logger.info("OpenTelemetry metrics disabled (TESTING/OTEL_SDK_DISABLED)") + if not settings.OTEL_EXPORTER_OTLP_ENDPOINT: + logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT not configured, skipping metrics setup") return # Configure OpenTelemetry resource resource = Resource.create( { - SERVICE_NAME: settings.PROJECT_NAME, - SERVICE_VERSION: "1.0.0", - "service.environment": "test" if settings.TESTING else "production", + SERVICE_NAME: settings.SERVICE_NAME, + SERVICE_VERSION: settings.SERVICE_VERSION, + "service.environment": settings.ENVIRONMENT, } ) # Configure OTLP exporter (sends to OpenTelemetry Collector or compatible backend) # Default endpoint is localhost:4317 for gRPC - otlp_exporter = OTLPMetricExporter( - endpoint=os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "localhost:4317"), - insecure=True, # Use insecure for local development - ) + otlp_exporter = OTLPMetricExporter(endpoint=settings.OTEL_EXPORTER_OTLP_ENDPOINT, insecure=True) # Create metric reader with 60 second export interval metric_reader = PeriodicExportingMetricReader( diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index c1f29693..1daf530d 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -22,6 +22,7 @@ from app.core.metrics.connections import ConnectionMetrics from app.core.metrics.events import EventMetrics from app.core.metrics.rate_limit import RateLimitMetrics +from app.core.security import SecurityService from app.core.tracing import TracerManager from app.db.repositories import ( EventRepository, @@ -63,6 +64,7 @@ from app.services.kafka_event_service import KafkaEventService from app.services.notification_service import NotificationService from app.services.pod_monitor.config import PodMonitorConfig +from app.services.pod_monitor.event_mapper import PodEventMapper from app.services.pod_monitor.monitor import PodMonitor from app.services.rate_limit_service import RateLimitService from app.services.replay_service import ReplayService @@ -142,6 +144,10 @@ async def get_database(self, settings: Settings, logger: logging.Logger) -> Asyn class CoreServicesProvider(Provider): scope = Scope.APP + @provide + def get_security_service(self, settings: Settings) -> SecurityService: + return SecurityService(settings) + @provide def get_tracer_manager(self, settings: Settings) -> TracerManager: return TracerManager(tracer_name=settings.TRACING_SERVICE_NAME) @@ -155,7 +161,7 @@ async def get_kafka_producer( self, settings: Settings, schema_registry: SchemaRegistryManager, logger: logging.Logger ) -> AsyncIterator[UnifiedProducer]: config = ProducerConfig(bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS) - async with UnifiedProducer(config, schema_registry, logger) as producer: + async with UnifiedProducer(config, schema_registry, logger, settings=settings) as producer: yield producer @provide @@ -214,8 +220,8 @@ async def get_event_store_consumer( yield consumer @provide - async def get_event_bus_manager(self, logger: logging.Logger) -> AsyncIterator[EventBusManager]: - manager = EventBusManager(logger) + async def get_event_bus_manager(self, settings: Settings, logger: logging.Logger) -> AsyncIterator[EventBusManager]: + manager = EventBusManager(settings, logger) try: yield manager finally: @@ -240,52 +246,52 @@ class MetricsProvider(Provider): scope = Scope.APP @provide - def get_event_metrics(self) -> EventMetrics: - return EventMetrics() + def get_event_metrics(self, settings: Settings) -> EventMetrics: + return EventMetrics(settings) @provide - def get_connection_metrics(self) -> ConnectionMetrics: - return ConnectionMetrics() + def get_connection_metrics(self, settings: Settings) -> ConnectionMetrics: + return ConnectionMetrics(settings) @provide - def get_rate_limit_metrics(self) -> RateLimitMetrics: - return RateLimitMetrics() + def get_rate_limit_metrics(self, settings: Settings) -> RateLimitMetrics: + return RateLimitMetrics(settings) @provide - def get_execution_metrics(self) -> ExecutionMetrics: - return ExecutionMetrics() + def get_execution_metrics(self, settings: Settings) -> ExecutionMetrics: + return ExecutionMetrics(settings) @provide - def get_database_metrics(self) -> DatabaseMetrics: - return DatabaseMetrics() + def get_database_metrics(self, settings: Settings) -> DatabaseMetrics: + return DatabaseMetrics(settings) @provide - def get_health_metrics(self) -> HealthMetrics: - return HealthMetrics() + def get_health_metrics(self, settings: Settings) -> HealthMetrics: + return HealthMetrics(settings) @provide - def get_kubernetes_metrics(self) -> KubernetesMetrics: - return KubernetesMetrics() + def get_kubernetes_metrics(self, settings: Settings) -> KubernetesMetrics: + return KubernetesMetrics(settings) @provide - def get_coordinator_metrics(self) -> CoordinatorMetrics: - return CoordinatorMetrics() + def get_coordinator_metrics(self, settings: Settings) -> CoordinatorMetrics: + return CoordinatorMetrics(settings) @provide - def get_dlq_metrics(self) -> DLQMetrics: - return DLQMetrics() + def get_dlq_metrics(self, settings: Settings) -> DLQMetrics: + return DLQMetrics(settings) @provide - def get_notification_metrics(self) -> NotificationMetrics: - return NotificationMetrics() + def get_notification_metrics(self, settings: Settings) -> NotificationMetrics: + return NotificationMetrics(settings) @provide - def get_replay_metrics(self) -> ReplayMetrics: - return ReplayMetrics() + def get_replay_metrics(self, settings: Settings) -> ReplayMetrics: + return ReplayMetrics(settings) @provide - def get_security_metrics(self) -> SecurityMetrics: - return SecurityMetrics() + def get_security_metrics(self, settings: Settings) -> SecurityMetrics: + return SecurityMetrics(settings) class RepositoryProvider(Provider): @@ -334,8 +340,8 @@ def get_admin_settings_repository(self, logger: logging.Logger) -> AdminSettings return AdminSettingsRepository(logger) @provide - def get_admin_user_repository(self) -> AdminUserRepository: - return AdminUserRepository() + def get_admin_user_repository(self, security_service: SecurityService) -> AdminUserRepository: + return AdminUserRepository(security_service) @provide def get_notification_repository(self, logger: logging.Logger) -> NotificationRepository: @@ -407,8 +413,10 @@ class AuthProvider(Provider): scope = Scope.APP @provide - def get_auth_service(self, user_repository: UserRepository, logger: logging.Logger) -> AuthService: - return AuthService(user_repository, logger) + def get_auth_service( + self, user_repository: UserRepository, security_service: SecurityService, logger: logging.Logger + ) -> AuthService: + return AuthService(user_repository, security_service, logger) class KafkaServicesProvider(Provider): @@ -422,9 +430,18 @@ def get_event_service(self, event_repository: EventRepository) -> EventService: @provide def get_kafka_event_service( - self, event_repository: EventRepository, kafka_producer: UnifiedProducer, logger: logging.Logger + self, + event_repository: EventRepository, + kafka_producer: UnifiedProducer, + settings: Settings, + logger: logging.Logger, ) -> KafkaEventService: - return KafkaEventService(event_repository=event_repository, kafka_producer=kafka_producer, logger=logger) + return KafkaEventService( + event_repository=event_repository, + kafka_producer=kafka_producer, + settings=settings, + logger=logger, + ) class UserServicesProvider(Provider): @@ -562,7 +579,6 @@ class BusinessServicesProvider(Provider): def __init__(self) -> None: super().__init__() # Register shared factory functions on instance (avoids warning about missing self) - self.provide(_provide_saga_orchestrator) self.provide(_provide_execution_coordinator) @provide @@ -609,10 +625,15 @@ async def get_replay_service( replay_repository: ReplayRepository, kafka_producer: UnifiedProducer, event_store: EventStore, + settings: Settings, logger: logging.Logger, ) -> ReplayService: event_replay_service = EventReplayService( - repository=replay_repository, producer=kafka_producer, event_store=event_store, logger=logger + repository=replay_repository, + producer=kafka_producer, + event_store=event_store, + settings=settings, + logger=logger, ) return ReplayService(replay_repository, event_replay_service, logger) @@ -623,6 +644,7 @@ def get_admin_user_service( event_service: EventService, execution_service: ExecutionService, rate_limit_service: RateLimitService, + security_service: SecurityService, logger: logging.Logger, ) -> AdminUserService: return AdminUserService( @@ -630,6 +652,7 @@ def get_admin_user_service( event_service=event_service, execution_service=execution_service, rate_limit_service=rate_limit_service, + security_service=security_service, logger=logger, ) @@ -671,12 +694,21 @@ async def get_kubernetes_worker( class PodMonitorProvider(Provider): scope = Scope.APP + @provide + def get_event_mapper( + self, + logger: logging.Logger, + k8s_clients: K8sClients, + ) -> PodEventMapper: + return PodEventMapper(logger=logger, k8s_api=k8s_clients.v1) + @provide async def get_pod_monitor( self, kafka_event_service: KafkaEventService, k8s_clients: K8sClients, logger: logging.Logger, + event_mapper: PodEventMapper, ) -> AsyncIterator[PodMonitor]: config = PodMonitorConfig() async with PodMonitor( @@ -684,6 +716,7 @@ async def get_pod_monitor( kafka_event_service=kafka_event_service, logger=logger, k8s_clients=k8s_clients, + event_mapper=event_mapper, ) as monitor: yield monitor @@ -705,12 +738,14 @@ def get_event_replay_service( replay_repository: ReplayRepository, kafka_producer: UnifiedProducer, event_store: EventStore, + settings: Settings, logger: logging.Logger, ) -> EventReplayService: return EventReplayService( repository=replay_repository, producer=kafka_producer, event_store=event_store, + settings=settings, logger=logger, ) diff --git a/backend/app/core/security.py b/backend/app/core/security.py index 1e6e0277..01b86d78 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -8,7 +8,7 @@ from app.domain.user import AuthenticationRequiredError, CSRFValidationError, InvalidCredentialsError from app.domain.user import User as DomainAdminUser -from app.settings import get_settings +from app.settings import Settings oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/login") @@ -21,9 +21,13 @@ def get_token_from_cookie(request: Request) -> str: class SecurityService: - def __init__(self) -> None: - self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - self.settings = get_settings() + def __init__(self, settings: Settings) -> None: + self.settings = settings + self.pwd_context = CryptContext( + schemes=["bcrypt"], + deprecated="auto", + bcrypt__rounds=self.settings.BCRYPT_ROUNDS, + ) def verify_password(self, plain_password: str, hashed_password: str) -> bool: return self.pwd_context.verify(plain_password, hashed_password) # type: ignore @@ -70,38 +74,49 @@ def validate_csrf_token(self, header_token: str, cookie_token: str) -> bool: return hmac.compare_digest(header_token, cookie_token) + # Paths exempt from CSRF validation (auth handles its own security) + CSRF_EXEMPT_PATHS: frozenset[str] = frozenset({ + "/api/v1/auth/login", + "/api/v1/auth/register", + "/api/v1/auth/logout", + }) -security_service = SecurityService() + def validate_csrf_from_request(self, request: Request) -> str: + """Validate CSRF token from HTTP request using double-submit cookie pattern. + Returns: + "skip" if validation was skipped (safe method, exempt path, or unauthenticated) + The validated token string if validation passed -def validate_csrf_token(request: Request) -> str: - """FastAPI dependency to validate CSRF token using double-submit cookie pattern""" - # Skip CSRF validation for safe methods - if request.method in ["GET", "HEAD", "OPTIONS"]: - return "skip" + Raises: + CSRFValidationError: If token is missing or invalid + """ + # Skip CSRF validation for safe methods + if request.method in ("GET", "HEAD", "OPTIONS"): + return "skip" - # Skip CSRF validation for auth endpoints - if request.url.path in ["/api/v1/login", "/api/v1/register", "/api/v1/logout"]: - return "skip" + # Skip CSRF validation for auth endpoints + if request.url.path in self.CSRF_EXEMPT_PATHS: + return "skip" - # Skip CSRF validation for non-API endpoints - if not request.url.path.startswith("/api/"): - return "skip" + # Skip CSRF validation for non-API endpoints (health, metrics, etc.) + if not request.url.path.startswith("/api/"): + return "skip" - # Check if user is authenticated first (has access_token cookie) - access_token = request.cookies.get("access_token") - if not access_token: - # If not authenticated, skip CSRF validation (auth will be handled by other dependencies) - return "skip" + # Check if user is authenticated first (has access_token cookie) + access_token = request.cookies.get("access_token") + if not access_token: + # If not authenticated, skip CSRF validation (auth will be handled by other dependencies) + return "skip" - # Get CSRF token from header and cookie - header_token = request.headers.get("X-CSRF-Token") - cookie_token = request.cookies.get("csrf_token", "") + # Get CSRF token from header and cookie + header_token = request.headers.get("X-CSRF-Token") + cookie_token = request.cookies.get("csrf_token", "") - if not header_token: - raise CSRFValidationError("CSRF token missing") + if not header_token: + raise CSRFValidationError("CSRF token missing from X-CSRF-Token header") - if not security_service.validate_csrf_token(header_token, cookie_token): - raise CSRFValidationError("CSRF token invalid") + if not self.validate_csrf_token(header_token, cookie_token): + raise CSRFValidationError("CSRF token invalid or does not match cookie") - return header_token + return header_token diff --git a/backend/app/core/tracing/config.py b/backend/app/core/tracing/config.py index 4eeb74c2..5aa61fa3 100644 --- a/backend/app/core/tracing/config.py +++ b/backend/app/core/tracing/config.py @@ -21,7 +21,7 @@ InstrumentationStatus, LibraryInstrumentation, ) -from app.settings import get_settings +from app.settings import Settings class TracingConfiguration: @@ -30,6 +30,7 @@ class TracingConfiguration: def __init__( self, service_name: str, + settings: Settings, service_version: str = "1.0.0", otlp_endpoint: str | None = None, enable_console_exporter: bool = False, @@ -42,7 +43,7 @@ def __init__( self.enable_console_exporter = enable_console_exporter self.sampling_rate = sampling_rate self.adaptive_sampling = adaptive_sampling - self._settings = get_settings() + self._settings = settings def create_resource(self) -> Resource: """Create OpenTelemetry resource with service metadata.""" @@ -59,7 +60,7 @@ def create_resource(self) -> Resource: def create_sampler(self) -> Sampler: """Create appropriate sampler based on configuration.""" if self.adaptive_sampling: - return create_adaptive_sampler() + return create_adaptive_sampler(self._settings) if self.sampling_rate <= 0: return ALWAYS_OFF @@ -176,6 +177,7 @@ def _instrument_library(self, lib: LibraryInstrumentation) -> InstrumentationRes def init_tracing( service_name: str, + settings: Settings, logger: logging.Logger, service_version: str = "1.0.0", otlp_endpoint: str | None = None, @@ -186,6 +188,7 @@ def init_tracing( """Initialize OpenTelemetry tracing with the given configuration.""" config = TracingConfiguration( service_name=service_name, + settings=settings, service_version=service_version, otlp_endpoint=otlp_endpoint, enable_console_exporter=enable_console_exporter, diff --git a/backend/app/db/repositories/admin/admin_user_repository.py b/backend/app/db/repositories/admin/admin_user_repository.py index f7aed21a..400d84a6 100644 --- a/backend/app/db/repositories/admin/admin_user_repository.py +++ b/backend/app/db/repositories/admin/admin_user_repository.py @@ -20,8 +20,8 @@ class AdminUserRepository: - def __init__(self) -> None: - self.security_service = SecurityService() + def __init__(self, security_service: SecurityService) -> None: + self.security_service = security_service async def create_user(self, create_data: DomainUserCreate) -> User: doc = UserDocument(**asdict(create_data)) @@ -29,7 +29,7 @@ async def create_user(self, create_data: DomainUserCreate) -> User: return User(**doc.model_dump(exclude={"id", "revision_id"})) async def list_users( - self, limit: int = 100, offset: int = 0, search: str | None = None, role: UserRole | None = None + self, limit: int = 100, offset: int = 0, search: str | None = None, role: UserRole | None = None ) -> UserListResult: conditions: list[BaseFindOperator] = [] diff --git a/backend/app/db/repositories/dlq_repository.py b/backend/app/db/repositories/dlq_repository.py index 5ab12674..af30ee45 100644 --- a/backend/app/db/repositories/dlq_repository.py +++ b/backend/app/db/repositories/dlq_repository.py @@ -7,17 +7,14 @@ from app.db.docs import DLQMessageDocument from app.dlq import ( AgeStatistics, - DLQBatchRetryResult, DLQMessage, DLQMessageListResult, DLQMessageStatus, - DLQRetryResult, DLQStatistics, DLQTopicSummary, EventTypeStatistic, TopicStatistic, ) -from app.dlq.manager import DLQManager from app.domain.enums.events import EventType from app.infrastructure.kafka.mappings import get_event_class_for_type @@ -186,32 +183,3 @@ async def mark_message_discarded(self, event_id: str, reason: str) -> bool: doc.last_updated = now await doc.save() return True - - async def retry_messages_batch(self, event_ids: list[str], dlq_manager: DLQManager) -> DLQBatchRetryResult: - details = [] - successful = 0 - failed = 0 - - for event_id in event_ids: - try: - doc = await DLQMessageDocument.find_one({"event_id": event_id}) - if not doc: - failed += 1 - details.append(DLQRetryResult(event_id=event_id, status="failed", error="Message not found")) - continue - - success = await dlq_manager.retry_message_manually(event_id) - if success: - await self.mark_message_retried(event_id) - successful += 1 - details.append(DLQRetryResult(event_id=event_id, status="success")) - else: - failed += 1 - details.append(DLQRetryResult(event_id=event_id, status="failed", error="Retry failed")) - - except Exception as e: - self.logger.error(f"Error retrying message {event_id}: {e}") - failed += 1 - details.append(DLQRetryResult(event_id=event_id, status="failed", error=str(e))) - - return DLQBatchRetryResult(total=len(event_ids), successful=successful, failed=failed, details=details) diff --git a/backend/app/db/repositories/notification_repository.py b/backend/app/db/repositories/notification_repository.py index ffed3a1a..e7c261f9 100644 --- a/backend/app/db/repositories/notification_repository.py +++ b/backend/app/db/repositories/notification_repository.py @@ -154,10 +154,12 @@ async def cleanup_old_notifications(self, days: int = 30) -> int: # Subscriptions async def get_subscription( self, user_id: str, channel: NotificationChannel - ) -> DomainNotificationSubscription | None: + ) -> DomainNotificationSubscription: + """Get subscription for user/channel, returning default enabled subscription if none exists.""" doc = await NotificationSubscriptionDocument.find_one({"user_id": user_id, "channel": channel}) if not doc: - return None + # Default: enabled=True for new users (consistent with get_all_subscriptions) + return DomainNotificationSubscription(user_id=user_id, channel=channel, enabled=True) return DomainNotificationSubscription(**doc.model_dump(exclude={"id"})) async def upsert_subscription( diff --git a/backend/app/dlq/manager.py b/backend/app/dlq/manager.py index f64ca00f..fdbb729c 100644 --- a/backend/app/dlq/manager.py +++ b/backend/app/dlq/manager.py @@ -13,9 +13,11 @@ from app.core.tracing.utils import extract_trace_context, get_tracer, inject_trace_context from app.db.docs import DLQMessageDocument from app.dlq.models import ( + DLQBatchRetryResult, DLQMessage, DLQMessageStatus, DLQMessageUpdate, + DLQRetryResult, RetryPolicy, RetryStrategy, ) @@ -459,6 +461,35 @@ async def retry_message_manually(self, event_id: str) -> bool: await self._retry_message(message) return True + async def retry_messages_batch(self, event_ids: list[str]) -> DLQBatchRetryResult: + """Retry multiple DLQ messages in batch. + + Args: + event_ids: List of event IDs to retry + + Returns: + Batch result with success/failure counts and details + """ + details: list[DLQRetryResult] = [] + successful = 0 + failed = 0 + + for event_id in event_ids: + try: + success = await self.retry_message_manually(event_id) + if success: + successful += 1 + details.append(DLQRetryResult(event_id=event_id, status="success")) + else: + failed += 1 + details.append(DLQRetryResult(event_id=event_id, status="failed", error="Retry failed")) + except Exception as e: + self.logger.error(f"Error retrying message {event_id}: {e}") + failed += 1 + details.append(DLQRetryResult(event_id=event_id, status="failed", error=str(e))) + + return DLQBatchRetryResult(total=len(event_ids), successful=successful, failed=failed, details=details) + def create_dlq_manager( settings: Settings, diff --git a/backend/app/domain/enums/__init__.py b/backend/app/domain/enums/__init__.py index 50722b89..f37aac67 100644 --- a/backend/app/domain/enums/__init__.py +++ b/backend/app/domain/enums/__init__.py @@ -7,6 +7,7 @@ NotificationStatus, ) from app.domain.enums.saga import SagaState +from app.domain.enums.sse import SSEControlEvent, SSENotificationEvent from app.domain.enums.user import UserRole __all__ = [ @@ -26,6 +27,9 @@ "NotificationStatus", # Saga "SagaState", + # SSE + "SSEControlEvent", + "SSENotificationEvent", # User "UserRole", ] diff --git a/backend/app/domain/enums/sse.py b/backend/app/domain/enums/sse.py new file mode 100644 index 00000000..7c7c1a03 --- /dev/null +++ b/backend/app/domain/enums/sse.py @@ -0,0 +1,21 @@ +from app.core.utils import StringEnum + + +class SSEControlEvent(StringEnum): + """Control events for execution SSE streams (not from Kafka).""" + + CONNECTED = "connected" + SUBSCRIBED = "subscribed" + HEARTBEAT = "heartbeat" + SHUTDOWN = "shutdown" + STATUS = "status" + ERROR = "error" + + +class SSENotificationEvent(StringEnum): + """Event types for notification SSE streams.""" + + CONNECTED = "connected" + SUBSCRIBED = "subscribed" + HEARTBEAT = "heartbeat" + NOTIFICATION = "notification" diff --git a/backend/app/events/admin_utils.py b/backend/app/events/admin_utils.py index a0a50679..ea924ade 100644 --- a/backend/app/events/admin_utils.py +++ b/backend/app/events/admin_utils.py @@ -4,18 +4,17 @@ from confluent_kafka.admin import AdminClient, NewTopic -from app.settings import get_settings +from app.settings import Settings class AdminUtils: """Minimal admin utilities using native AdminClient.""" - def __init__(self, logger: logging.Logger, bootstrap_servers: str | None = None): + def __init__(self, settings: Settings, logger: logging.Logger): self.logger = logger - settings = get_settings() self._admin = AdminClient( { - "bootstrap.servers": bootstrap_servers or settings.KAFKA_BOOTSTRAP_SERVERS, + "bootstrap.servers": settings.KAFKA_BOOTSTRAP_SERVERS, "client.id": "integr8scode-admin", } ) @@ -63,6 +62,6 @@ def get_admin_client(self) -> AdminClient: return self._admin -def create_admin_utils(logger: logging.Logger, bootstrap_servers: str | None = None) -> AdminUtils: +def create_admin_utils(settings: Settings, logger: logging.Logger) -> AdminUtils: """Create admin utilities.""" - return AdminUtils(logger, bootstrap_servers) + return AdminUtils(settings, logger) diff --git a/backend/app/events/consumer_group_monitor.py b/backend/app/events/consumer_group_monitor.py index 25e759cf..3ce95770 100644 --- a/backend/app/events/consumer_group_monitor.py +++ b/backend/app/events/consumer_group_monitor.py @@ -9,7 +9,7 @@ from app.core.utils import StringEnum from app.events.admin_utils import AdminUtils -from app.settings import get_settings +from app.settings import Settings class ConsumerGroupHealth(StringEnum): @@ -75,8 +75,8 @@ class NativeConsumerGroupMonitor: def __init__( self, + settings: Settings, logger: logging.Logger, - bootstrap_servers: str | None = None, client_id: str = "integr8scode-consumer-group-monitor", request_timeout_ms: int = 30000, # Health thresholds @@ -86,10 +86,9 @@ def __init__( min_members_threshold: int = 1, ): self.logger = logger - settings = get_settings() - self.bootstrap_servers = bootstrap_servers or settings.KAFKA_BOOTSTRAP_SERVERS + self.bootstrap_servers = settings.KAFKA_BOOTSTRAP_SERVERS - self.admin_client = AdminUtils(logger=logger, bootstrap_servers=self.bootstrap_servers) + self.admin_client = AdminUtils(settings=settings, logger=logger) # Health thresholds self.max_rebalance_time = max_rebalance_time_seconds @@ -434,6 +433,6 @@ def clear_cache(self) -> None: def create_consumer_group_monitor( - logger: logging.Logger, bootstrap_servers: str | None = None, **kwargs: Any + settings: Settings, logger: logging.Logger, **kwargs: Any ) -> NativeConsumerGroupMonitor: - return NativeConsumerGroupMonitor(logger=logger, bootstrap_servers=bootstrap_servers, **kwargs) + return NativeConsumerGroupMonitor(settings=settings, logger=logger, **kwargs) diff --git a/backend/app/events/core/producer.py b/backend/app/events/core/producer.py index b45858ea..defc65f9 100644 --- a/backend/app/events/core/producer.py +++ b/backend/app/events/core/producer.py @@ -15,7 +15,7 @@ from app.domain.enums.kafka import KafkaTopic from app.events.schema.schema_registry import SchemaRegistryManager from app.infrastructure.kafka.events import BaseEvent -from app.settings import get_settings +from app.settings import Settings from .types import ProducerConfig, ProducerMetrics, ProducerState @@ -29,11 +29,12 @@ class UnifiedProducer(LifecycleEnabled): def __init__( - self, - config: ProducerConfig, - schema_registry_manager: SchemaRegistryManager, - logger: logging.Logger, - stats_callback: StatsCallback | None = None, + self, + config: ProducerConfig, + schema_registry_manager: SchemaRegistryManager, + logger: logging.Logger, + settings: Settings, + stats_callback: StatsCallback | None = None, ): super().__init__() self._config = config @@ -45,8 +46,7 @@ def __init__( self._metrics = ProducerMetrics() self._event_metrics = get_event_metrics() # Singleton for Kafka metrics self._poll_task: asyncio.Task[None] | None = None - # Topic prefix (for tests/local isolation); cached on init - self._topic_prefix = get_settings().KAFKA_TOPIC_PREFIX + self._topic_prefix = settings.KAFKA_TOPIC_PREFIX @property def is_running(self) -> bool: @@ -173,7 +173,7 @@ async def _poll_loop(self) -> None: self.logger.info("Producer poll loop ended") async def produce( - self, event_to_produce: BaseEvent, key: str | None = None, headers: dict[str, str] | None = None + self, event_to_produce: BaseEvent, key: str | None = None, headers: dict[str, str] | None = None ) -> None: """ Produce a message to Kafka. @@ -206,7 +206,7 @@ async def produce( self.logger.debug(f"Message [{event_to_produce}] queued for topic: {topic}") async def send_to_dlq( - self, original_event: BaseEvent, original_topic: str, error: Exception, retry_count: int = 0 + self, original_event: BaseEvent, original_topic: str, error: Exception, retry_count: int = 0 ) -> None: """ Send a failed event to the Dead Letter Queue. diff --git a/backend/app/events/schema/schema_registry.py b/backend/app/events/schema/schema_registry.py index e149a192..fcc73eed 100644 --- a/backend/app/events/schema/schema_registry.py +++ b/backend/app/events/schema/schema_registry.py @@ -1,6 +1,5 @@ import json import logging -import os import struct from functools import lru_cache from typing import Any, Dict, Type, TypeVar @@ -60,7 +59,7 @@ def __init__(self, settings: Settings, logger: logging.Logger, schema_registry_u self.namespace = "com.integr8scode.events" # Optional per-session/worker subject prefix for tests/local isolation # e.g., "test..." -> subjects become "test.x.y.ExecutionRequestedEvent-value" - self.subject_prefix = os.getenv("SCHEMA_SUBJECT_PREFIX", "") + self.subject_prefix = settings.SCHEMA_SUBJECT_PREFIX config = {"url": self.url} if settings.SCHEMA_REGISTRY_AUTH: diff --git a/backend/app/main.py b/backend/app/main.py index 466df23e..52af39bb 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -31,14 +31,30 @@ from app.core.dishka_lifespan import lifespan from app.core.exceptions import configure_exception_handlers from app.core.logging import setup_logger +from app.core.metrics import ( + ConnectionMetrics, + CoordinatorMetrics, + DatabaseMetrics, + DLQMetrics, + EventMetrics, + ExecutionMetrics, + HealthMetrics, + KubernetesMetrics, + NotificationMetrics, + RateLimitMetrics, + ReplayMetrics, + SecurityMetrics, +) +from app.core.metrics.context import MetricsContext from app.core.middlewares import ( CacheControlMiddleware, + CSRFMiddleware, MetricsMiddleware, RateLimitMiddleware, RequestSizeLimitMiddleware, setup_metrics, ) -from app.settings import Settings, get_settings +from app.settings import Settings def create_app(settings: Settings | None = None) -> FastAPI: @@ -47,10 +63,28 @@ def create_app(settings: Settings | None = None) -> FastAPI: Args: settings: Optional pre-configured settings (e.g., TestSettings for testing). - If None, uses get_settings() which reads from .env. + If None, creates Settings() which reads from env vars then .env file. """ - settings = settings or get_settings() + settings = settings or Settings() logger = setup_logger(settings.LOG_LEVEL) + + # Initialize metrics context for all services + MetricsContext.initialize_all( + logger, + connection=ConnectionMetrics(settings), + coordinator=CoordinatorMetrics(settings), + database=DatabaseMetrics(settings), + dlq=DLQMetrics(settings), + event=EventMetrics(settings), + execution=ExecutionMetrics(settings), + health=HealthMetrics(settings), + kubernetes=KubernetesMetrics(settings), + notification=NotificationMetrics(settings), + rate_limit=RateLimitMetrics(settings), + replay=ReplayMetrics(settings), + security=SecurityMetrics(settings), + ) + # Disable OpenAPI/Docs in production for security; health endpoints provide readiness app = FastAPI( title=settings.PROJECT_NAME, @@ -63,11 +97,12 @@ def create_app(settings: Settings | None = None) -> FastAPI: container = create_app_container(settings) setup_dishka(container, app) - setup_metrics(app, logger) + setup_metrics(app, settings, logger) app.add_middleware(MetricsMiddleware) if settings.RATE_LIMIT_ENABLED: app.add_middleware(RateLimitMiddleware) + app.add_middleware(CSRFMiddleware, container=container) app.add_middleware(CorrelationMiddleware) app.add_middleware(RequestSizeLimitMiddleware) app.add_middleware(CacheControlMiddleware) @@ -127,10 +162,8 @@ def create_app(settings: Settings | None = None) -> FastAPI: return app -app = create_app() - if __name__ == "__main__": - settings = get_settings() + settings = Settings() logger = setup_logger(settings.LOG_LEVEL) logger.info( "Starting uvicorn server", @@ -144,7 +177,8 @@ def create_app(settings: Settings | None = None) -> FastAPI: }, ) uvicorn.run( - app, + "app.main:create_app", + factory=True, host=settings.SERVER_HOST, port=settings.SERVER_PORT, ssl_keyfile=settings.SSL_KEYFILE, diff --git a/backend/app/schemas_pydantic/execution.py b/backend/app/schemas_pydantic/execution.py index 843ca75f..12226be0 100644 --- a/backend/app/schemas_pydantic/execution.py +++ b/backend/app/schemas_pydantic/execution.py @@ -9,7 +9,7 @@ from app.domain.enums.events import EventType from app.domain.enums.execution import ExecutionStatus from app.domain.enums.storage import ExecutionErrorType -from app.settings import get_settings +from app.runtime_registry import SUPPORTED_RUNTIMES class ExecutionBase(BaseModel): @@ -74,9 +74,8 @@ class ExecutionRequest(BaseModel): @model_validator(mode="after") def validate_runtime_supported(self) -> "ExecutionRequest": # noqa: D401 - runtimes = get_settings().SUPPORTED_RUNTIMES - if not (lang_info := runtimes.get(self.lang)): - raise ValueError(f"Language '{self.lang}' not supported. Supported: {list(runtimes.keys())}") + if not (lang_info := SUPPORTED_RUNTIMES.get(self.lang)): + raise ValueError(f"Language '{self.lang}' not supported. Supported: {list(SUPPORTED_RUNTIMES.keys())}") if self.lang_version not in lang_info.versions: raise ValueError( f"Version '{self.lang_version}' not supported for {self.lang}. Supported: {lang_info.versions}" diff --git a/backend/app/schemas_pydantic/sse.py b/backend/app/schemas_pydantic/sse.py index 489156db..c420b209 100644 --- a/backend/app/schemas_pydantic/sse.py +++ b/backend/app/schemas_pydantic/sse.py @@ -1,16 +1,14 @@ from datetime import datetime -from typing import Any, Dict, Literal, TypeVar +from typing import Any, Dict, TypeVar from pydantic import BaseModel, Field from app.domain.enums.events import EventType from app.domain.enums.execution import ExecutionStatus from app.domain.enums.notification import NotificationSeverity, NotificationStatus +from app.domain.enums.sse import SSEControlEvent, SSENotificationEvent from app.schemas_pydantic.execution import ExecutionResult, ResourceUsage -# Control event types sent by SSE (not from Kafka) -SSEControlEventType = Literal["connected", "heartbeat", "shutdown", "status", "error"] - # Type variable for generic Redis message parsing T = TypeVar("T", bound=BaseModel) @@ -24,7 +22,7 @@ class SSEExecutionEventData(BaseModel): """ # Always present - identifies the event - event_type: EventType | SSEControlEventType = Field( + event_type: EventType | SSEControlEvent = Field( description="Event type identifier (business event or control event)" ) execution_id: str = Field(description="Execution ID this event relates to") @@ -34,9 +32,6 @@ class SSEExecutionEventData(BaseModel): # Present in business events from Kafka event_id: str | None = Field(default=None, description="Unique event identifier") - type: EventType | SSEControlEventType | None = Field( - default=None, description="Event type (legacy field, same as event_type)" - ) # Control event specific fields connection_id: str | None = Field(default=None, description="SSE connection ID (connected event)") @@ -68,10 +63,6 @@ class RedisSSEMessage(BaseModel): data: Dict[str, Any] = Field(description="Full event data from BaseEvent.model_dump()") -# Control event types for notification SSE stream -SSENotificationControlEventType = Literal["connected", "heartbeat", "notification"] - - class SSENotificationEventData(BaseModel): """Typed model for SSE notification stream event payload. @@ -79,9 +70,7 @@ class SSENotificationEventData(BaseModel): """ # Always present - identifies the event type - event_type: SSENotificationControlEventType = Field( - description="Event type identifier (connected, heartbeat, or notification)" - ) + event_type: SSENotificationEvent = Field(description="SSE notification event type") # Present in control events (connected, heartbeat) user_id: str | None = Field(default=None, description="User ID for the notification stream") diff --git a/backend/app/services/admin/admin_user_service.py b/backend/app/services/admin/admin_user_service.py index 88e2e8e2..d8975de9 100644 --- a/backend/app/services/admin/admin_user_service.py +++ b/backend/app/services/admin/admin_user_service.py @@ -23,12 +23,14 @@ def __init__( event_service: EventService, execution_service: ExecutionService, rate_limit_service: RateLimitService, + security_service: SecurityService, logger: logging.Logger, ) -> None: self._users = user_repository self._events = event_service self._executions = execution_service self._rate_limits = rate_limit_service + self._security = security_service self.logger = logger async def get_user_overview(self, user_id: str, hours: int = 24) -> AdminUserOverviewDomain: @@ -126,8 +128,7 @@ async def create_user(self, *, admin_username: str, user_data: UserCreate) -> Us if user.username == user_data.username: raise ValueError("Username already exists") - security = SecurityService() - hashed_password = security.get_password_hash(user_data.password) + hashed_password = self._security.get_password_hash(user_data.password) create_data = DomainUserCreate( username=user_data.username, diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py index 50e6f98f..dd628673 100644 --- a/backend/app/services/auth_service.py +++ b/backend/app/services/auth_service.py @@ -2,7 +2,7 @@ from fastapi import Request -from app.core.security import security_service +from app.core.security import SecurityService from app.db.repositories.user_repository import UserRepository from app.domain.enums.user import UserRole from app.domain.user import AdminAccessRequiredError, AuthenticationRequiredError @@ -10,8 +10,9 @@ class AuthService: - def __init__(self, user_repo: UserRepository, logger: logging.Logger): + def __init__(self, user_repo: UserRepository, security_service: SecurityService, logger: logging.Logger): self.user_repo = user_repo + self.security_service = security_service self.logger = logger async def get_current_user(self, request: Request) -> UserResponse: @@ -19,7 +20,7 @@ async def get_current_user(self, request: Request) -> UserResponse: if not token: raise AuthenticationRequiredError() - user = await security_service.get_current_user(token, self.user_repo) + user = await self.security_service.get_current_user(token, self.user_repo) return UserResponse( user_id=user.user_id, diff --git a/backend/app/services/event_bus.py b/backend/app/services/event_bus.py index 93208612..a1ea1107 100644 --- a/backend/app/services/event_bus.py +++ b/backend/app/services/event_bus.py @@ -13,7 +13,7 @@ from app.core.lifecycle import LifecycleEnabled from app.core.metrics.context import get_connection_metrics from app.domain.enums.kafka import KafkaTopic -from app.settings import get_settings +from app.settings import Settings @dataclass @@ -45,10 +45,10 @@ class EventBus(LifecycleEnabled): - *.completed - matches all completed events """ - def __init__(self, logger: logging.Logger) -> None: + def __init__(self, settings: Settings, logger: logging.Logger) -> None: super().__init__() self.logger = logger - self.settings = get_settings() + self.settings = settings self.metrics = get_connection_metrics() self.producer: Optional[Producer] = None self.consumer: Optional[Consumer] = None @@ -323,7 +323,8 @@ async def get_statistics(self) -> dict[str, Any]: class EventBusManager: """Manages EventBus lifecycle as a singleton.""" - def __init__(self, logger: logging.Logger) -> None: + def __init__(self, settings: Settings, logger: logging.Logger) -> None: + self.settings = settings self.logger = logger self._event_bus: Optional[EventBus] = None self._lock = asyncio.Lock() @@ -332,7 +333,7 @@ async def get_event_bus(self) -> EventBus: """Get or create the event bus instance.""" async with self._lock: if self._event_bus is None: - self._event_bus = EventBus(self.logger) + self._event_bus = EventBus(self.settings, self.logger) await self._event_bus.__aenter__() return self._event_bus diff --git a/backend/app/services/event_replay/replay_service.py b/backend/app/services/event_replay/replay_service.py index 837d8147..856cdea6 100644 --- a/backend/app/services/event_replay/replay_service.py +++ b/backend/app/services/event_replay/replay_service.py @@ -17,6 +17,7 @@ from app.events.core import UnifiedProducer from app.events.event_store import EventStore from app.infrastructure.kafka.events.base import BaseEvent +from app.settings import Settings class EventReplayService: @@ -25,6 +26,7 @@ def __init__( repository: ReplayRepository, producer: UnifiedProducer, event_store: EventStore, + settings: Settings, logger: logging.Logger, ) -> None: self._sessions: Dict[str, ReplaySessionState] = {} @@ -35,7 +37,7 @@ def __init__( self.logger = logger self._callbacks: Dict[ReplayTarget, Callable[..., Any]] = {} self._file_locks: Dict[str, asyncio.Lock] = {} - self._metrics = ReplayMetrics() + self._metrics = ReplayMetrics(settings) self.logger.info("Event replay service initialized") async def create_replay_session(self, config: ReplayConfig) -> str: diff --git a/backend/app/services/k8s_worker/config.py b/backend/app/services/k8s_worker/config.py index dad0f6f3..ebc2c953 100644 --- a/backend/app/services/k8s_worker/config.py +++ b/backend/app/services/k8s_worker/config.py @@ -7,7 +7,6 @@ @dataclass class K8sWorkerConfig: # Kafka settings - kafka_bootstrap_servers: str | None = None consumer_group: str = "kubernetes-worker" topics: list[KafkaTopic] = field(default_factory=lambda: [KafkaTopic.EXECUTION_TASKS]) diff --git a/backend/app/services/k8s_worker/worker.py b/backend/app/services/k8s_worker/worker.py index 8bad97c2..177c9e46 100644 --- a/backend/app/services/k8s_worker/worker.py +++ b/backend/app/services/k8s_worker/worker.py @@ -58,12 +58,12 @@ def __init__( ): super().__init__() self.logger = logger - self.metrics = KubernetesMetrics() - self.execution_metrics = ExecutionMetrics() + self.metrics = KubernetesMetrics(settings) + self.execution_metrics = ExecutionMetrics(settings) self.config = config or K8sWorkerConfig() self._settings = settings - self.kafka_servers = self.config.kafka_bootstrap_servers or self._settings.KAFKA_BOOTSTRAP_SERVERS + self.kafka_servers = self._settings.KAFKA_BOOTSTRAP_SERVERS self._event_store = event_store # Kubernetes clients diff --git a/backend/app/services/kafka_event_service.py b/backend/app/services/kafka_event_service.py index 26bdbd21..83fe6d20 100644 --- a/backend/app/services/kafka_event_service.py +++ b/backend/app/services/kafka_event_service.py @@ -17,18 +17,24 @@ from app.infrastructure.kafka.events.base import BaseEvent from app.infrastructure.kafka.events.metadata import AvroEventMetadata from app.infrastructure.kafka.mappings import get_event_class_for_type -from app.settings import get_settings +from app.settings import Settings tracer = trace.get_tracer(__name__) class KafkaEventService: - def __init__(self, event_repository: EventRepository, kafka_producer: UnifiedProducer, logger: logging.Logger): + def __init__( + self, + event_repository: EventRepository, + kafka_producer: UnifiedProducer, + settings: Settings, + logger: logging.Logger, + ): self.event_repository = event_repository self.kafka_producer = kafka_producer self.logger = logger self.metrics = get_event_metrics() - self.settings = get_settings() + self.settings = settings async def publish_event( self, diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index 084de72c..4d709b31 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -805,13 +805,13 @@ async def _publish_notification_sse(self, notification: DomainNotification) -> N await self.sse_bus.publish_notification(notification.user_id, message) async def _should_skip_notification( - self, notification: DomainNotification, subscription: DomainNotificationSubscription | None + self, notification: DomainNotification, subscription: DomainNotificationSubscription ) -> str | None: """Check if notification should be skipped based on subscription filters. Returns skip reason if should skip, None otherwise. """ - if not subscription or not subscription.enabled: + if not subscription.enabled: return f"User {notification.user_id} has {notification.channel} disabled; skipping delivery." if subscription.severities and notification.severity not in subscription.severities: @@ -860,9 +860,6 @@ async def _deliver_notification(self, notification: DomainNotification) -> None: ) return - # At this point, subscription is guaranteed to be non-None (checked in _should_skip_notification) - assert subscription is not None - # Send through channel start_time = asyncio.get_running_loop().time() try: diff --git a/backend/app/services/pod_monitor/config.py b/backend/app/services/pod_monitor/config.py index 686ad406..44159037 100644 --- a/backend/app/services/pod_monitor/config.py +++ b/backend/app/services/pod_monitor/config.py @@ -11,7 +11,6 @@ class PodMonitorConfig: """Configuration for PodMonitor service""" # Kafka settings - kafka_bootstrap_servers: str | None = None pod_events_topic: str = str(get_topic_for_event(EventType.POD_CREATED)) execution_events_topic: str = str(get_topic_for_event(EventType.EXECUTION_REQUESTED)) execution_completed_topic: str = str(get_topic_for_event(EventType.EXECUTION_COMPLETED)) diff --git a/backend/app/services/pod_monitor/monitor.py b/backend/app/services/pod_monitor/monitor.py index bdc61583..095d8449 100644 --- a/backend/app/services/pod_monitor/monitor.py +++ b/backend/app/services/pod_monitor/monitor.py @@ -8,11 +8,9 @@ from typing import Any from kubernetes import client as k8s_client -from kubernetes import config as k8s_config -from kubernetes import watch from kubernetes.client.rest import ApiException -from app.core.k8s_clients import K8sClients +from app.core.k8s_clients import K8sClients, close_k8s_clients, create_k8s_clients from app.core.lifecycle import LifecycleEnabled from app.core.metrics.context import get_kubernetes_metrics from app.core.utils import StringEnum @@ -104,20 +102,25 @@ def __init__( config: PodMonitorConfig, kafka_event_service: KafkaEventService, logger: logging.Logger, - k8s_clients: K8sClients | None = None, + k8s_clients: K8sClients, + event_mapper: PodEventMapper, ) -> None: - """Initialize the pod monitor.""" + """Initialize the pod monitor with all required dependencies. + + All dependencies must be provided - use create_pod_monitor() factory + for automatic dependency creation in production. + """ super().__init__() self.logger = logger - self.config = config or PodMonitorConfig() + self.config = config - # Kubernetes clients (initialized on start) - self._v1: k8s_client.CoreV1Api | None = None - self._watch: watch.Watch | None = None - self._clients: K8sClients | None = k8s_clients + # Kubernetes clients (required, no nullability) + self._clients = k8s_clients + self._v1 = k8s_clients.v1 + self._watch = k8s_clients.watch - # Components - self._event_mapper = PodEventMapper(logger=self.logger) + # Components (required, no nullability) + self._event_mapper = event_mapper self._kafka_event_service = kafka_event_service # State @@ -142,8 +145,9 @@ async def _on_start(self) -> None: """Start the pod monitor.""" self.logger.info("Starting PodMonitor service...") - # Initialize components - self._initialize_kubernetes_client() + # Verify K8s connectivity (all clients already injected via __init__) + await asyncio.to_thread(self._v1.get_api_resources) + self.logger.info("Successfully connected to Kubernetes API") # Start monitoring self._state = MonitorState.RUNNING @@ -180,34 +184,6 @@ async def _on_stop(self) -> None: self._state = MonitorState.STOPPED self.logger.info("PodMonitor service stopped") - def _initialize_kubernetes_client(self) -> None: - """Initialize Kubernetes API clients.""" - if self._clients is None: - match (self.config.in_cluster, self.config.kubeconfig_path): - case (True, _): - self.logger.info("Using in-cluster Kubernetes configuration") - k8s_config.load_incluster_config() - case (False, path) if path: - self.logger.info(f"Using kubeconfig from {path}") - k8s_config.load_kube_config(config_file=path) - case _: - self.logger.info("Using default kubeconfig") - k8s_config.load_kube_config() - - configuration = k8s_client.Configuration.get_default_copy() - self.logger.info(f"Kubernetes API host: {configuration.host}") - self.logger.info(f"SSL CA cert configured: {configuration.ssl_ca_cert is not None}") - - api_client = k8s_client.ApiClient(configuration) - self._v1 = k8s_client.CoreV1Api(api_client) - else: - self._v1 = self._clients.v1 - - self._watch = watch.Watch() - self._v1.get_api_resources() - self.logger.info("Successfully connected to Kubernetes API") - self._event_mapper = PodEventMapper(logger=self.logger, k8s_api=self._v1) - async def _watch_pods(self) -> None: """Main watch loop for pods.""" while self._state == MonitorState.RUNNING: @@ -259,10 +235,7 @@ async def _watch_pod_events(self) -> None: if context.resource_version: kwargs["resource_version"] = context.resource_version - # Watch stream - if not self._watch or not self._v1: - raise RuntimeError("Watch or API not initialized") - + # Watch stream (clients guaranteed by __init__) stream = self._watch.stream(self._v1.list_namespaced_pod, **kwargs) try: @@ -405,24 +378,12 @@ async def _reconciliation_loop(self) -> None: async def _reconcile_state(self) -> ReconciliationResult: """Reconcile tracked pods with actual state.""" - # self._v1 is guaranteed initialized by start() - start_time = time.time() try: self.logger.info("Starting pod state reconciliation") - # List all pods matching selector - if not self._v1: - self.logger.warning("K8s API not initialized, skipping reconciliation") - return ReconciliationResult( - missing_pods=set(), - extra_pods=set(), - duration_seconds=time.time() - start_time, - success=False, - error="K8s API not initialized", - ) - + # List all pods matching selector (clients guaranteed by __init__) pods = await asyncio.to_thread( self._v1.list_namespaced_pod, namespace=self.config.namespace, label_selector=self.config.label_selector ) @@ -502,14 +463,39 @@ async def create_pod_monitor( kafka_event_service: KafkaEventService, logger: logging.Logger, k8s_clients: K8sClients | None = None, + event_mapper: PodEventMapper | None = None, ) -> AsyncIterator[PodMonitor]: - """Create and manage a pod monitor instance.""" + """Create and manage a pod monitor instance. + + This factory handles production dependency creation: + - Creates K8sClients if not provided (using config settings) + - Creates PodEventMapper if not provided + - Cleans up created K8sClients on exit + """ + # Track whether we created clients (so we know to close them) + owns_clients = k8s_clients is None + + if k8s_clients is None: + k8s_clients = create_k8s_clients( + logger=logger, + kubeconfig_path=config.kubeconfig_path, + in_cluster=config.in_cluster, + ) + + if event_mapper is None: + event_mapper = PodEventMapper(logger=logger, k8s_api=k8s_clients.v1) + monitor = PodMonitor( config=config, kafka_event_service=kafka_event_service, logger=logger, k8s_clients=k8s_clients, + event_mapper=event_mapper, ) - async with monitor: - yield monitor + try: + async with monitor: + yield monitor + finally: + if owns_clients: + close_k8s_clients(k8s_clients) diff --git a/backend/app/services/sse/kafka_redis_bridge.py b/backend/app/services/sse/kafka_redis_bridge.py index 478e3420..43473556 100644 --- a/backend/app/services/sse/kafka_redis_bridge.py +++ b/backend/app/services/sse/kafka_redis_bridge.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import os from app.core.lifecycle import LifecycleEnabled from app.core.metrics.events import EventMetrics @@ -62,13 +61,9 @@ async def _on_stop(self) -> None: self.logger.info("SSE Kafka→Redis bridge stopped") async def _create_consumer(self, consumer_index: int) -> UnifiedConsumer: - suffix = os.environ.get("KAFKA_GROUP_SUFFIX", "") - group_id = "sse-bridge-pool" - if suffix: - group_id = f"{group_id}.{suffix}" - client_id = f"sse-bridge-{consumer_index}" - if suffix: - client_id = f"{client_id}-{suffix}" + suffix = self.settings.KAFKA_GROUP_SUFFIX + group_id = f"sse-bridge-pool.{suffix}" + client_id = f"sse-bridge-{consumer_index}.{suffix}" config = ConsumerConfig( bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, diff --git a/backend/app/services/sse/sse_service.py b/backend/app/services/sse/sse_service.py index 3608ec2e..19055a7b 100644 --- a/backend/app/services/sse/sse_service.py +++ b/backend/app/services/sse/sse_service.py @@ -7,6 +7,7 @@ from app.core.metrics.context import get_connection_metrics from app.db.repositories.sse_repository import SSERepository from app.domain.enums.events import EventType +from app.domain.enums.sse import SSEControlEvent, SSENotificationEvent from app.domain.sse import SSEHealthDomain from app.schemas_pydantic.execution import ExecutionResult from app.schemas_pydantic.sse import ( @@ -55,7 +56,7 @@ async def create_execution_stream(self, execution_id: str, user_id: str) -> Asyn if shutdown_event is None: yield self._format_sse_event( SSEExecutionEventData( - event_type="error", + event_type=SSEControlEvent.ERROR, execution_id=execution_id, timestamp=datetime.now(timezone.utc).isoformat(), error="Server is shutting down", @@ -69,7 +70,7 @@ async def create_execution_stream(self, execution_id: str, user_id: str) -> Asyn sub_task = asyncio.create_task(self.sse_bus.open_subscription(execution_id)) yield self._format_sse_event( SSEExecutionEventData( - event_type="connected", + event_type=SSEControlEvent.CONNECTED, execution_id=execution_id, timestamp=datetime.now(timezone.utc).isoformat(), connection_id=connection_id, @@ -81,11 +82,21 @@ async def create_execution_stream(self, execution_id: str, user_id: str) -> Asyn subscription = await sub_task self.logger.info("Redis subscription opened for execution", extra={"execution_id": execution_id}) + # Signal that subscription is ready - safe to publish events now + yield self._format_sse_event( + SSEExecutionEventData( + event_type=SSEControlEvent.SUBSCRIBED, + execution_id=execution_id, + timestamp=datetime.now(timezone.utc).isoformat(), + message="Redis subscription established", + ) + ) + initial_status = await self.repository.get_execution_status(execution_id) if initial_status: yield self._format_sse_event( SSEExecutionEventData( - event_type="status", + event_type=SSEControlEvent.STATUS, execution_id=initial_status.execution_id, timestamp=initial_status.timestamp, status=initial_status.status, @@ -119,7 +130,7 @@ async def _stream_events_redis( if shutdown_event.is_set(): yield self._format_sse_event( SSEExecutionEventData( - event_type="shutdown", + event_type=SSEControlEvent.SHUTDOWN, execution_id=execution_id, timestamp=datetime.now(timezone.utc).isoformat(), message="Server is shutting down", @@ -132,7 +143,7 @@ async def _stream_events_redis( if include_heartbeat and (now - last_heartbeat).total_seconds() >= self.heartbeat_interval: yield self._format_sse_event( SSEExecutionEventData( - event_type="heartbeat", + event_type=SSEControlEvent.HEARTBEAT, execution_id=execution_id, timestamp=now.isoformat(), message="SSE connection active", @@ -179,7 +190,6 @@ async def _build_sse_event_from_redis(self, execution_id: str, msg: RedisSSEMess **msg.data, "event_type": msg.event_type, "execution_id": execution_id, - "type": msg.event_type, "result": result, } ) @@ -192,7 +202,7 @@ async def create_notification_stream(self, user_id: str) -> AsyncGenerator[Dict[ sub_task = asyncio.create_task(self.sse_bus.open_notification_subscription(user_id)) yield self._format_notification_event( SSENotificationEventData( - event_type="connected", + event_type=SSENotificationEvent.CONNECTED, user_id=user_id, timestamp=datetime.now(timezone.utc).isoformat(), message="Connected to notification stream", @@ -202,6 +212,16 @@ async def create_notification_stream(self, user_id: str) -> AsyncGenerator[Dict[ # Complete Redis subscription after handshake subscription = await sub_task + # Signal that subscription is ready - safe to publish notifications now + yield self._format_notification_event( + SSENotificationEventData( + event_type=SSENotificationEvent.SUBSCRIBED, + user_id=user_id, + timestamp=datetime.now(timezone.utc).isoformat(), + message="Redis subscription established", + ) + ) + last_heartbeat = datetime.now(timezone.utc) while not self.shutdown_manager.is_shutting_down(): # Heartbeat @@ -209,7 +229,7 @@ async def create_notification_stream(self, user_id: str) -> AsyncGenerator[Dict[ if (now - last_heartbeat).total_seconds() >= self.heartbeat_interval: yield self._format_notification_event( SSENotificationEventData( - event_type="heartbeat", + event_type=SSENotificationEvent.HEARTBEAT, user_id=user_id, timestamp=now.isoformat(), message="Notification stream active", @@ -222,7 +242,7 @@ async def create_notification_stream(self, user_id: str) -> AsyncGenerator[Dict[ if redis_msg: yield self._format_notification_event( SSENotificationEventData( - event_type="notification", + event_type=SSENotificationEvent.NOTIFICATION, notification_id=redis_msg.notification_id, severity=redis_msg.severity, status=redis_msg.status, diff --git a/backend/app/services/sse/sse_shutdown_manager.py b/backend/app/services/sse/sse_shutdown_manager.py index dcc74873..1e29b60a 100644 --- a/backend/app/services/sse/sse_shutdown_manager.py +++ b/backend/app/services/sse/sse_shutdown_manager.py @@ -4,9 +4,9 @@ from enum import Enum from typing import Dict, Set +from app.core.lifecycle import LifecycleEnabled from app.core.metrics.context import get_connection_metrics from app.domain.sse import ShutdownStatus -from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge class ShutdownPhase(Enum): @@ -57,7 +57,7 @@ def __init__( self._draining_connections: Set[str] = set() # Router reference (set during initialization) - self._router: SSEKafkaRedisBridge | None = None + self._router: LifecycleEnabled | None = None # Synchronization self._lock = asyncio.Lock() @@ -69,7 +69,7 @@ def __init__( extra={"drain_timeout": drain_timeout, "notification_timeout": notification_timeout}, ) - def set_router(self, router: "SSEKafkaRedisBridge") -> None: + def set_router(self, router: LifecycleEnabled) -> None: """Set the router reference for shutdown coordination.""" self._router = router diff --git a/backend/app/settings.py b/backend/app/settings.py index 6e80b55f..89bcad33 100644 --- a/backend/app/settings.py +++ b/backend/app/settings.py @@ -1,4 +1,4 @@ -from functools import lru_cache +import os from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict @@ -44,6 +44,9 @@ class Settings(BaseSettings): TESTING: bool = False + # Security: bcrypt rounds (lower in tests for speed, higher in production for security) + BCRYPT_ROUNDS: int = 12 + # Event-Driven Design Configuration KAFKA_BOOTSTRAP_SERVERS: str = "kafka:29092" KAFKA_GROUP_SUFFIX: str = "suff" # Suffix to append to consumer group IDs for test/parallel isolation @@ -73,6 +76,7 @@ class Settings(BaseSettings): SCHEMA_BASE_PATH: str = "app/schemas_avro" SCHEMA_AVRO_PATH: str = "app/schemas_avro" SCHEMA_CONFIG_PATH: str | None = None + SCHEMA_SUBJECT_PREFIX: str = "" # OpenTelemetry / Jaeger Configuration ENABLE_TRACING: bool = True @@ -126,6 +130,7 @@ class Settings(BaseSettings): # Service metadata SERVICE_NAME: str = "integr8scode-backend" SERVICE_VERSION: str = "1.0.0" + ENVIRONMENT: str = "production" # deployment environment (production, staging, development) # OpenTelemetry Configuration OTEL_EXPORTER_OTLP_ENDPOINT: str | None = None @@ -153,13 +158,8 @@ class Settings(BaseSettings): LOG_LEVEL: str = Field(default="DEBUG", description="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)") model_config = SettingsConfigDict( - env_file=".env", + env_file=os.environ.get("DOTENV_PATH", ".env"), env_file_encoding="utf-8", case_sensitive=True, extra="forbid", # Raise error on extra fields ) - - -@lru_cache(maxsize=1) -def get_settings() -> Settings: - return Settings() # type: ignore[call-arg] diff --git a/backend/pyproject.toml b/backend/pyproject.toml index b8a3b5ec..6726f5db 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -141,10 +141,10 @@ dev = [ "mypy_extensions==1.1.0", "pipdeptree==2.23.4", "pluggy==1.5.0", - "pytest==8.3.3", + "pytest==8.4.2", "pytest-asyncio==1.3.0", "pytest-cov==5.0.0", - "pytest-env>=1.1.5", + "pytest-env==1.2.0", "pytest-xdist==3.6.1", "ruff==0.14.10", "types-cachetools==6.2.0.20250827", @@ -182,8 +182,12 @@ warn_unused_configs = true disallow_untyped_defs = true disallow_incomplete_defs = true disable_error_code = ["import-untyped", "import-not-found"] -# TODO: REMOVE NEXT LINE -exclude = '(^tests/|/tests/)' +plugins = ["pydantic.mypy"] + +[tool.pydantic-mypy] +init_forbid_extra = true +init_typed = true +warn_required_dynamic_aliases = true # Pytest configuration [tool.pytest.ini_options] @@ -209,8 +213,11 @@ asyncio_default_test_loop_scope = "session" log_cli = false log_cli_level = "ERROR" log_level = "ERROR" -addopts = "-n 4 --dist loadfile --tb=short -q --no-header -q" -env = ["OTEL_SDK_DISABLED=true"] +addopts = "--tb=short -n auto --dist=loadfile" + +# pytest-env: Set DOTENV_PATH so Settings loads .env.test instead of .env +[tool.pytest_env] +DOTENV_PATH = ".env.test" # Coverage configuration [tool.coverage.run] diff --git a/backend/scripts/__init__.py b/backend/scripts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/scripts/create_topics.py b/backend/scripts/create_topics.py index 0620a010..6b473e25 100755 --- a/backend/scripts/create_topics.py +++ b/backend/scripts/create_topics.py @@ -10,16 +10,15 @@ from app.core.logging import setup_logger from app.infrastructure.kafka.topics import get_all_topics, get_topic_configs -from app.settings import get_settings +from app.settings import Settings from confluent_kafka import KafkaException from confluent_kafka.admin import AdminClient, NewTopic logger = setup_logger(os.environ.get("LOG_LEVEL", "INFO")) -async def create_topics() -> None: - """Create all required Kafka topics""" - settings = get_settings() +async def create_topics(settings: Settings) -> None: + """Create all required Kafka topics using provided settings.""" # Create admin client admin_client = AdminClient( @@ -103,11 +102,11 @@ async def create_topics() -> None: async def main() -> None: - """Main entry point""" + """Main entry point - creates Settings() which reads from env vars then .env file.""" logger.info("Starting Kafka topic creation...") try: - await create_topics() + await create_topics(Settings()) logger.info("Topic creation completed successfully") except Exception as e: logger.error(f"Topic creation failed: {e}") diff --git a/backend/scripts/seed_users.py b/backend/scripts/seed_users.py index a8450954..f8e33422 100755 --- a/backend/scripts/seed_users.py +++ b/backend/scripts/seed_users.py @@ -6,8 +6,7 @@ 1. Default user (role=user) for testing/demo 2. Admin user (role=admin, is_superuser=True) for administration -Environment Variables: - MONGODB_URL: Connection string (default: mongodb://mongo:27017/integr8scode) +Uses main Settings for MongoDB connection. Password env vars are script-specific: DEFAULT_USER_PASSWORD: Default user password (default: user123) ADMIN_USER_PASSWORD: Admin user password (default: admin123) """ @@ -17,6 +16,7 @@ from datetime import datetime, timezone from typing import Any +from app.settings import Settings from bson import ObjectId from passlib.context import CryptContext from pymongo.asynchronous.database import AsyncDatabase @@ -68,21 +68,22 @@ async def upsert_user( ) -async def seed_users() -> None: - mongodb_url = os.getenv("MONGODB_URL", "mongodb://mongo:27017/integr8scode") - db_name = os.getenv("DATABASE_NAME", "integr8scode_db") +async def seed_users(settings: Settings) -> None: + """Seed default users using provided settings for MongoDB connection.""" + default_password = os.environ.get("DEFAULT_USER_PASSWORD", "user123") + admin_password = os.environ.get("ADMIN_USER_PASSWORD", "admin123") - print(f"Connecting to MongoDB (database: {db_name})...") + print(f"Connecting to MongoDB (database: {settings.DATABASE_NAME})...") - client: AsyncMongoClient[dict[str, Any]] = AsyncMongoClient(mongodb_url) - db = client[db_name] + client: AsyncMongoClient[dict[str, Any]] = AsyncMongoClient(settings.MONGODB_URL) + db = client[settings.DATABASE_NAME] # Default user await upsert_user( db, username="user", email="user@integr8scode.com", - password=os.getenv("DEFAULT_USER_PASSWORD", "user123"), + password=default_password, role="user", is_superuser=False, ) @@ -92,7 +93,7 @@ async def seed_users() -> None: db, username="admin", email="admin@integr8scode.com", - password=os.getenv("ADMIN_USER_PASSWORD", "admin123"), + password=admin_password, role="admin", is_superuser=True, ) @@ -108,4 +109,4 @@ async def seed_users() -> None: if __name__ == "__main__": - asyncio.run(seed_users()) + asyncio.run(seed_users(Settings())) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 2b1b00a1..7b7c1a61 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,3 +1,4 @@ +import asyncio import os import uuid from contextlib import asynccontextmanager @@ -11,199 +12,216 @@ from app.main import create_app from app.settings import Settings from dishka import AsyncContainer +from fastapi import FastAPI from httpx import ASGITransport -from pydantic_settings import SettingsConfigDict - - -class TestSettings(Settings): - """Test configuration - loads from .env.test instead of .env""" - - model_config = SettingsConfigDict( - env_file=".env.test", - env_file_encoding="utf-8", - case_sensitive=True, - extra="ignore", - ) - +from scripts.create_topics import create_topics # ===== Worker-specific isolation for pytest-xdist ===== -def _compute_worker_id() -> str: - return os.environ.get("PYTEST_XDIST_WORKER", "gw0") - - -def _setup_worker_env() -> None: - """Set worker-specific environment variables for pytest-xdist isolation. - - Must be called BEFORE TestSettings is instantiated so env vars are picked up. - """ - session_id = os.environ.get("PYTEST_SESSION_ID") or uuid.uuid4().hex[:8] - worker_id = _compute_worker_id() - os.environ["PYTEST_SESSION_ID"] = session_id - - # Unique database name per worker - os.environ["DATABASE_NAME"] = f"integr8scode_test_{session_id}_{worker_id}" - - # Distribute Redis DBs across workers (0-15) - try: - worker_num = int(worker_id[2:]) if worker_id.startswith("gw") else 0 - os.environ["REDIS_DB"] = str(worker_num % 16) - except Exception: - os.environ.setdefault("REDIS_DB", "0") +# Redis has 16 DBs (0-15); each xdist worker gets one, limiting parallel workers to 16. +_WORKER_ID = os.environ.get("PYTEST_XDIST_WORKER", "gw0") +_WORKER_NUM = int(_WORKER_ID.removeprefix("gw") or "0") +assert _WORKER_NUM < 16, f"xdist worker {_WORKER_NUM} >= 16 exceeds Redis DB limit; use -n 16 or fewer" - # Unique Kafka consumer group per worker - os.environ["KAFKA_GROUP_SUFFIX"] = f"{session_id}.{worker_id}" - # Unique Schema Registry prefix per worker - os.environ["SCHEMA_SUBJECT_PREFIX"] = f"test.{session_id}.{worker_id}." +# ===== Pytest hooks ===== +@pytest.hookimpl(trylast=True) +def pytest_configure(config: pytest.Config) -> None: + """Create Kafka topics once before any tests run. - # Disable OpenTelemetry exporters to prevent "otel-collector:4317" retry noise - os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = "" - os.environ["OTEL_METRICS_EXPORTER"] = "none" - os.environ["OTEL_TRACES_EXPORTER"] = "none" - os.environ["OTEL_LOGS_EXPORTER"] = "none" + Uses trylast=True to ensure pytest-env has set DOTENV_PATH first. + Runs in master process before xdist workers spawn. - -# Set up worker env at module load time (before any Settings instantiation) -_setup_worker_env() + Silently skips if Kafka is unavailable (e.g., unit tests). + """ + # Only run in master process (not in xdist workers) + if not hasattr(config, "workerinput"): + try: + asyncio.run(create_topics(Settings())) + except Exception: + # Kafka not available (unit tests) - silently skip + pass # ===== Settings fixture ===== @pytest.fixture(scope="session") def test_settings() -> Settings: - """Provide TestSettings for tests that need to create their own components.""" - return TestSettings() + """Provide test settings with per-worker isolation where needed. + + pytest-env sets DOTENV_PATH=.env.test (configured in pyproject.toml). + Settings class uses this to load the correct env file via pydantic-settings. + + What gets isolated per worker (to prevent interference): + - DATABASE_NAME: Each worker gets its own MongoDB database + - REDIS_DB: Each worker gets its own Redis database (0-15) + - KAFKA_GROUP_SUFFIX: Each worker gets unique consumer groups + + What's SHARED (from env, no per-worker suffix): + - KAFKA_TOPIC_PREFIX: Topics created once by CI/scripts + - SCHEMA_SUBJECT_PREFIX: Schemas shared across workers + """ + base = Settings() # Uses DOTENV_PATH from pytest-env to load .env.test + session_id = uuid.uuid4().hex[:8] + return base.model_copy( + update={ + # Per-worker isolation + "DATABASE_NAME": f"integr8scode_test_{session_id}_{_WORKER_ID}", + "REDIS_DB": _WORKER_NUM, + "KAFKA_GROUP_SUFFIX": f"{session_id}.{_WORKER_ID}", + # Disable telemetry in tests + "OTEL_EXPORTER_OTLP_ENDPOINT": None, + "ENABLE_TRACING": False, + } + ) # ===== App fixture ===== @pytest_asyncio.fixture(scope="session") -async def app(): - """Create FastAPI app with TestSettings. +async def app(test_settings: Settings) -> AsyncGenerator[FastAPI, None]: + """Create FastAPI app with test settings and run lifespan. Session-scoped to avoid Pydantic schema validator memory issues when FastAPI recreates OpenAPI schemas hundreds of times with pytest-xdist. - """ - application = create_app(settings=TestSettings()) - yield application + Uses lifespan_context to trigger startup/shutdown events, which initializes + Beanie, metrics, and other services through the normal DI flow. + """ + application = create_app(settings=test_settings) - if hasattr(application.state, "dishka_container"): - await application.state.dishka_container.close() + async with application.router.lifespan_context(application): + yield application @pytest_asyncio.fixture(scope="session") -async def app_container(app): +async def app_container(app: FastAPI) -> AsyncContainer: """Expose the Dishka container attached to the app.""" container: AsyncContainer = app.state.dishka_container return container @pytest_asyncio.fixture -async def client(app) -> AsyncGenerator[httpx.AsyncClient, None]: +async def client(app: FastAPI) -> AsyncGenerator[httpx.AsyncClient, None]: """HTTP client for testing API endpoints.""" async with httpx.AsyncClient( - transport=ASGITransport(app=app), - base_url="https://test", - timeout=30.0, - follow_redirects=True, + transport=ASGITransport(app=app), + base_url="https://test", + timeout=30.0, + follow_redirects=True, ) as c: yield c @asynccontextmanager -async def _container_scope(container: AsyncContainer): +async def _container_scope(container: AsyncContainer) -> AsyncGenerator[AsyncContainer, None]: async with container() as scope: yield scope @pytest_asyncio.fixture -async def scope(app_container: AsyncContainer): +async def scope(app_container: AsyncContainer) -> AsyncGenerator[AsyncContainer, None]: async with _container_scope(app_container) as s: yield s @pytest_asyncio.fixture -async def db(scope) -> AsyncGenerator[Database, None]: +async def db(scope: AsyncContainer) -> AsyncGenerator[Database, None]: database: Database = await scope.get(Database) yield database @pytest_asyncio.fixture -async def redis_client(scope) -> AsyncGenerator[redis.Redis, None]: +async def redis_client(scope: AsyncContainer) -> AsyncGenerator[redis.Redis, None]: + # Don't close here - Dishka's RedisProvider handles cleanup when scope exits client: redis.Redis = await scope.get(redis.Redis) yield client -# ===== HTTP helpers (auth) ===== -async def _http_login(client: httpx.AsyncClient, username: str, password: str) -> str: - data = {"username": username, "password": password} - resp = await client.post("/api/v1/auth/login", data=data) - resp.raise_for_status() - return resp.json().get("csrf_token", "") - +# ===== Authenticated client fixtures ===== +# Return httpx.AsyncClient with CSRF header pre-set. Just use test_user.post(...) directly. -@pytest.fixture -def test_user_credentials(): - uid = uuid.uuid4().hex[:8] - return { - "username": f"test_user_{uid}", - "email": f"test_user_{uid}@example.com", - "password": "TestPass123!", - "role": "user", - } +async def _create_authenticated_client( + app: FastAPI, username: str, email: str, password: str, role: str +) -> httpx.AsyncClient: + """Create and return an authenticated client with CSRF header set.""" + c = httpx.AsyncClient( + transport=ASGITransport(app=app), + base_url="https://test", + timeout=30.0, + follow_redirects=True, + ) + try: + r = await c.post("/api/v1/auth/register", json={ + "username": username, + "email": email, + "password": password, + "role": role, + }) + # 200: created, 400: username exists, 409: email exists - all OK to proceed to login + if r.status_code not in (200, 400, 409): + pytest.fail(f"Cannot create {role} (status {r.status_code}): {r.text}") -@pytest.fixture -def test_admin_credentials(): - uid = uuid.uuid4().hex[:8] - return { - "username": f"admin_user_{uid}", - "email": f"admin_user_{uid}@example.com", - "password": "AdminPass123!", - "role": "admin", - } + login_resp = await c.post("/api/v1/auth/login", data={ + "username": username, + "password": password, + }) + login_resp.raise_for_status() + + login_data = login_resp.json() + csrf = login_data.get("csrf_token") + if not csrf: + await c.aclose() + pytest.fail( + f"Login succeeded but csrf_token missing or empty for {role} '{username}'. " + f"Response: {login_resp.text}" + ) + + c.headers["X-CSRF-Token"] = csrf + return c + except Exception: + await c.aclose() + raise @pytest_asyncio.fixture -async def test_user(client: httpx.AsyncClient, test_user_credentials): - """Function-scoped authenticated user.""" - creds = test_user_credentials - r = await client.post("/api/v1/auth/register", json=creds) - if r.status_code not in (200, 201, 400): - pytest.fail(f"Cannot create test user (status {r.status_code}): {r.text}") - csrf = await _http_login(client, creds["username"], creds["password"]) - return {**creds, "csrf_token": csrf, "headers": {"X-CSRF-Token": csrf}} +async def test_user(app: FastAPI) -> AsyncGenerator[httpx.AsyncClient, None]: + """Authenticated user client. CSRF header is set automatically.""" + uid = uuid.uuid4().hex[:8] + c = await _create_authenticated_client( + app, + username=f"test_user_{uid}", + email=f"test_user_{uid}@example.com", + password="TestPass123!", + role="user", + ) + yield c + await c.aclose() @pytest_asyncio.fixture -async def test_admin(client: httpx.AsyncClient, test_admin_credentials): - """Function-scoped authenticated admin.""" - creds = test_admin_credentials - r = await client.post("/api/v1/auth/register", json=creds) - if r.status_code not in (200, 201, 400): - pytest.fail(f"Cannot create test admin (status {r.status_code}): {r.text}") - csrf = await _http_login(client, creds["username"], creds["password"]) - return {**creds, "csrf_token": csrf, "headers": {"X-CSRF-Token": csrf}} +async def test_admin(app: FastAPI) -> AsyncGenerator[httpx.AsyncClient, None]: + """Authenticated admin client. CSRF header is set automatically.""" + uid = uuid.uuid4().hex[:8] + c = await _create_authenticated_client( + app, + username=f"admin_user_{uid}", + email=f"admin_user_{uid}@example.com", + password="AdminPass123!", + role="admin", + ) + yield c + await c.aclose() @pytest_asyncio.fixture -async def another_user(client: httpx.AsyncClient): - username = f"test_user_{uuid.uuid4().hex[:8]}" - email = f"{username}@example.com" - password = "TestPass123!" - await client.post( - "/api/v1/auth/register", - json={ - "username": username, - "email": email, - "password": password, - "role": "user", - }, +async def another_user(app: FastAPI) -> AsyncGenerator[httpx.AsyncClient, None]: + """Another authenticated user client (for multi-user tests).""" + uid = uuid.uuid4().hex[:8] + c = await _create_authenticated_client( + app, + username=f"test_user_{uid}", + email=f"test_user_{uid}@example.com", + password="TestPass123!", + role="user", ) - csrf = await _http_login(client, username, password) - return { - "username": username, - "email": email, - "password": password, - "csrf_token": csrf, - "headers": {"X-CSRF-Token": csrf}, - } + yield c + await c.aclose() diff --git a/backend/tests/e2e/conftest.py b/backend/tests/e2e/conftest.py index e8243e1c..b753b352 100644 --- a/backend/tests/e2e/conftest.py +++ b/backend/tests/e2e/conftest.py @@ -1,4 +1,6 @@ """E2E tests conftest - with infrastructure cleanup.""" +from collections.abc import AsyncGenerator + import pytest_asyncio import redis.asyncio as redis @@ -7,7 +9,7 @@ @pytest_asyncio.fixture(autouse=True) -async def _cleanup(db: Database, redis_client: redis.Redis): +async def _cleanup(db: Database, redis_client: redis.Redis) -> AsyncGenerator[None, None]: """Clean DB and Redis before each E2E test. Only pre-test cleanup - post-test cleanup causes event loop issues diff --git a/backend/tests/e2e/test_execution_routes.py b/backend/tests/e2e/test_execution_routes.py index 2cb1fa7a..dd0bfa2f 100644 --- a/backend/tests/e2e/test_execution_routes.py +++ b/backend/tests/e2e/test_execution_routes.py @@ -1,6 +1,4 @@ import asyncio -import os -from typing import Dict from uuid import UUID import pytest @@ -10,7 +8,7 @@ from app.schemas_pydantic.execution import ( ExecutionResponse, ExecutionResult, - ResourceUsage + ResourceUsage, ) pytestmark = [pytest.mark.e2e, pytest.mark.k8s] @@ -37,16 +35,8 @@ async def test_execute_requires_authentication(self, client: AsyncClient) -> Non for word in ["not authenticated", "unauthorized", "login"]) @pytest.mark.asyncio - async def test_execute_simple_python_script(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_execute_simple_python_script(self, test_user: AsyncClient) -> None: """Test executing a simple Python script.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Execute script execution_request = { "script": "print('Hello from real backend!')", @@ -54,7 +44,7 @@ async def test_execute_simple_python_script(self, client: AsyncClient, test_user "lang_version": "3.11" } - response = await client.post("/api/v1/execute", json=execution_request) + response = await test_user.post("/api/v1/execute", json=execution_request) assert response.status_code == 200 # Validate response structure @@ -80,16 +70,8 @@ async def test_execute_simple_python_script(self, client: AsyncClient, test_user ] @pytest.mark.asyncio - async def test_get_execution_result(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_execution_result(self, test_user: AsyncClient) -> None: """Test getting execution result after completion using SSE (event-driven).""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Execute a simple script execution_request = { "script": "print('Test output')\nprint('Line 2')", @@ -97,13 +79,13 @@ async def test_get_execution_result(self, client: AsyncClient, test_user: Dict[s "lang_version": "3.11" } - exec_response = await client.post("/api/v1/execute", json=execution_request) + exec_response = await test_user.post("/api/v1/execute", json=execution_request) assert exec_response.status_code == 200 execution_id = exec_response.json()["execution_id"] # Immediately fetch result - no waiting - result_response = await client.get(f"/api/v1/result/{execution_id}") + result_response = await test_user.get(f"/api/v1/result/{execution_id}") assert result_response.status_code == 200 result_data = result_response.json() @@ -120,16 +102,8 @@ async def test_get_execution_result(self, client: AsyncClient, test_user: Dict[s assert "Line 2" in execution_result.stdout @pytest.mark.asyncio - async def test_execute_with_error(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_execute_with_error(self, test_user: AsyncClient) -> None: """Test executing a script that produces an error.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Execute script with intentional error execution_request = { "script": "print('Before error')\nraise ValueError('Test error')\nprint('After error')", @@ -137,7 +111,7 @@ async def test_execute_with_error(self, client: AsyncClient, test_user: Dict[str "lang_version": "3.11" } - exec_response = await client.post("/api/v1/execute", json=execution_request) + exec_response = await test_user.post("/api/v1/execute", json=execution_request) assert exec_response.status_code == 200 execution_id = exec_response.json()["execution_id"] @@ -145,16 +119,8 @@ async def test_execute_with_error(self, client: AsyncClient, test_user: Dict[str # No waiting - execution was accepted, error will be processed asynchronously @pytest.mark.asyncio - async def test_execute_with_resource_tracking(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_execute_with_resource_tracking(self, test_user: AsyncClient) -> None: """Test that execution tracks resource usage.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Execute script that uses some resources execution_request = { "script": """ @@ -169,7 +135,7 @@ async def test_execute_with_resource_tracking(self, client: AsyncClient, test_us "lang_version": "3.11" } - exec_response = await client.post("/api/v1/execute", json=execution_request) + exec_response = await test_user.post("/api/v1/execute", json=execution_request) assert exec_response.status_code == 200 execution_id = exec_response.json()["execution_id"] @@ -177,7 +143,7 @@ async def test_execute_with_resource_tracking(self, client: AsyncClient, test_us # No waiting - execution was accepted, error will be processed asynchronously # Fetch result and validate resource usage if present - result_response = await client.get(f"/api/v1/result/{execution_id}") + result_response = await test_user.get(f"/api/v1/result/{execution_id}") if result_response.status_code == 200 and result_response.json().get("resource_usage"): resource_usage = ResourceUsage(**result_response.json()["resource_usage"]) if resource_usage.execution_time_wall_seconds is not None: @@ -186,17 +152,8 @@ async def test_execute_with_resource_tracking(self, client: AsyncClient, test_us assert resource_usage.peak_memory_kb >= 0 @pytest.mark.asyncio - async def test_execute_with_different_language_versions(self, client: AsyncClient, - test_user: Dict[str, str]) -> None: + async def test_execute_with_different_language_versions(self, test_user: AsyncClient) -> None: """Test execution with different Python versions.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Test different Python versions (if supported) test_cases = [ ("3.10", "import sys; print(f'Python {sys.version}')"), @@ -211,7 +168,7 @@ async def test_execute_with_different_language_versions(self, client: AsyncClien "lang_version": version } - response = await client.post("/api/v1/execute", json=execution_request) + response = await test_user.post("/api/v1/execute", json=execution_request) # Should either accept (200) or reject unsupported version (400/422) assert response.status_code in [200, 400, 422] @@ -220,16 +177,8 @@ async def test_execute_with_different_language_versions(self, client: AsyncClien assert "execution_id" in data @pytest.mark.asyncio - async def test_execute_with_large_output(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_execute_with_large_output(self, test_user: AsyncClient) -> None: """Test execution with large output.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Script that produces large output execution_request = { "script": """ @@ -242,14 +191,14 @@ async def test_execute_with_large_output(self, client: AsyncClient, test_user: D "lang_version": "3.11" } - exec_response = await client.post("/api/v1/execute", json=execution_request) + exec_response = await test_user.post("/api/v1/execute", json=execution_request) assert exec_response.status_code == 200 execution_id = exec_response.json()["execution_id"] # No waiting - execution was accepted, error will be processed asynchronously # Validate output from result endpoint (best-effort) - result_response = await client.get(f"/api/v1/result/{execution_id}") + result_response = await test_user.get(f"/api/v1/result/{execution_id}") if result_response.status_code == 200: result_data = result_response.json() if result_data.get("status") == "COMPLETED": @@ -258,16 +207,8 @@ async def test_execute_with_large_output(self, client: AsyncClient, test_user: D assert "End of output" in result_data["stdout"] or len(result_data["stdout"]) > 10000 @pytest.mark.asyncio - async def test_cancel_running_execution(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_cancel_running_execution(self, test_user: AsyncClient) -> None: """Test cancelling a running execution.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Start a long-running script execution_request = { "script": """ @@ -282,7 +223,7 @@ async def test_cancel_running_execution(self, client: AsyncClient, test_user: Di "lang_version": "3.11" } - exec_response = await client.post("/api/v1/execute", json=execution_request) + exec_response = await test_user.post("/api/v1/execute", json=execution_request) assert exec_response.status_code == 200 execution_id = exec_response.json()["execution_id"] @@ -293,7 +234,9 @@ async def test_cancel_running_execution(self, client: AsyncClient, test_user: Di } try: - cancel_response = await client.post(f"/api/v1/{execution_id}/cancel", json=cancel_request) + cancel_response = await test_user.post( + f"/api/v1/{execution_id}/cancel", json=cancel_request + ) except Exception: pytest.skip("Cancel endpoint not available or connection dropped") if cancel_response.status_code >= 500: @@ -304,21 +247,13 @@ async def test_cancel_running_execution(self, client: AsyncClient, test_user: Di # Cancel response of 200 means cancellation was accepted @pytest.mark.asyncio - async def test_execution_with_timeout(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_execution_with_timeout(self, test_user: AsyncClient) -> None: """Bounded check: long-running executions don't finish immediately. The backend's default timeout is 300s. To keep integration fast, assert that within a short window the execution is either still running or has transitioned to a terminal state due to platform limits. """ - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Script that would run forever execution_request = { "script": """ @@ -332,7 +267,7 @@ async def test_execution_with_timeout(self, client: AsyncClient, test_user: Dict "lang_version": "3.11" } - exec_response = await client.post("/api/v1/execute", json=execution_request) + exec_response = await test_user.post("/api/v1/execute", json=execution_request) assert exec_response.status_code == 200 execution_id = exec_response.json()["execution_id"] @@ -341,16 +276,8 @@ async def test_execution_with_timeout(self, client: AsyncClient, test_user: Dict # No need to wait or observe states @pytest.mark.asyncio - async def test_sandbox_restrictions(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_sandbox_restrictions(self, test_user: AsyncClient) -> None: """Test that dangerous operations are blocked by sandbox.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Try dangerous operations that should be blocked dangerous_scripts = [ # File system access @@ -370,14 +297,14 @@ async def test_sandbox_restrictions(self, client: AsyncClient, test_user: Dict[s "lang_version": "3.11" } - exec_response = await client.post("/api/v1/execute", json=execution_request) + exec_response = await test_user.post("/api/v1/execute", json=execution_request) # Should either reject immediately or fail during execution if exec_response.status_code == 200: execution_id = exec_response.json()["execution_id"] # Immediately check result - no waiting - result_resp = await client.get(f"/api/v1/result/{execution_id}") + result_resp = await test_user.get(f"/api/v1/result/{execution_id}") if result_resp.status_code == 200: result_data = result_resp.json() # Dangerous operations should either: @@ -397,16 +324,8 @@ async def test_sandbox_restrictions(self, client: AsyncClient, test_user: Dict[s assert exec_response.status_code in [400, 422] @pytest.mark.asyncio - async def test_concurrent_executions_by_same_user(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_concurrent_executions_by_same_user(self, test_user: AsyncClient) -> None: """Test running multiple executions concurrently.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Submit multiple executions execution_request = { "script": "import time; time.sleep(1); print('Concurrent test')", @@ -415,8 +334,8 @@ async def test_concurrent_executions_by_same_user(self, client: AsyncClient, tes } tasks = [] - for i in range(3): - task = client.post("/api/v1/execute", json=execution_request) + for _ in range(3): + task = test_user.post("/api/v1/execute", json=execution_request) tasks.append(task) responses = await asyncio.gather(*tasks) @@ -464,44 +383,35 @@ async def test_get_k8s_resource_limits(self, client: AsyncClient) -> None: assert key in limits @pytest.mark.asyncio - async def test_get_user_executions_list(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_user_executions_list(self, test_user: AsyncClient) -> None: """User executions list returns paginated executions for current user.""" - # Login first - login_data = {"username": test_user["username"], "password": test_user["password"]} - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # List executions - response = await client.get("/api/v1/user/executions?limit=5&skip=0") + response = await test_user.get("/api/v1/user/executions?limit=5&skip=0") assert response.status_code == 200 payload = response.json() assert set(["executions", "total", "limit", "skip", "has_more"]).issubset(payload.keys()) @pytest.mark.asyncio - async def test_execution_idempotency_same_key_returns_same_execution(self, client: AsyncClient, - test_user: Dict[str, str]) -> None: + async def test_execution_idempotency_same_key_returns_same_execution( + self, test_user: AsyncClient + ) -> None: """Submitting the same request with the same Idempotency-Key yields the same execution_id.""" - # Login first - login_data = {"username": test_user["username"], "password": test_user["password"]} - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - execution_request = { "script": "print('Idempotency integration test')", "lang": "python", "lang_version": "3.11", } + # Add idempotency key header (CSRF is already set on test_user) headers = {"Idempotency-Key": "it-idem-key-123"} # Use idempotency header on both requests to guarantee keying - r1 = await client.post("/api/v1/execute", json=execution_request, headers=headers) - assert r1.status_code == 200 + r1 = await test_user.post("/api/v1/execute", json=execution_request, headers=headers) assert r1.status_code == 200 e1 = r1.json()["execution_id"] # Second request with same key must return the same execution id - r2 = await client.post("/api/v1/execute", json=execution_request, headers=headers) + r2 = await test_user.post("/api/v1/execute", json=execution_request, headers=headers) assert r2.status_code == 200 e2 = r2.json()["execution_id"] diff --git a/backend/tests/e2e/test_resource_cleaner_orphan.py b/backend/tests/e2e/test_resource_cleaner_orphan.py index 2cd36173..2cfb199b 100644 --- a/backend/tests/e2e/test_resource_cleaner_orphan.py +++ b/backend/tests/e2e/test_resource_cleaner_orphan.py @@ -13,7 +13,7 @@ _test_logger = logging.getLogger("test.k8s.resource_cleaner_orphan") -def _ensure_kubeconfig(): +def _ensure_kubeconfig() -> None: try: k8s_config.load_incluster_config() except Exception: @@ -21,7 +21,7 @@ def _ensure_kubeconfig(): @pytest.mark.asyncio -async def test_cleanup_orphaned_configmaps_dry_run(): +async def test_cleanup_orphaned_configmaps_dry_run() -> None: _ensure_kubeconfig() v1 = k8s_client.CoreV1Api() ns = "default" @@ -41,7 +41,7 @@ async def test_cleanup_orphaned_configmaps_dry_run(): cleaned = await cleaner.cleanup_orphaned_resources(namespace=ns, max_age_hours=0, dry_run=True) # We expect our configmap to be a candidate; poll the response - async def _has_cm(): + async def _has_cm() -> None: # If cleaner is non-deterministic across runs, re-invoke to reflect current state res = await cleaner.cleanup_orphaned_resources(namespace=ns, max_age_hours=0, dry_run=True) assert any(name == cm for cm in res.get("configmaps", [])) diff --git a/backend/tests/helpers/__init__.py b/backend/tests/helpers/__init__.py index f6e01139..31402bb5 100644 --- a/backend/tests/helpers/__init__.py +++ b/backend/tests/helpers/__init__.py @@ -1,3 +1,6 @@ """Helper utilities for tests (async polling, Kafka utilities, event factories).""" -from .events import make_execution_requested_event # re-export +from .auth import AuthResult, login_user +from .events import make_execution_requested_event + +__all__ = ["AuthResult", "login_user", "make_execution_requested_event"] diff --git a/backend/tests/helpers/auth.py b/backend/tests/helpers/auth.py new file mode 100644 index 00000000..a7d8b947 --- /dev/null +++ b/backend/tests/helpers/auth.py @@ -0,0 +1,42 @@ +from typing import TypedDict + +from httpx import AsyncClient + + +class AuthResult(TypedDict): + """Result of a login operation with CSRF token.""" + + csrf_token: str + headers: dict[str, str] + + +async def login_user(client: AsyncClient, username: str, password: str) -> AuthResult: + """Login a user and return CSRF token and headers for subsequent requests. + + Use this helper when tests need to switch users or re-authenticate. + The returned headers dict should be passed to POST/PUT/DELETE requests. + + Args: + client: The httpx AsyncClient + username: Username to login with + password: Password for the user + + Returns: + AuthResult with csrf_token and headers dict containing X-CSRF-Token + + Raises: + AssertionError: If login fails + """ + response = await client.post( + "/api/v1/auth/login", + data={"username": username, "password": password}, + ) + assert response.status_code == 200, f"Login failed: {response.text}" + + json_data: dict[str, str] = response.json() + csrf_token = json_data.get("csrf_token", "") + + return AuthResult( + csrf_token=csrf_token, + headers={"X-CSRF-Token": csrf_token}, + ) diff --git a/backend/tests/helpers/cleanup.py b/backend/tests/helpers/cleanup.py index 33a4cdfd..7dcccaa8 100644 --- a/backend/tests/helpers/cleanup.py +++ b/backend/tests/helpers/cleanup.py @@ -1,23 +1,23 @@ """Shared cleanup utilities for integration and E2E tests.""" import redis.asyncio as redis -from beanie import init_beanie from app.core.database_context import Database -from app.db.docs import ALL_DOCUMENTS async def cleanup_db_and_redis(db: Database, redis_client: redis.Redis) -> None: """Clean DB and Redis before a test. - NOTE: With pytest-xdist, each worker uses a separate Redis database - (gw0→db0, gw1→db1, etc.), so flushdb() is safe and only affects - that worker's database. See tests/conftest.py for REDIS_DB setup. + Beanie is already initialized once during app lifespan (dishka_lifespan.py). + We just delete documents to preserve indexes and avoid file descriptor exhaustion. + + NOTE: With pytest-xdist, each worker is assigned a dedicated Redis DB + derived from the worker id (sum(_WORKER_ID.encode()) % 16), so flushdb() + is safe and only affects that worker's database. See tests/conftest.py + for REDIS_DB setup. """ - collections = await db.list_collection_names() + collections = await db.list_collection_names(filter={"type": "collection"}) for name in collections: if not name.startswith("system."): - await db.drop_collection(name) + await db[name].delete_many({}) await redis_client.flushdb() - - await init_beanie(database=db, document_models=ALL_DOCUMENTS) diff --git a/backend/tests/helpers/events.py b/backend/tests/helpers/events.py index 63b6fc15..b4eef7ca 100644 --- a/backend/tests/helpers/events.py +++ b/backend/tests/helpers/events.py @@ -34,6 +34,7 @@ def make_execution_requested_event( metadata = AvroEventMetadata(service_name=service_name, service_version=service_version, user_id=user_id) return ExecutionRequestedEvent( execution_id=execution_id, + aggregate_id=execution_id, # Match production: aggregate_id == execution_id for execution events script=script, language=language, language_version=language_version, diff --git a/backend/tests/helpers/eventually.py b/backend/tests/helpers/eventually.py index f72689f3..ee5c525b 100644 --- a/backend/tests/helpers/eventually.py +++ b/backend/tests/helpers/eventually.py @@ -1,33 +1,36 @@ import asyncio -from typing import Awaitable, Callable, TypeVar +from collections.abc import Awaitable, Callable +from typing import TypeVar T = TypeVar("T") async def eventually( - fn: Callable[[], Awaitable[T]] | Callable[[], T], + fn: Callable[[], Awaitable[T]], *, timeout: float = 10.0, interval: float = 0.1, exceptions: tuple[type[BaseException], ...] = (AssertionError,), ) -> T: - """Polls `fn` until it succeeds or timeout elapses. + """Poll async `fn` until it succeeds or timeout elapses. - - `fn` may be sync or async. If it raises one of `exceptions`, it is retried. - - Returns the value of `fn` on success. - - Raises the last exception after timeout. + Args: + fn: Async callable to poll. Retried if it raises one of `exceptions`. + timeout: Maximum time to wait in seconds. + interval: Time between retries in seconds. + exceptions: Exception types that trigger a retry. + + Returns: + The return value of `fn` on success. + + Raises: + The last exception raised by `fn` after timeout. """ deadline = asyncio.get_running_loop().time() + timeout - last_exc: BaseException | None = None while True: try: - res = fn() - if asyncio.iscoroutine(res): - return await res # type: ignore[return-value] - return res # type: ignore[return-value] - except exceptions as exc: # type: ignore[misc] - last_exc = exc + return await fn() + except exceptions: if asyncio.get_running_loop().time() >= deadline: raise await asyncio.sleep(interval) - diff --git a/backend/tests/helpers/k8s_fakes.py b/backend/tests/helpers/k8s_fakes.py index 835e29e3..7104dc4f 100644 --- a/backend/tests/helpers/k8s_fakes.py +++ b/backend/tests/helpers/k8s_fakes.py @@ -131,30 +131,95 @@ class FakeApi: def __init__(self, logs: str) -> None: self._logs = logs - def read_namespaced_pod_log(self, name: str, namespace: str, tail_lines: int = 10000): # noqa: ARG002 + def read_namespaced_pod_log(self, name: str, namespace: str, tail_lines: int = 10000) -> str: # noqa: ARG002 return self._logs -def make_watch(events: list[dict[str, Any]], resource_version: str = "rv2"): - class _StopEvent: - def __init__(self, rv: str) -> None: - self.resource_version = rv +class StopEvent: + """Fake stop event for FakeWatch - holds resource_version.""" - class _Stream(list): - def __init__(self, ev: list[dict[str, Any]], rv: str) -> None: - super().__init__(ev) - self._stop_event = _StopEvent(rv) + def __init__(self, resource_version: str) -> None: + self.resource_version = resource_version + + +class FakeWatchStream: + """Fake watch stream object returned by FakeWatch.stream(). + + The real kubernetes watch stream has a _stop_event attribute that + holds the resource_version for use by _update_resource_version. + """ + + def __init__(self, events: list[dict[str, Any]], resource_version: str) -> None: + self._events = events + self._stop_event = StopEvent(resource_version) + self._index = 0 + + def __iter__(self) -> "FakeWatchStream": + return self + + def __next__(self) -> dict[str, Any]: + if self._index >= len(self._events): + raise StopIteration + event = self._events[self._index] + self._index += 1 + return event + + +class FakeWatch: + """Fake kubernetes Watch for testing.""" + + def __init__(self, events: list[dict[str, Any]], resource_version: str) -> None: + self._events = events + self._rv = resource_version + + def stream( + self, func: Any, **kwargs: Any # noqa: ARG002 + ) -> FakeWatchStream: + return FakeWatchStream(self._events, self._rv) + + def stop(self) -> None: + return None + + +def make_watch(events: list[dict[str, Any]], resource_version: str = "rv2") -> FakeWatch: + return FakeWatch(events, resource_version) + + +class FakeV1Api: + """Fake CoreV1Api for testing PodMonitor.""" + + def __init__(self, logs: str = "{}", pods: list[Pod] | None = None) -> None: + self._logs = logs + self._pods = pods or [] + + def read_namespaced_pod_log(self, name: str, namespace: str, tail_lines: int = 10000) -> str: # noqa: ARG002 + return self._logs + + def get_api_resources(self) -> None: + """Stub for connectivity check.""" + return None + + def list_namespaced_pod(self, namespace: str, label_selector: str) -> Any: # noqa: ARG002 + """Return configured pods for reconciliation tests.""" + + class PodList: + def __init__(self, items: list[Pod]) -> None: + self.items = items - class _Watch: - def __init__(self, ev: list[dict[str, Any]], rv: str) -> None: - self._events = ev - self._rv = rv + return PodList(list(self._pods)) - def stream(self, func, **kwargs): # noqa: ARG002 - return _Stream(list(self._events), self._rv) - def stop(self) -> None: - return None +def make_k8s_clients( + logs: str = "{}", + events: list[dict[str, Any]] | None = None, + resource_version: str = "rv1", + pods: list[Pod] | None = None, +) -> tuple[FakeV1Api, FakeWatch]: + """Create fake K8s clients for testing. - return _Watch(events, resource_version) + Returns (v1_api, watch) tuple for pure DI into PodMonitor. + """ + v1 = FakeV1Api(logs=logs, pods=pods) + watch = make_watch(events or [], resource_version) + return v1, watch diff --git a/backend/tests/helpers/kafka.py b/backend/tests/helpers/kafka.py index 4ceefb22..531e0bc9 100644 --- a/backend/tests/helpers/kafka.py +++ b/backend/tests/helpers/kafka.py @@ -1,19 +1,21 @@ -from typing import Awaitable, Callable +from collections.abc import Awaitable, Callable import pytest +from dishka import AsyncContainer from app.events.core import UnifiedProducer from app.infrastructure.kafka.events.base import BaseEvent @pytest.fixture(scope="function") -async def producer(scope) -> UnifiedProducer: # type: ignore[valid-type] +async def producer(scope: AsyncContainer) -> UnifiedProducer: """Real Kafka producer from DI scope.""" - return await scope.get(UnifiedProducer) + prod: UnifiedProducer = await scope.get(UnifiedProducer) + return prod @pytest.fixture(scope="function") -def send_event(producer: UnifiedProducer) -> Callable[[BaseEvent], Awaitable[None]]: # type: ignore[valid-type] +def send_event(producer: UnifiedProducer) -> Callable[[BaseEvent], Awaitable[None]]: async def _send(ev: BaseEvent) -> None: await producer.produce(ev) return _send diff --git a/backend/tests/helpers/sse.py b/backend/tests/helpers/sse.py index e167467c..e72670f1 100644 --- a/backend/tests/helpers/sse.py +++ b/backend/tests/helpers/sse.py @@ -5,7 +5,7 @@ from httpx import AsyncClient -async def stream_sse(client: AsyncClient, url: str, timeout: float = 20.0) -> AsyncIterator[dict]: +async def stream_sse(client: AsyncClient, url: str, timeout: float = 20.0) -> AsyncIterator[dict[str, object]]: """Yield parsed SSE event dicts from the given URL within a timeout. Expects lines in the form "data: {...json...}" and ignores keepalives. @@ -31,11 +31,11 @@ async def wait_for_event_type( url: str, wanted_types: Iterable[str], timeout: float = 20.0, -) -> dict: - """Return first event whose type/event_type is in wanted_types, otherwise timeout.""" +) -> dict[str, object]: + """Return first event whose event_type is in wanted_types, otherwise timeout.""" wanted = {str(t).lower() for t in wanted_types} async for ev in stream_sse(client, url, timeout=timeout): - et = str(ev.get("type") or ev.get("event_type") or "").lower() + et = str(ev.get("event_type") or "").lower() if et in wanted: return ev raise TimeoutError(f"No event of types {wanted} seen on {url} within {timeout}s") @@ -45,7 +45,7 @@ async def wait_for_execution_terminal( client: AsyncClient, execution_id: str, timeout: float = 30.0, -) -> dict: +) -> dict[str, object]: terminal = {"execution_completed", "result_stored", "execution_failed", "execution_timeout", "execution_cancelled"} url = f"/api/v1/events/executions/{execution_id}" return await wait_for_event_type(client, url, terminal, timeout=timeout) @@ -55,7 +55,7 @@ async def wait_for_execution_running( client: AsyncClient, execution_id: str, timeout: float = 15.0, -) -> dict: +) -> dict[str, object]: running = {"execution_running", "execution_started", "execution_scheduled", "execution_queued"} url = f"/api/v1/events/executions/{execution_id}" return await wait_for_event_type(client, url, running, timeout=timeout) diff --git a/backend/tests/integration/app/test_main_app.py b/backend/tests/integration/app/test_main_app.py index 36af7d12..529df5f2 100644 --- a/backend/tests/integration/app/test_main_app.py +++ b/backend/tests/integration/app/test_main_app.py @@ -3,6 +3,7 @@ import pytest from fastapi import FastAPI from starlette.middleware.cors import CORSMiddleware +from starlette.routing import Route from app.core.correlation import CorrelationMiddleware from app.core.middlewares import ( @@ -11,30 +12,32 @@ RateLimitMiddleware, RequestSizeLimitMiddleware, ) +from app.settings import Settings pytestmark = pytest.mark.integration -def test_create_app_real_instance(app) -> None: # type: ignore[valid-type] +def test_create_app_real_instance(app: FastAPI, test_settings: Settings) -> None: assert isinstance(app, FastAPI) # Verify API routes are configured - paths = {r.path for r in app.router.routes} + paths = {r.path for r in app.router.routes if isinstance(r, Route)} assert any(p.startswith("/api/") for p in paths) # Verify required middlewares are actually present in the stack - middleware_classes = {m.cls for m in app.user_middleware} + middleware_class_names = {getattr(m.cls, "__name__", str(m.cls)) for m in app.user_middleware} # Check that all required middlewares are configured - assert CORSMiddleware in middleware_classes, "CORS middleware not configured" - assert CorrelationMiddleware in middleware_classes, "Correlation middleware not configured" - assert RequestSizeLimitMiddleware in middleware_classes, "Request size limit middleware not configured" - assert CacheControlMiddleware in middleware_classes, "Cache control middleware not configured" - assert MetricsMiddleware in middleware_classes, "Metrics middleware not configured" - assert RateLimitMiddleware in middleware_classes, "Rate limit middleware not configured" + assert "CORSMiddleware" in middleware_class_names, "CORS middleware not configured" + assert "CorrelationMiddleware" in middleware_class_names, "Correlation middleware not configured" + assert "RequestSizeLimitMiddleware" in middleware_class_names, "Request size limit middleware not configured" + assert "CacheControlMiddleware" in middleware_class_names, "Cache control middleware not configured" + assert "MetricsMiddleware" in middleware_class_names, "Metrics middleware not configured" + if test_settings.RATE_LIMIT_ENABLED: + assert "RateLimitMiddleware" in middleware_class_names, "Rate limit middleware not configured" -def test_create_app_function_constructs(app) -> None: # type: ignore[valid-type] +def test_create_app_function_constructs(test_settings: Settings) -> None: # Sanity: calling create_app returns a FastAPI instance (lazy import) - inst = import_module("app.main").create_app() + inst = import_module("app.main").create_app(settings=test_settings) assert isinstance(inst, FastAPI) diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py index a59a32a9..4ae85086 100644 --- a/backend/tests/integration/conftest.py +++ b/backend/tests/integration/conftest.py @@ -1,4 +1,6 @@ """Integration tests conftest - with infrastructure cleanup.""" +from collections.abc import AsyncGenerator + import pytest_asyncio import redis.asyncio as redis @@ -7,7 +9,7 @@ @pytest_asyncio.fixture(autouse=True) -async def _cleanup(db: Database, redis_client: redis.Redis): +async def _cleanup(db: Database, redis_client: redis.Redis) -> AsyncGenerator[None, None]: """Clean DB and Redis before each integration test. Only pre-test cleanup - post-test cleanup causes event loop issues diff --git a/backend/tests/integration/core/test_container.py b/backend/tests/integration/core/test_container.py index 36bad89a..e6b6a1ca 100644 --- a/backend/tests/integration/core/test_container.py +++ b/backend/tests/integration/core/test_container.py @@ -1,14 +1,14 @@ import pytest from dishka import AsyncContainer -from app.core.database_context import Database +from app.core.database_context import Database from app.services.event_service import EventService pytestmark = [pytest.mark.integration, pytest.mark.mongodb] @pytest.mark.asyncio -async def test_container_resolves_services(app_container, scope) -> None: # type: ignore[valid-type] +async def test_container_resolves_services(app_container: AsyncContainer, scope: AsyncContainer) -> None: # Container is the real Dishka container assert isinstance(app_container, AsyncContainer) diff --git a/backend/tests/integration/core/test_dishka_lifespan.py b/backend/tests/integration/core/test_dishka_lifespan.py index bdb5c38c..b1948131 100644 --- a/backend/tests/integration/core/test_dishka_lifespan.py +++ b/backend/tests/integration/core/test_dishka_lifespan.py @@ -1,14 +1,16 @@ +from importlib import import_module + +from app.settings import Settings from fastapi import FastAPI -def test_lifespan_container_attached(app) -> None: # type: ignore[valid-type] +def test_lifespan_container_attached(app: FastAPI) -> None: # App fixture uses real lifespan; container is attached to app.state assert isinstance(app, FastAPI) assert hasattr(app.state, "dishka_container") -def test_create_app_attaches_container() -> None: - from importlib import import_module - app = import_module("app.main").create_app() +def test_create_app_attaches_container(test_settings: Settings) -> None: + app = import_module("app.main").create_app(settings=test_settings) assert isinstance(app, FastAPI) assert hasattr(app.state, "dishka_container") diff --git a/backend/tests/integration/db/repositories/test_admin_settings_repository.py b/backend/tests/integration/db/repositories/test_admin_settings_repository.py index 7c19cf50..ecbe6b30 100644 --- a/backend/tests/integration/db/repositories/test_admin_settings_repository.py +++ b/backend/tests/integration/db/repositories/test_admin_settings_repository.py @@ -1,4 +1,7 @@ import pytest +from dishka import AsyncContainer + +from app.core.database_context import Database from app.db.repositories.admin.admin_settings_repository import AdminSettingsRepository from app.domain.admin import SystemSettings @@ -6,7 +9,7 @@ @pytest.fixture() -async def repo(scope) -> AdminSettingsRepository: # type: ignore[valid-type] +async def repo(scope: AsyncContainer) -> AdminSettingsRepository: return await scope.get(AdminSettingsRepository) @@ -24,7 +27,7 @@ async def test_get_system_settings_existing(repo: AdminSettingsRepository) -> No @pytest.mark.asyncio -async def test_update_and_reset_settings(repo: AdminSettingsRepository, db) -> None: # type: ignore[valid-type] +async def test_update_and_reset_settings(repo: AdminSettingsRepository, db: Database) -> None: s = SystemSettings() updated = await repo.update_system_settings(s, updated_by="admin", user_id="u1") assert isinstance(updated, SystemSettings) diff --git a/backend/tests/integration/db/repositories/test_dlq_repository.py b/backend/tests/integration/db/repositories/test_dlq_repository.py index 07d3711f..06bbb5f8 100644 --- a/backend/tests/integration/db/repositories/test_dlq_repository.py +++ b/backend/tests/integration/db/repositories/test_dlq_repository.py @@ -17,7 +17,7 @@ def repo() -> DLQRepository: return DLQRepository(_test_logger) -async def insert_test_dlq_docs(): +async def insert_test_dlq_docs() -> None: """Insert test DLQ documents using Beanie.""" now = datetime.now(timezone.utc) @@ -92,14 +92,3 @@ async def test_stats_list_get_and_updates(repo: DLQRepository) -> None: topics = await repo.get_topics_summary() assert any(t.topic == "t1" for t in topics) - - -@pytest.mark.asyncio -async def test_retry_batch(repo: DLQRepository) -> None: - class Manager: - async def retry_message_manually(self, eid: str) -> bool: # noqa: ARG002 - return True - - result = await repo.retry_messages_batch(["missing"], Manager()) - # Missing messages cause failures - assert result.total == 1 and result.failed >= 1 diff --git a/backend/tests/integration/db/repositories/test_saved_script_repository.py b/backend/tests/integration/db/repositories/test_saved_script_repository.py index 85fc2b58..92d26699 100644 --- a/backend/tests/integration/db/repositories/test_saved_script_repository.py +++ b/backend/tests/integration/db/repositories/test_saved_script_repository.py @@ -1,4 +1,6 @@ import pytest +from dishka import AsyncContainer + from app.db.repositories.saved_script_repository import SavedScriptRepository from app.domain.saved_script import DomainSavedScriptCreate, DomainSavedScriptUpdate @@ -6,7 +8,7 @@ @pytest.fixture() -async def repo(scope) -> SavedScriptRepository: # type: ignore[valid-type] +async def repo(scope: AsyncContainer) -> SavedScriptRepository: return await scope.get(SavedScriptRepository) diff --git a/backend/tests/integration/dlq/test_dlq_discard.py b/backend/tests/integration/dlq/test_dlq_discard.py new file mode 100644 index 00000000..0549fd7f --- /dev/null +++ b/backend/tests/integration/dlq/test_dlq_discard.py @@ -0,0 +1,161 @@ +import logging +import uuid +from datetime import datetime, timezone + +import pytest +from dishka import AsyncContainer + +from app.db.docs import DLQMessageDocument +from app.db.repositories.dlq_repository import DLQRepository +from app.dlq.models import DLQMessageStatus +from app.domain.enums.events import EventType +from app.domain.enums.kafka import KafkaTopic +from tests.helpers import make_execution_requested_event + +pytestmark = [pytest.mark.integration, pytest.mark.mongodb] + +_test_logger = logging.getLogger("test.dlq.discard") + + +async def _create_dlq_document( + event_id: str | None = None, + status: DLQMessageStatus = DLQMessageStatus.PENDING, +) -> DLQMessageDocument: + """Helper to create a DLQ document directly in MongoDB.""" + if event_id is None: + event_id = str(uuid.uuid4()) + + event = make_execution_requested_event(execution_id=f"exec-{uuid.uuid4().hex[:8]}") + now = datetime.now(timezone.utc) + + doc = DLQMessageDocument( + event=event.model_dump(), + event_id=event_id, + event_type=EventType.EXECUTION_REQUESTED, + original_topic=str(KafkaTopic.EXECUTION_EVENTS), + error="Test error", + retry_count=0, + failed_at=now, + status=status, + producer_id="test-producer", + created_at=now, + ) + await doc.insert() + return doc + + +@pytest.mark.asyncio +async def test_dlq_repository_marks_message_discarded(scope: AsyncContainer) -> None: + """Test that DLQRepository.mark_message_discarded() updates status correctly.""" + repository: DLQRepository = await scope.get(DLQRepository) + + # Create a DLQ document + event_id = f"dlq-discard-{uuid.uuid4().hex[:8]}" + doc = await _create_dlq_document(event_id=event_id, status=DLQMessageStatus.PENDING) + + # Discard the message + reason = "max_retries_exceeded" + result = await repository.mark_message_discarded(event_id, reason) + + assert result is True + + # Verify the status changed + updated_doc = await DLQMessageDocument.find_one({"event_id": event_id}) + assert updated_doc is not None + assert updated_doc.status == DLQMessageStatus.DISCARDED + assert updated_doc.discard_reason == reason + assert updated_doc.discarded_at is not None + + +@pytest.mark.asyncio +async def test_dlq_discard_nonexistent_message_returns_false(scope: AsyncContainer) -> None: + """Test that discarding a nonexistent message returns False.""" + repository: DLQRepository = await scope.get(DLQRepository) + + # Try to discard a message that doesn't exist + result = await repository.mark_message_discarded( + f"nonexistent-{uuid.uuid4().hex[:8]}", + "test_reason", + ) + + assert result is False + + +@pytest.mark.asyncio +async def test_dlq_discard_sets_timestamp(scope: AsyncContainer) -> None: + """Test that discarding sets the discarded_at timestamp.""" + repository: DLQRepository = await scope.get(DLQRepository) + + # Create a DLQ document + event_id = f"dlq-ts-{uuid.uuid4().hex[:8]}" + before_discard = datetime.now(timezone.utc) + await _create_dlq_document(event_id=event_id) + + # Discard the message + await repository.mark_message_discarded(event_id, "manual_discard") + after_discard = datetime.now(timezone.utc) + + # Verify timestamp is set correctly + doc = await DLQMessageDocument.find_one({"event_id": event_id}) + assert doc is not None + assert doc.discarded_at is not None + assert before_discard <= doc.discarded_at <= after_discard + + +@pytest.mark.asyncio +async def test_dlq_discard_with_custom_reason(scope: AsyncContainer) -> None: + """Test that custom discard reasons are stored.""" + repository: DLQRepository = await scope.get(DLQRepository) + + # Create a DLQ document + event_id = f"dlq-reason-{uuid.uuid4().hex[:8]}" + await _create_dlq_document(event_id=event_id) + + # Discard with custom reason + custom_reason = "manual: User requested deletion due to invalid payload" + await repository.mark_message_discarded(event_id, custom_reason) + + # Verify the reason is stored + doc = await DLQMessageDocument.find_one({"event_id": event_id}) + assert doc is not None + assert doc.discard_reason == custom_reason + + +@pytest.mark.asyncio +async def test_dlq_discard_from_scheduled_status(scope: AsyncContainer) -> None: + """Test that scheduled messages can be discarded.""" + repository: DLQRepository = await scope.get(DLQRepository) + + # Create a SCHEDULED DLQ document + event_id = f"dlq-scheduled-{uuid.uuid4().hex[:8]}" + await _create_dlq_document(event_id=event_id, status=DLQMessageStatus.SCHEDULED) + + # Discard the message + result = await repository.mark_message_discarded(event_id, "policy_change") + + assert result is True + + # Verify status transition + doc = await DLQMessageDocument.find_one({"event_id": event_id}) + assert doc is not None + assert doc.status == DLQMessageStatus.DISCARDED + + +@pytest.mark.asyncio +async def test_dlq_stats_reflect_discarded_messages(scope: AsyncContainer) -> None: + """Test that DLQ statistics correctly count discarded messages.""" + repository: DLQRepository = await scope.get(DLQRepository) + + # Capture count before to ensure our discard is what increments the stat + stats_before = await repository.get_dlq_stats() + count_before = stats_before.by_status.get(DLQMessageStatus.DISCARDED.value, 0) + + # Create and discard a message + event_id = f"dlq-stats-{uuid.uuid4().hex[:8]}" + await _create_dlq_document(event_id=event_id, status=DLQMessageStatus.PENDING) + await repository.mark_message_discarded(event_id, "test") + + # Get stats after - verify the count incremented by exactly 1 + stats_after = await repository.get_dlq_stats() + count_after = stats_after.by_status.get(DLQMessageStatus.DISCARDED.value, 0) + assert count_after == count_before + 1 diff --git a/backend/tests/integration/dlq/test_dlq_discard_policy.py b/backend/tests/integration/dlq/test_dlq_discard_policy.py deleted file mode 100644 index ba625f58..00000000 --- a/backend/tests/integration/dlq/test_dlq_discard_policy.py +++ /dev/null @@ -1,61 +0,0 @@ -import json -import logging -import uuid -from datetime import datetime, timezone - -import pytest -from app.db.docs import DLQMessageDocument -from app.dlq.manager import create_dlq_manager -from app.dlq.models import DLQMessageStatus, RetryPolicy, RetryStrategy -from app.domain.enums.kafka import KafkaTopic -from app.events.schema.schema_registry import create_schema_registry_manager -from confluent_kafka import Producer - -from tests.helpers import make_execution_requested_event -from tests.helpers.eventually import eventually - -# xdist_group: DLQ tests share a Kafka consumer group. When running in parallel, -# different workers' managers consume each other's messages and apply wrong policies. -# Serial execution ensures each test's manager processes only its own messages. -pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.mongodb, pytest.mark.xdist_group("dlq")] - -_test_logger = logging.getLogger("test.dlq.discard_policy") - - -@pytest.mark.asyncio -async def test_dlq_manager_discards_with_manual_policy(db, test_settings) -> None: # type: ignore[valid-type] - schema_registry = create_schema_registry_manager(test_settings, _test_logger) - manager = create_dlq_manager(settings=test_settings, schema_registry=schema_registry, logger=_test_logger) - # Use prefix from test_settings to match what the manager uses - prefix = test_settings.KAFKA_TOPIC_PREFIX - topic = f"{prefix}{str(KafkaTopic.EXECUTION_EVENTS)}" - manager.set_retry_policy(topic, RetryPolicy(topic=topic, strategy=RetryStrategy.MANUAL)) - - # Use unique execution_id to avoid conflicts with parallel test workers - ev = make_execution_requested_event(execution_id=f"exec-dlq-discard-{uuid.uuid4().hex[:8]}") - - payload = { - "event": ev.to_dict(), - "original_topic": topic, - "error": "boom", - "retry_count": 0, - "failed_at": datetime.now(timezone.utc).isoformat(), - "producer_id": "tests", - } - - producer = Producer({"bootstrap.servers": "localhost:9092"}) - producer.produce( - topic=f"{prefix}{str(KafkaTopic.DEAD_LETTER_QUEUE)}", - key=ev.event_id.encode(), - value=json.dumps(payload).encode(), - ) - producer.flush(5) - - async with manager: - - async def _discarded() -> None: - doc = await DLQMessageDocument.find_one({"event_id": ev.event_id}) - assert doc is not None - assert doc.status == DLQMessageStatus.DISCARDED - - await eventually(_discarded, timeout=10.0, interval=0.2) diff --git a/backend/tests/integration/dlq/test_dlq_manager.py b/backend/tests/integration/dlq/test_dlq_manager.py index b6da245e..63b69b0a 100644 --- a/backend/tests/integration/dlq/test_dlq_manager.py +++ b/backend/tests/integration/dlq/test_dlq_manager.py @@ -4,12 +4,14 @@ from datetime import datetime, timezone import pytest +from confluent_kafka import Producer + +from app.core.database_context import Database from app.db.docs import DLQMessageDocument from app.dlq.manager import create_dlq_manager from app.domain.enums.kafka import KafkaTopic from app.events.schema.schema_registry import create_schema_registry_manager -from confluent_kafka import Producer - +from app.settings import Settings from tests.helpers import make_execution_requested_event from tests.helpers.eventually import eventually @@ -22,7 +24,7 @@ @pytest.mark.asyncio -async def test_dlq_manager_persists_in_mongo(db, test_settings) -> None: # type: ignore[valid-type] +async def test_dlq_manager_persists_in_mongo(db: Database, test_settings: Settings) -> None: schema_registry = create_schema_registry_manager(test_settings, _test_logger) manager = create_dlq_manager(settings=test_settings, schema_registry=schema_registry, logger=_test_logger) @@ -52,7 +54,7 @@ async def test_dlq_manager_persists_in_mongo(db, test_settings) -> None: # type # Run the manager briefly to consume and persist async with manager: - async def _exists(): + async def _exists() -> None: doc = await DLQMessageDocument.find_one({"event_id": ev.event_id}) assert doc is not None diff --git a/backend/tests/integration/dlq/test_dlq_retry.py b/backend/tests/integration/dlq/test_dlq_retry.py new file mode 100644 index 00000000..cde2ac8c --- /dev/null +++ b/backend/tests/integration/dlq/test_dlq_retry.py @@ -0,0 +1,225 @@ +import logging +import uuid +from datetime import datetime, timezone + +import pytest +from app.db.docs import DLQMessageDocument +from app.db.repositories.dlq_repository import DLQRepository +from app.dlq.models import DLQMessageStatus +from app.domain.enums.events import EventType +from app.domain.enums.kafka import KafkaTopic +from dishka import AsyncContainer + +from tests.helpers import make_execution_requested_event + +pytestmark = [pytest.mark.integration, pytest.mark.mongodb] + +_test_logger = logging.getLogger("test.dlq.retry") + + +async def _create_dlq_document( + event_id: str | None = None, + status: DLQMessageStatus = DLQMessageStatus.PENDING, +) -> DLQMessageDocument: + """Helper to create a DLQ document directly in MongoDB.""" + if event_id is None: + event_id = str(uuid.uuid4()) + + event = make_execution_requested_event(execution_id=f"exec-{uuid.uuid4().hex[:8]}") + now = datetime.now(timezone.utc) + + doc = DLQMessageDocument( + event=event.model_dump(), + event_id=event_id, + event_type=EventType.EXECUTION_REQUESTED, + original_topic=str(KafkaTopic.EXECUTION_EVENTS), + error="Test error", + retry_count=0, + failed_at=now, + status=status, + producer_id="test-producer", + created_at=now, + ) + await doc.insert() + return doc + + +@pytest.mark.asyncio +async def test_dlq_repository_marks_message_retried(scope: AsyncContainer) -> None: + """Test that DLQRepository.mark_message_retried() updates status correctly.""" + repository: DLQRepository = await scope.get(DLQRepository) + + # Create a DLQ document + event_id = f"dlq-retry-{uuid.uuid4().hex[:8]}" + await _create_dlq_document(event_id=event_id, status=DLQMessageStatus.SCHEDULED) + + # Mark as retried + result = await repository.mark_message_retried(event_id) + + assert result is True + + # Verify the status changed + updated_doc = await DLQMessageDocument.find_one({"event_id": event_id}) + assert updated_doc is not None + assert updated_doc.status == DLQMessageStatus.RETRIED + assert updated_doc.retried_at is not None + + +@pytest.mark.asyncio +async def test_dlq_retry_nonexistent_message_returns_false(scope: AsyncContainer) -> None: + """Test that retrying a nonexistent message returns False.""" + repository: DLQRepository = await scope.get(DLQRepository) + + # Try to retry a message that doesn't exist + result = await repository.mark_message_retried(f"nonexistent-{uuid.uuid4().hex[:8]}") + + assert result is False + + +@pytest.mark.asyncio +async def test_dlq_retry_sets_timestamp(scope: AsyncContainer) -> None: + """Test that retrying sets the retried_at timestamp.""" + repository: DLQRepository = await scope.get(DLQRepository) + + # Create a DLQ document + event_id = f"dlq-retry-ts-{uuid.uuid4().hex[:8]}" + before_retry = datetime.now(timezone.utc) + await _create_dlq_document(event_id=event_id, status=DLQMessageStatus.SCHEDULED) + + # Retry the message + await repository.mark_message_retried(event_id) + after_retry = datetime.now(timezone.utc) + + # Verify timestamp is set correctly + doc = await DLQMessageDocument.find_one({"event_id": event_id}) + assert doc is not None + assert doc.retried_at is not None + assert before_retry <= doc.retried_at <= after_retry + + +@pytest.mark.asyncio +async def test_dlq_retry_from_pending_status(scope: AsyncContainer) -> None: + """Test that pending messages can be retried.""" + repository: DLQRepository = await scope.get(DLQRepository) + + # Create a PENDING DLQ document + event_id = f"dlq-pending-{uuid.uuid4().hex[:8]}" + await _create_dlq_document(event_id=event_id, status=DLQMessageStatus.PENDING) + + # Retry the message + result = await repository.mark_message_retried(event_id) + + assert result is True + + # Verify status transition + doc = await DLQMessageDocument.find_one({"event_id": event_id}) + assert doc is not None + assert doc.status == DLQMessageStatus.RETRIED + + +@pytest.mark.asyncio +async def test_dlq_stats_reflect_retried_messages(scope: AsyncContainer) -> None: + """Test that DLQ statistics correctly count retried messages.""" + repository: DLQRepository = await scope.get(DLQRepository) + + # Capture count before to ensure our retry is what increments the stat + stats_before = await repository.get_dlq_stats() + count_before = stats_before.by_status.get(DLQMessageStatus.RETRIED.value, 0) + + # Create and retry a message + event_id = f"dlq-stats-retry-{uuid.uuid4().hex[:8]}" + await _create_dlq_document(event_id=event_id, status=DLQMessageStatus.SCHEDULED) + await repository.mark_message_retried(event_id) + + # Get stats after - verify the count incremented by exactly 1 + stats_after = await repository.get_dlq_stats() + count_after = stats_after.by_status.get(DLQMessageStatus.RETRIED.value, 0) + assert count_after == count_before + 1 + + +@pytest.mark.asyncio +async def test_dlq_retry_already_retried_message(scope: AsyncContainer) -> None: + """Test that retrying an already RETRIED message still succeeds at repository level. + + Note: The DLQManager.retry_message_manually guards against this, but the + repository method doesn't - it's a low-level operation that always succeeds. + """ + repository: DLQRepository = await scope.get(DLQRepository) + + # Create an already RETRIED document + event_id = f"dlq-already-retried-{uuid.uuid4().hex[:8]}" + await _create_dlq_document(event_id=event_id, status=DLQMessageStatus.RETRIED) + + # Repository method still succeeds (no guard at this level) + result = await repository.mark_message_retried(event_id) + assert result is True + + # Status remains RETRIED + doc = await DLQMessageDocument.find_one({"event_id": event_id}) + assert doc is not None + assert doc.status == DLQMessageStatus.RETRIED + + +@pytest.mark.asyncio +async def test_dlq_retry_discarded_message(scope: AsyncContainer) -> None: + """Test that retrying a DISCARDED message still succeeds at repository level. + + Note: The DLQManager.retry_message_manually guards against this and returns False, + but the repository method is a low-level operation that doesn't validate transitions. + """ + repository: DLQRepository = await scope.get(DLQRepository) + + # Create a DISCARDED document + event_id = f"dlq-discarded-retry-{uuid.uuid4().hex[:8]}" + await _create_dlq_document(event_id=event_id, status=DLQMessageStatus.DISCARDED) + + # Repository method succeeds (transitions status back to RETRIED) + result = await repository.mark_message_retried(event_id) + assert result is True + + # Status is now RETRIED (repository doesn't guard transitions) + doc = await DLQMessageDocument.find_one({"event_id": event_id}) + assert doc is not None + assert doc.status == DLQMessageStatus.RETRIED + + +@pytest.mark.asyncio +async def test_dlq_discard_already_discarded_message(scope: AsyncContainer) -> None: + """Test that discarding an already DISCARDED message updates the reason.""" + repository: DLQRepository = await scope.get(DLQRepository) + + # Create an already DISCARDED document + event_id = f"dlq-already-discarded-{uuid.uuid4().hex[:8]}" + await _create_dlq_document(event_id=event_id, status=DLQMessageStatus.DISCARDED) + + # Discard again with a new reason + new_reason = "updated_discard_reason" + result = await repository.mark_message_discarded(event_id, new_reason) + assert result is True + + # Reason is updated + doc = await DLQMessageDocument.find_one({"event_id": event_id}) + assert doc is not None + assert doc.status == DLQMessageStatus.DISCARDED + assert doc.discard_reason == new_reason + + +@pytest.mark.asyncio +async def test_dlq_discard_retried_message(scope: AsyncContainer) -> None: + """Test that discarding a RETRIED message transitions to DISCARDED.""" + repository: DLQRepository = await scope.get(DLQRepository) + + # Create a RETRIED document + event_id = f"dlq-retried-discard-{uuid.uuid4().hex[:8]}" + await _create_dlq_document(event_id=event_id, status=DLQMessageStatus.RETRIED) + + # Discard it + reason = "manual_cleanup" + result = await repository.mark_message_discarded(event_id, reason) + assert result is True + + # Status is now DISCARDED + doc = await DLQMessageDocument.find_one({"event_id": event_id}) + assert doc is not None + assert doc.status == DLQMessageStatus.DISCARDED + assert doc.discard_reason == reason diff --git a/backend/tests/integration/dlq/test_dlq_retry_immediate.py b/backend/tests/integration/dlq/test_dlq_retry_immediate.py deleted file mode 100644 index 5c435b92..00000000 --- a/backend/tests/integration/dlq/test_dlq_retry_immediate.py +++ /dev/null @@ -1,66 +0,0 @@ -import json -import logging -import uuid -from datetime import datetime, timezone - -import pytest -from app.db.docs import DLQMessageDocument -from app.dlq.manager import create_dlq_manager -from app.dlq.models import DLQMessageStatus, RetryPolicy, RetryStrategy -from app.domain.enums.kafka import KafkaTopic -from app.events.schema.schema_registry import create_schema_registry_manager -from confluent_kafka import Producer - -from tests.helpers import make_execution_requested_event -from tests.helpers.eventually import eventually - -# xdist_group: DLQ tests share a Kafka consumer group. When running in parallel, -# different workers' managers consume each other's messages and apply wrong policies. -# Serial execution ensures each test's manager processes only its own messages. -pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.mongodb, pytest.mark.xdist_group("dlq")] - -_test_logger = logging.getLogger("test.dlq.retry_immediate") - - -@pytest.mark.asyncio -async def test_dlq_manager_immediate_retry_updates_doc(db, test_settings) -> None: # type: ignore[valid-type] - schema_registry = create_schema_registry_manager(test_settings, _test_logger) - manager = create_dlq_manager(settings=test_settings, schema_registry=schema_registry, logger=_test_logger) - # Use prefix from test_settings to match what the manager uses - prefix = test_settings.KAFKA_TOPIC_PREFIX - topic = f"{prefix}{str(KafkaTopic.EXECUTION_EVENTS)}" - manager.set_retry_policy( - topic, - RetryPolicy(topic=topic, strategy=RetryStrategy.IMMEDIATE, max_retries=1, base_delay_seconds=0.1), - ) - - # Use unique execution_id to avoid conflicts with parallel test workers - ev = make_execution_requested_event(execution_id=f"exec-dlq-retry-{uuid.uuid4().hex[:8]}") - - payload = { - "event": ev.to_dict(), - "original_topic": topic, - "error": "boom", - "retry_count": 0, - "failed_at": datetime.now(timezone.utc).isoformat(), - "producer_id": "tests", - } - - prod = Producer({"bootstrap.servers": "localhost:9092"}) - prod.produce( - topic=f"{prefix}{str(KafkaTopic.DEAD_LETTER_QUEUE)}", - key=ev.event_id.encode(), - value=json.dumps(payload).encode(), - ) - prod.flush(5) - - async with manager: - - async def _retried() -> None: - doc = await DLQMessageDocument.find_one({"event_id": ev.event_id}) - assert doc is not None - assert doc.status == DLQMessageStatus.RETRIED - assert doc.retry_count == 1 - assert doc.retried_at is not None - - await eventually(_retried, timeout=10.0, interval=0.2) diff --git a/backend/tests/integration/events/test_admin_utils.py b/backend/tests/integration/events/test_admin_utils.py index 7ab34509..db03ac86 100644 --- a/backend/tests/integration/events/test_admin_utils.py +++ b/backend/tests/integration/events/test_admin_utils.py @@ -1,18 +1,17 @@ import logging -import os import pytest from app.events.admin_utils import AdminUtils +from app.settings import Settings _test_logger = logging.getLogger("test.events.admin_utils") @pytest.mark.kafka @pytest.mark.asyncio -async def test_admin_utils_real_topic_checks() -> None: - prefix = os.environ.get("KAFKA_TOPIC_PREFIX", "test.") - topic = f"{prefix}adminutils.{os.environ.get('PYTEST_SESSION_ID', 'sid')}" - au = AdminUtils(logger=_test_logger) +async def test_admin_utils_real_topic_checks(test_settings: Settings) -> None: + topic = f"{test_settings.KAFKA_TOPIC_PREFIX}adminutils.{test_settings.KAFKA_GROUP_SUFFIX}" + au = AdminUtils(settings=test_settings, logger=_test_logger) # Ensure topic exists (idempotent) res = await au.ensure_topics_exist([(topic, 1)]) diff --git a/backend/tests/integration/events/test_consume_roundtrip.py b/backend/tests/integration/events/test_consume_roundtrip.py index b2ceb48b..830a950b 100644 --- a/backend/tests/integration/events/test_consume_roundtrip.py +++ b/backend/tests/integration/events/test_consume_roundtrip.py @@ -3,23 +3,27 @@ import uuid import pytest +from dishka import AsyncContainer + from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic from app.events.core import UnifiedConsumer, UnifiedProducer from app.events.core.dispatcher import EventDispatcher from app.events.core.types import ConsumerConfig from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas +from app.infrastructure.kafka.events.base import BaseEvent from app.settings import Settings - from tests.helpers import make_execution_requested_event -pytestmark = [pytest.mark.integration, pytest.mark.kafka] +# xdist_group: Kafka consumer creation can crash librdkafka when multiple workers +# instantiate Consumer() objects simultaneously. Serial execution prevents this. +pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.xdist_group("kafka_consumers")] _test_logger = logging.getLogger("test.events.consume_roundtrip") @pytest.mark.asyncio -async def test_produce_consume_roundtrip(scope) -> None: # type: ignore[valid-type] +async def test_produce_consume_roundtrip(scope: AsyncContainer) -> None: # Ensure schemas are registered registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) settings: Settings = await scope.get(Settings) @@ -33,7 +37,7 @@ async def test_produce_consume_roundtrip(scope) -> None: # type: ignore[valid-t received = asyncio.Event() @dispatcher.register(EventType.EXECUTION_REQUESTED) - async def _handle(_event) -> None: # noqa: ANN001 + async def _handle(_event: BaseEvent) -> None: received.set() group_id = f"test-consumer.{uuid.uuid4().hex[:6]}" @@ -51,7 +55,7 @@ async def _handle(_event) -> None: # noqa: ANN001 settings=settings, logger=_test_logger, ) - await consumer.start([str(KafkaTopic.EXECUTION_EVENTS)]) + await consumer.start([KafkaTopic.EXECUTION_EVENTS]) try: # Produce a request event diff --git a/backend/tests/integration/events/test_consumer_group_monitor.py b/backend/tests/integration/events/test_consumer_group_monitor.py index cfab3017..11d535dd 100644 --- a/backend/tests/integration/events/test_consumer_group_monitor.py +++ b/backend/tests/integration/events/test_consumer_group_monitor.py @@ -2,6 +2,7 @@ import pytest from app.events.consumer_group_monitor import ConsumerGroupHealth, NativeConsumerGroupMonitor +from app.settings import Settings _test_logger = logging.getLogger("test.events.consumer_group_monitor") @@ -9,8 +10,8 @@ @pytest.mark.integration @pytest.mark.kafka @pytest.mark.asyncio -async def test_list_groups_and_error_status(): - mon = NativeConsumerGroupMonitor(logger=_test_logger) +async def test_list_groups_and_error_status(test_settings: Settings) -> None: + mon = NativeConsumerGroupMonitor(settings=test_settings, logger=_test_logger) groups = await mon.list_consumer_groups() assert isinstance(groups, list) diff --git a/backend/tests/integration/events/test_consumer_group_monitor_real.py b/backend/tests/integration/events/test_consumer_group_monitor_real.py index a31ab4bf..233d0239 100644 --- a/backend/tests/integration/events/test_consumer_group_monitor_real.py +++ b/backend/tests/integration/events/test_consumer_group_monitor_real.py @@ -7,6 +7,7 @@ ConsumerGroupStatus, NativeConsumerGroupMonitor, ) +from app.settings import Settings pytestmark = [pytest.mark.integration, pytest.mark.kafka] @@ -14,8 +15,8 @@ @pytest.mark.asyncio -async def test_consumer_group_status_error_path_and_summary(): - monitor = NativeConsumerGroupMonitor(bootstrap_servers="localhost:9092", logger=_test_logger) +async def test_consumer_group_status_error_path_and_summary(test_settings: Settings) -> None: + monitor = NativeConsumerGroupMonitor(settings=test_settings, logger=_test_logger) # Non-existent group triggers error-handling path and returns minimal status gid = f"does-not-exist-{uuid4().hex[:8]}" status = await monitor.get_consumer_group_status(gid, timeout=5.0, include_lag=False) @@ -27,8 +28,8 @@ async def test_consumer_group_status_error_path_and_summary(): assert summary["group_id"] == gid and summary["health"] == ConsumerGroupHealth.UNHEALTHY.value -def test_assess_group_health_branches(): - m = NativeConsumerGroupMonitor(logger=_test_logger) +def test_assess_group_health_branches(test_settings: Settings) -> None: + m = NativeConsumerGroupMonitor(settings=test_settings, logger=_test_logger) # Error state s = ConsumerGroupStatus( group_id="g", @@ -81,8 +82,8 @@ def test_assess_group_health_branches(): @pytest.mark.asyncio -async def test_multiple_group_status_mixed_errors(): - m = NativeConsumerGroupMonitor(bootstrap_servers="localhost:9092", logger=_test_logger) +async def test_multiple_group_status_mixed_errors(test_settings: Settings) -> None: + m = NativeConsumerGroupMonitor(settings=test_settings, logger=_test_logger) gids = [f"none-{uuid4().hex[:6]}", f"none-{uuid4().hex[:6]}"] res = await m.get_multiple_group_status(gids, timeout=5.0, include_lag=False) assert set(res.keys()) == set(gids) diff --git a/backend/tests/integration/events/test_consumer_lifecycle.py b/backend/tests/integration/events/test_consumer_lifecycle.py index eb63b770..f1628142 100644 --- a/backend/tests/integration/events/test_consumer_lifecycle.py +++ b/backend/tests/integration/events/test_consumer_lifecycle.py @@ -2,21 +2,28 @@ from uuid import uuid4 import pytest +from dishka import AsyncContainer + from app.domain.enums.kafka import KafkaTopic from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer from app.events.schema.schema_registry import SchemaRegistryManager from app.settings import Settings -pytestmark = [pytest.mark.integration, pytest.mark.kafka] +# xdist_group: Kafka consumer creation can crash librdkafka when multiple workers +# instantiate Consumer() objects simultaneously. Serial execution prevents this. +pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.xdist_group("kafka_consumers")] _test_logger = logging.getLogger("test.events.consumer_lifecycle") @pytest.mark.asyncio -async def test_consumer_start_status_seek_and_stop(scope) -> None: # type: ignore[valid-type] +async def test_consumer_start_status_seek_and_stop(scope: AsyncContainer) -> None: registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) settings: Settings = await scope.get(Settings) - cfg = ConsumerConfig(bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, group_id=f"test-consumer-{uuid4().hex[:6]}") + cfg = ConsumerConfig( + bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, + group_id=f"test-consumer-{uuid4().hex[:6]}", + ) disp = EventDispatcher(logger=_test_logger) c = UnifiedConsumer( cfg, diff --git a/backend/tests/integration/events/test_dlq_handler.py b/backend/tests/integration/events/test_dlq_handler.py index 5659529b..eb930b17 100644 --- a/backend/tests/integration/events/test_dlq_handler.py +++ b/backend/tests/integration/events/test_dlq_handler.py @@ -1,7 +1,10 @@ import logging import pytest +from dishka import AsyncContainer + from app.events.core import UnifiedProducer, create_dlq_error_handler, create_immediate_dlq_handler +from app.infrastructure.kafka.events.base import BaseEvent from app.infrastructure.kafka.events.metadata import AvroEventMetadata from app.infrastructure.kafka.events.saga import SagaStartedEvent @@ -11,11 +14,11 @@ @pytest.mark.asyncio -async def test_dlq_handler_with_retries(scope, monkeypatch): # type: ignore[valid-type] +async def test_dlq_handler_with_retries(scope: AsyncContainer, monkeypatch: pytest.MonkeyPatch) -> None: p: UnifiedProducer = await scope.get(UnifiedProducer) calls: list[tuple[str | None, str, str, int]] = [] - async def _record_send_to_dlq(original_event, original_topic, error, retry_count): # noqa: ANN001 + async def _record_send_to_dlq(original_event: BaseEvent, original_topic: str, error: Exception, retry_count: int) -> None: calls.append((original_event.event_id, original_topic, str(error), retry_count)) monkeypatch.setattr(p, "send_to_dlq", _record_send_to_dlq) @@ -38,11 +41,11 @@ async def _record_send_to_dlq(original_event, original_topic, error, retry_count @pytest.mark.asyncio -async def test_immediate_dlq_handler(scope, monkeypatch): # type: ignore[valid-type] +async def test_immediate_dlq_handler(scope: AsyncContainer, monkeypatch: pytest.MonkeyPatch) -> None: p: UnifiedProducer = await scope.get(UnifiedProducer) calls: list[tuple[str | None, str, str, int]] = [] - async def _record_send_to_dlq(original_event, original_topic, error, retry_count): # noqa: ANN001 + async def _record_send_to_dlq(original_event: BaseEvent, original_topic: str, error: Exception, retry_count: int) -> None: calls.append((original_event.event_id, original_topic, str(error), retry_count)) monkeypatch.setattr(p, "send_to_dlq", _record_send_to_dlq) diff --git a/backend/tests/integration/events/test_event_dispatcher.py b/backend/tests/integration/events/test_event_dispatcher.py index aa65d181..0195acbf 100644 --- a/backend/tests/integration/events/test_event_dispatcher.py +++ b/backend/tests/integration/events/test_event_dispatcher.py @@ -3,23 +3,28 @@ import uuid import pytest +from dishka import AsyncContainer + from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic from app.events.core import UnifiedConsumer, UnifiedProducer from app.events.core.dispatcher import EventDispatcher from app.events.core.types import ConsumerConfig from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas +from app.infrastructure.kafka.events.base import BaseEvent from app.settings import Settings from tests.helpers import make_execution_requested_event -pytestmark = [pytest.mark.integration, pytest.mark.kafka] +# xdist_group: Kafka consumer creation can crash librdkafka when multiple workers +# instantiate Consumer() objects simultaneously. Serial execution prevents this. +pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.xdist_group("kafka_consumers")] _test_logger = logging.getLogger("test.events.event_dispatcher") @pytest.mark.asyncio -async def test_dispatcher_with_multiple_handlers(scope) -> None: # type: ignore[valid-type] +async def test_dispatcher_with_multiple_handlers(scope: AsyncContainer) -> None: # Ensure schema registry is ready registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) settings: Settings = await scope.get(Settings) @@ -31,11 +36,11 @@ async def test_dispatcher_with_multiple_handlers(scope) -> None: # type: ignore h2_called = asyncio.Event() @dispatcher.register(EventType.EXECUTION_REQUESTED) - async def h1(_e) -> None: # noqa: ANN001 + async def h1(_e: BaseEvent) -> None: h1_called.set() @dispatcher.register(EventType.EXECUTION_REQUESTED) - async def h2(_e) -> None: # noqa: ANN001 + async def h2(_e: BaseEvent) -> None: h2_called.set() # Real consumer against execution-events @@ -52,7 +57,7 @@ async def h2(_e) -> None: # noqa: ANN001 settings=settings, logger=_test_logger, ) - await consumer.start([str(KafkaTopic.EXECUTION_EVENTS)]) + await consumer.start([KafkaTopic.EXECUTION_EVENTS]) # Produce a request event via DI producer: UnifiedProducer = await scope.get(UnifiedProducer) diff --git a/backend/tests/integration/events/test_event_store.py b/backend/tests/integration/events/test_event_store.py new file mode 100644 index 00000000..45b18304 --- /dev/null +++ b/backend/tests/integration/events/test_event_store.py @@ -0,0 +1,153 @@ +import logging +import uuid +from datetime import datetime, timedelta, timezone + +import pytest +from dishka import AsyncContainer + +from app.db.docs import EventDocument +from app.domain.enums.events import EventType +from app.events.event_store import EventStore +from app.infrastructure.kafka.events.base import BaseEvent +from tests.helpers import make_execution_requested_event + +pytestmark = [pytest.mark.integration, pytest.mark.mongodb] + +_test_logger = logging.getLogger("test.events.event_store") + + +@pytest.mark.asyncio +async def test_event_store_stores_single_event(scope: AsyncContainer) -> None: + """Test that EventStore.store_event() persists an event to MongoDB.""" + store: EventStore = await scope.get(EventStore) + + # Create a unique event + execution_id = f"exec-{uuid.uuid4().hex[:8]}" + event = make_execution_requested_event(execution_id=execution_id) + + # Store the event + result = await store.store_event(event) + assert result is True + + # Verify it's in MongoDB + doc = await EventDocument.find_one({"event_id": event.event_id}) + assert doc is not None + assert doc.event_id == event.event_id + assert doc.event_type == EventType.EXECUTION_REQUESTED + assert doc.aggregate_id == execution_id + assert doc.stored_at is not None + assert doc.ttl_expires_at is not None + # TTL should be ~90 days in the future + assert doc.ttl_expires_at > datetime.now(timezone.utc) + timedelta(days=89) + + +@pytest.mark.asyncio +async def test_event_store_stores_batch(scope: AsyncContainer) -> None: + """Test that EventStore.store_batch() persists multiple events.""" + store: EventStore = await scope.get(EventStore) + + # Create multiple unique events + events: list[BaseEvent] = [ + make_execution_requested_event(execution_id=f"exec-batch-{uuid.uuid4().hex[:8]}") + for _ in range(5) + ] + + # Store the batch + results = await store.store_batch(events) + + assert results["total"] == 5 + assert results["stored"] == 5 + assert results["duplicates"] == 0 + assert results["failed"] == 0 + + # Verify all events are in MongoDB + for event in events: + doc = await EventDocument.find_one({"event_id": event.event_id}) + assert doc is not None + assert doc.event_type == EventType.EXECUTION_REQUESTED + + +@pytest.mark.asyncio +async def test_event_store_handles_duplicates(scope: AsyncContainer) -> None: + """Test that EventStore handles duplicate event IDs gracefully.""" + store: EventStore = await scope.get(EventStore) + + # Create an event + event = make_execution_requested_event(execution_id=f"exec-dup-{uuid.uuid4().hex[:8]}") + + # Store it twice + result1 = await store.store_event(event) + result2 = await store.store_event(event) + + # Both should succeed (second is a no-op due to duplicate handling) + assert result1 is True + assert result2 is True + + # Only one document should exist + count = await EventDocument.find({"event_id": event.event_id}).count() + assert count == 1 + + +@pytest.mark.asyncio +async def test_event_store_batch_handles_duplicates(scope: AsyncContainer) -> None: + """Test that store_batch handles duplicates within the batch.""" + store: EventStore = await scope.get(EventStore) + + # Create an event and store it first + event = make_execution_requested_event(execution_id=f"exec-batch-dup-{uuid.uuid4().hex[:8]}") + await store.store_event(event) + + # Create a batch with one new event and one duplicate + new_event = make_execution_requested_event(execution_id=f"exec-batch-new-{uuid.uuid4().hex[:8]}") + batch: list[BaseEvent] = [new_event, event] # event is already stored + + results = await store.store_batch(batch) + + assert results["total"] == 2 + assert results["stored"] == 1 # Only the new one + assert results["duplicates"] == 1 # The duplicate + + +@pytest.mark.asyncio +async def test_event_store_retrieves_by_id(scope: AsyncContainer) -> None: + """Test that EventStore.get_event() retrieves a stored event.""" + store: EventStore = await scope.get(EventStore) + + # Create and store an event + execution_id = f"exec-get-{uuid.uuid4().hex[:8]}" + event = make_execution_requested_event(execution_id=execution_id, script="print('test')") + await store.store_event(event) + + # Retrieve it + retrieved = await store.get_event(event.event_id) + + assert retrieved is not None + assert retrieved.event_id == event.event_id + assert retrieved.event_type == EventType.EXECUTION_REQUESTED + + +@pytest.mark.asyncio +async def test_event_store_retrieves_by_type(scope: AsyncContainer) -> None: + """Test that EventStore.get_events_by_type() works correctly.""" + store: EventStore = await scope.get(EventStore) + + # Store a few events + unique_prefix = uuid.uuid4().hex[:8] + events: list[BaseEvent] = [ + make_execution_requested_event(execution_id=f"exec-type-{unique_prefix}-{i}") + for i in range(3) + ] + await store.store_batch(events) + + # Query by type + retrieved = await store.get_events_by_type( + EventType.EXECUTION_REQUESTED, + limit=100, + ) + + # Should find at least our 3 events + assert len(retrieved) >= 3 + + # All should be EXECUTION_REQUESTED + for ev in retrieved: + assert ev.event_type == EventType.EXECUTION_REQUESTED diff --git a/backend/tests/integration/events/test_event_store_consumer.py b/backend/tests/integration/events/test_event_store_consumer.py deleted file mode 100644 index ec35a99b..00000000 --- a/backend/tests/integration/events/test_event_store_consumer.py +++ /dev/null @@ -1,64 +0,0 @@ -import logging -import uuid - -import pytest -from app.core.database_context import Database -from app.domain.enums.kafka import KafkaTopic -from app.events.core import UnifiedProducer -from app.events.event_store import EventStore -from app.events.event_store_consumer import EventStoreConsumer, create_event_store_consumer -from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.domain.enums.auth import LoginMethod -from app.infrastructure.kafka.events.metadata import AvroEventMetadata -from app.infrastructure.kafka.events.user import UserLoggedInEvent -from app.settings import Settings - -pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.mongodb] - -_test_logger = logging.getLogger("test.events.event_store_consumer") - - -@pytest.mark.asyncio -async def test_event_store_consumer_stores_events(scope) -> None: # type: ignore[valid-type] - # Ensure schemas - registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) - await initialize_event_schemas(registry) - - # Resolve DI - producer: UnifiedProducer = await scope.get(UnifiedProducer) - db: Database = await scope.get(Database) - store: EventStore = await scope.get(EventStore) - settings: Settings = await scope.get(Settings) - - # Build an event - ev = UserLoggedInEvent( - user_id=f"u-{uuid.uuid4().hex[:6]}", - login_method=LoginMethod.PASSWORD, - metadata=AvroEventMetadata(service_name="tests", service_version="1.0.0"), - ) - - # Create a tuned consumer (fast batch timeout) limited to user-events - consumer: EventStoreConsumer = create_event_store_consumer( - event_store=store, - topics=[KafkaTopic.USER_EVENTS], - schema_registry_manager=registry, - settings=settings, - logger=_test_logger, - producer=producer, - batch_size=10, - batch_timeout_seconds=0.5, - ) - - # Start the consumer and publish - async with consumer: - await producer.produce(ev, key=ev.metadata.user_id or "u") - - # Wait until the event is persisted in Mongo - coll = db.get_collection("events") - from tests.helpers.eventually import eventually - - async def _exists() -> None: - doc = await coll.find_one({"event_id": ev.event_id}) - assert doc is not None - - await eventually(_exists, timeout=12.0, interval=0.2) diff --git a/backend/tests/integration/events/test_producer_roundtrip.py b/backend/tests/integration/events/test_producer_roundtrip.py index c35364b9..81ef5865 100644 --- a/backend/tests/integration/events/test_producer_roundtrip.py +++ b/backend/tests/integration/events/test_producer_roundtrip.py @@ -1,10 +1,11 @@ -import json import logging from uuid import uuid4 import pytest from app.events.core import ProducerConfig, UnifiedProducer from app.events.schema.schema_registry import SchemaRegistryManager +from app.settings import Settings +from dishka import AsyncContainer from tests.helpers import make_execution_requested_event @@ -14,9 +15,16 @@ @pytest.mark.asyncio -async def test_unified_producer_start_produce_send_to_dlq_stop(scope): # type: ignore[valid-type] +async def test_unified_producer_start_produce_send_to_dlq_stop( + scope: AsyncContainer, test_settings: Settings +) -> None: schema: SchemaRegistryManager = await scope.get(SchemaRegistryManager) - prod = UnifiedProducer(ProducerConfig(bootstrap_servers="localhost:9092"), schema, logger=_test_logger) + prod = UnifiedProducer( + ProducerConfig(bootstrap_servers=test_settings.KAFKA_BOOTSTRAP_SERVERS), + schema, + logger=_test_logger, + settings=test_settings, + ) async with prod: ev = make_execution_requested_event(execution_id=f"exec-{uuid4().hex[:8]}") @@ -27,18 +35,3 @@ async def test_unified_producer_start_produce_send_to_dlq_stop(scope): # type: st = prod.get_status() assert st["running"] is True and st["state"] == "running" - - -def test_producer_handle_stats_path(): - # Directly run stats parsing to cover branch logic; avoid relying on timing - from app.events.core.producer import ProducerMetrics - from app.events.core.producer import UnifiedProducer as UP - - m = ProducerMetrics() - p = object.__new__(UP) # bypass __init__ safely for method call - # Inject required attributes - p._metrics = m # type: ignore[attr-defined] - p._stats_callback = None # type: ignore[attr-defined] - payload = json.dumps({"msg_cnt": 1, "topics": {"t": {"partitions": {"0": {"msgq_cnt": 2, "rtt": {"avg": 5}}}}}}) - UP._handle_stats(p, payload) # type: ignore[misc] - assert m.queue_size == 1 and m.avg_latency_ms > 0 diff --git a/backend/tests/integration/events/test_schema_registry_roundtrip.py b/backend/tests/integration/events/test_schema_registry_roundtrip.py index 4791c16f..ffef9953 100644 --- a/backend/tests/integration/events/test_schema_registry_roundtrip.py +++ b/backend/tests/integration/events/test_schema_registry_roundtrip.py @@ -1,9 +1,10 @@ import logging import pytest +from dishka import AsyncContainer + from app.events.schema.schema_registry import MAGIC_BYTE, SchemaRegistryManager from app.settings import Settings - from tests.helpers import make_execution_requested_event pytestmark = [pytest.mark.integration] @@ -12,14 +13,14 @@ @pytest.mark.asyncio -async def test_schema_registry_serialize_deserialize_roundtrip(scope): # type: ignore[valid-type] +async def test_schema_registry_serialize_deserialize_roundtrip(scope: AsyncContainer) -> None: reg: SchemaRegistryManager = await scope.get(SchemaRegistryManager) # Schema registration happens lazily in serialize_event ev = make_execution_requested_event(execution_id="e-rt") data = reg.serialize_event(ev) assert data.startswith(MAGIC_BYTE) back = reg.deserialize_event(data, topic=str(ev.topic)) - assert back.event_id == ev.event_id and back.execution_id == ev.execution_id + assert back.event_id == ev.event_id and getattr(back, "execution_id", None) == ev.execution_id # initialize_schemas should be a no-op if already initialized; call to exercise path await reg.initialize_schemas() diff --git a/backend/tests/integration/idempotency/test_consumer_idempotent.py b/backend/tests/integration/idempotency/test_consumer_idempotent.py index bdcc04d9..f1554b50 100644 --- a/backend/tests/integration/idempotency/test_consumer_idempotent.py +++ b/backend/tests/integration/idempotency/test_consumer_idempotent.py @@ -3,25 +3,34 @@ import uuid import pytest +from dishka import AsyncContainer from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer from app.events.core.dispatcher import EventDispatcher as Disp from app.events.schema.schema_registry import SchemaRegistryManager -from tests.helpers import make_execution_requested_event +from app.infrastructure.kafka.events.base import BaseEvent from app.services.idempotency.idempotency_manager import IdempotencyManager from app.services.idempotency.middleware import IdempotentConsumerWrapper from app.settings import Settings +from tests.helpers import make_execution_requested_event from tests.helpers.eventually import eventually -pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.redis] +# xdist_group: Kafka consumer creation can crash librdkafka when multiple workers +# instantiate Consumer() objects simultaneously. Serial execution prevents this. +pytestmark = [ + pytest.mark.integration, + pytest.mark.kafka, + pytest.mark.redis, + pytest.mark.xdist_group("kafka_consumers"), +] _test_logger = logging.getLogger("test.idempotency.consumer_idempotent") @pytest.mark.asyncio -async def test_consumer_idempotent_wrapper_blocks_duplicates(scope) -> None: # type: ignore[valid-type] +async def test_consumer_idempotent_wrapper_blocks_duplicates(scope: AsyncContainer) -> None: producer: UnifiedProducer = await scope.get(UnifiedProducer) idm: IdempotencyManager = await scope.get(IdempotencyManager) registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) @@ -32,7 +41,7 @@ async def test_consumer_idempotent_wrapper_blocks_duplicates(scope) -> None: # seen = {"n": 0} @disp.register(EventType.EXECUTION_REQUESTED) - async def handle(_ev): # noqa: ANN001 + async def handle(_ev: BaseEvent) -> None: seen["n"] += 1 # Real consumer with idempotent wrapper @@ -59,6 +68,8 @@ async def handle(_ev): # noqa: ANN001 ) await wrapper.start([KafkaTopic.EXECUTION_EVENTS]) + # Allow time for consumer to join group and get partition assignments + await asyncio.sleep(2) try: # Produce the same event twice (same event_id) execution_id = f"e-{uuid.uuid4().hex[:8]}" @@ -66,7 +77,7 @@ async def handle(_ev): # noqa: ANN001 await producer.produce(ev, key=execution_id) await producer.produce(ev, key=execution_id) - async def _one(): + async def _one() -> None: assert seen["n"] >= 1 await eventually(_one, timeout=10.0, interval=0.2) diff --git a/backend/tests/integration/idempotency/test_decorator_idempotent.py b/backend/tests/integration/idempotency/test_decorator_idempotent.py index 3f4d73ce..305a3500 100644 --- a/backend/tests/integration/idempotency/test_decorator_idempotent.py +++ b/backend/tests/integration/idempotency/test_decorator_idempotent.py @@ -1,9 +1,12 @@ import logging + import pytest +from dishka import AsyncContainer -from tests.helpers import make_execution_requested_event +from app.infrastructure.kafka.events.base import BaseEvent from app.services.idempotency.idempotency_manager import IdempotencyManager from app.services.idempotency.middleware import idempotent_handler +from tests.helpers import make_execution_requested_event _test_logger = logging.getLogger("test.idempotency.decorator_idempotent") @@ -12,13 +15,13 @@ @pytest.mark.asyncio -async def test_decorator_blocks_duplicate_event(scope) -> None: # type: ignore[valid-type] +async def test_decorator_blocks_duplicate_event(scope: AsyncContainer) -> None: idm: IdempotencyManager = await scope.get(IdempotencyManager) calls = {"n": 0} @idempotent_handler(idempotency_manager=idm, key_strategy="event_based", logger=_test_logger) - async def h(ev): # noqa: ANN001 + async def h(ev: BaseEvent) -> None: calls["n"] += 1 ev = make_execution_requested_event(execution_id="exec-deco-1") @@ -29,16 +32,16 @@ async def h(ev): # noqa: ANN001 @pytest.mark.asyncio -async def test_decorator_custom_key_blocks(scope) -> None: # type: ignore[valid-type] +async def test_decorator_custom_key_blocks(scope: AsyncContainer) -> None: idm: IdempotencyManager = await scope.get(IdempotencyManager) calls = {"n": 0} - def fixed_key(_ev): # noqa: ANN001 + def fixed_key(_ev: BaseEvent) -> str: return "fixed-key" @idempotent_handler(idempotency_manager=idm, key_strategy="custom", custom_key_func=fixed_key, logger=_test_logger) - async def h(ev): # noqa: ANN001 + async def h(ev: BaseEvent) -> None: calls["n"] += 1 e1 = make_execution_requested_event(execution_id="exec-deco-2a") diff --git a/backend/tests/integration/idempotency/test_idempotency.py b/backend/tests/integration/idempotency/test_idempotency.py index 6620ef6f..b93cdd79 100644 --- a/backend/tests/integration/idempotency/test_idempotency.py +++ b/backend/tests/integration/idempotency/test_idempotency.py @@ -2,16 +2,20 @@ import json import logging import uuid +from collections.abc import AsyncGenerator from datetime import datetime, timedelta, timezone +from typing import Any + import pytest +import redis.asyncio as redis -from app.domain.idempotency import IdempotencyRecord, IdempotencyStatus, IdempotencyStats +from app.domain.idempotency import IdempotencyRecord, IdempotencyStatus from app.infrastructure.kafka.events.base import BaseEvent -from tests.helpers import make_execution_requested_event +from app.infrastructure.kafka.events.execution import ExecutionRequestedEvent from app.services.idempotency.idempotency_manager import IdempotencyConfig, IdempotencyManager from app.services.idempotency.middleware import IdempotentEventHandler, idempotent_handler from app.services.idempotency.redis_repository import RedisIdempotencyRepository - +from tests.helpers import make_execution_requested_event pytestmark = [pytest.mark.integration, pytest.mark.redis] @@ -23,7 +27,7 @@ class TestIdempotencyManager: """IdempotencyManager backed by real Redis repository (DI-provided client).""" @pytest.fixture - async def manager(self, redis_client): # type: ignore[valid-type] + async def manager(self, redis_client: redis.Redis) -> AsyncGenerator[IdempotencyManager, None]: prefix = f"idemp_ut:{uuid.uuid4().hex[:6]}" cfg = IdempotencyConfig( key_prefix=prefix, @@ -42,7 +46,7 @@ async def manager(self, redis_client): # type: ignore[valid-type] await m.close() @pytest.mark.asyncio - async def test_complete_flow_new_event(self, manager): + async def test_complete_flow_new_event(self, manager: IdempotencyManager) -> None: """Test the complete flow for a new event""" real_event = make_execution_requested_event(execution_id="exec-123") # Check and reserve @@ -54,7 +58,7 @@ async def test_complete_flow_new_event(self, manager): assert result.key.startswith(f"{manager.config.key_prefix}:") # Verify it's in the repository - record = await manager._repo.find_by_key(result.key) # type: ignore[attr-defined] + record = await manager._repo.find_by_key(result.key) assert record is not None assert record.status == IdempotencyStatus.PROCESSING @@ -63,13 +67,14 @@ async def test_complete_flow_new_event(self, manager): assert success is True # Verify status updated - record = await manager._repo.find_by_key(result.key) # type: ignore[attr-defined] + record = await manager._repo.find_by_key(result.key) + assert record is not None assert record.status == IdempotencyStatus.COMPLETED assert record.completed_at is not None assert record.processing_duration_ms is not None @pytest.mark.asyncio - async def test_duplicate_detection(self, manager): + async def test_duplicate_detection(self, manager: IdempotencyManager) -> None: """Test that duplicates are properly detected""" real_event = make_execution_requested_event(execution_id="exec-dupe-1") # First request @@ -85,7 +90,7 @@ async def test_duplicate_detection(self, manager): assert result2.status == IdempotencyStatus.COMPLETED @pytest.mark.asyncio - async def test_concurrent_requests_race_condition(self, manager): + async def test_concurrent_requests_race_condition(self, manager: IdempotencyManager) -> None: """Test handling of concurrent requests for the same event""" real_event = make_execution_requested_event(execution_id="exec-race-1") # Simulate concurrent requests @@ -105,7 +110,7 @@ async def test_concurrent_requests_race_condition(self, manager): assert duplicate_count == 4 @pytest.mark.asyncio - async def test_processing_timeout_allows_retry(self, manager): + async def test_processing_timeout_allows_retry(self, manager: IdempotencyManager) -> None: """Test that stuck processing allows retry after timeout""" real_event = make_execution_requested_event(execution_id="exec-timeout-1") # First request @@ -113,9 +118,10 @@ async def test_processing_timeout_allows_retry(self, manager): assert result1.is_duplicate is False # Manually update the created_at to simulate old processing - record = await manager._repo.find_by_key(result1.key) # type: ignore[attr-defined] + record = await manager._repo.find_by_key(result1.key) + assert record is not None record.created_at = datetime.now(timezone.utc) - timedelta(seconds=10) - await manager._repo.update_record(record) # type: ignore[attr-defined] + await manager._repo.update_record(record) # Second request should be allowed due to timeout result2 = await manager.check_and_reserve(real_event, key_strategy="event_based") @@ -123,7 +129,7 @@ async def test_processing_timeout_allows_retry(self, manager): assert result2.status == IdempotencyStatus.PROCESSING @pytest.mark.asyncio - async def test_content_hash_strategy(self, manager): + async def test_content_hash_strategy(self, manager: IdempotencyManager) -> None: """Test content-based deduplication""" # Two events with same content and same execution_id event1 = make_execution_requested_event( @@ -147,7 +153,7 @@ async def test_content_hash_strategy(self, manager): assert result2.is_duplicate is True @pytest.mark.asyncio - async def test_failed_event_handling(self, manager): + async def test_failed_event_handling(self, manager: IdempotencyManager) -> None: """Test marking events as failed""" real_event = make_execution_requested_event(execution_id="exec-failed-1") # Reserve @@ -160,13 +166,14 @@ async def test_failed_event_handling(self, manager): assert success is True # Verify status and error - record = await manager._repo.find_by_key(result.key) # type: ignore[attr-defined] + record = await manager._repo.find_by_key(result.key) + assert record is not None assert record.status == IdempotencyStatus.FAILED assert record.error == error_msg assert record.completed_at is not None @pytest.mark.asyncio - async def test_result_caching(self, manager): + async def test_result_caching(self, manager: IdempotencyManager) -> None: """Test caching of results""" real_event = make_execution_requested_event(execution_id="exec-cache-1") # Reserve @@ -192,7 +199,7 @@ async def test_result_caching(self, manager): assert duplicate_result.has_cached_result is True @pytest.mark.asyncio - async def test_stats_aggregation(self, manager): + async def test_stats_aggregation(self, manager: IdempotencyManager) -> None: """Test statistics aggregation""" # Create various events with different statuses events = [] @@ -224,7 +231,7 @@ async def test_stats_aggregation(self, manager): assert stats.prefix == manager.config.key_prefix @pytest.mark.asyncio - async def test_remove_key(self, manager): + async def test_remove_key(self, manager: IdempotencyManager) -> None: """Test removing idempotency keys""" real_event = make_execution_requested_event(execution_id="exec-remove-1") # Add a key @@ -236,7 +243,7 @@ async def test_remove_key(self, manager): assert removed is True # Verify it's gone - record = await manager._repo.find_by_key(result.key) # type: ignore[attr-defined] + record = await manager._repo.find_by_key(result.key) assert record is None # Can process again @@ -248,7 +255,7 @@ class TestIdempotentEventHandlerIntegration: """Test IdempotentEventHandler with real components""" @pytest.fixture - async def manager(self, redis_client): # type: ignore[valid-type] + async def manager(self, redis_client: redis.Redis) -> AsyncGenerator[IdempotencyManager, None]: prefix = f"handler_test:{uuid.uuid4().hex[:6]}" config = IdempotencyConfig(key_prefix=prefix, enable_metrics=False) repo = RedisIdempotencyRepository(redis_client, key_prefix=prefix) @@ -260,11 +267,11 @@ async def manager(self, redis_client): # type: ignore[valid-type] await m.close() @pytest.mark.asyncio - async def test_handler_processes_new_event(self, manager): + async def test_handler_processes_new_event(self, manager: IdempotencyManager) -> None: """Test that handler processes new events""" - processed_events = [] + processed_events: list[BaseEvent] = [] - async def actual_handler(event: BaseEvent): + async def actual_handler(event: BaseEvent) -> None: processed_events.append(event) # Create idempotent handler @@ -284,11 +291,11 @@ async def actual_handler(event: BaseEvent): assert processed_events[0] == real_event @pytest.mark.asyncio - async def test_handler_blocks_duplicate(self, manager): + async def test_handler_blocks_duplicate(self, manager: IdempotencyManager) -> None: """Test that handler blocks duplicate events""" - processed_events = [] + processed_events: list[BaseEvent] = [] - async def actual_handler(event: BaseEvent): + async def actual_handler(event: BaseEvent) -> None: processed_events.append(event) # Create idempotent handler @@ -308,10 +315,10 @@ async def actual_handler(event: BaseEvent): assert len(processed_events) == 1 @pytest.mark.asyncio - async def test_handler_with_failure(self, manager): + async def test_handler_with_failure(self, manager: IdempotencyManager) -> None: """Test handler marks failure on exception""" - async def failing_handler(event: BaseEvent): + async def failing_handler(event: BaseEvent) -> None: # noqa: ARG001 raise ValueError("Processing failed") handler = IdempotentEventHandler( @@ -328,19 +335,21 @@ async def failing_handler(event: BaseEvent): # Verify marked as failed key = f"{manager.config.key_prefix}:{real_event.event_type}:{real_event.event_id}" - record = await manager._repo.find_by_key(key) # type: ignore[attr-defined] + record = await manager._repo.find_by_key(key) + assert record is not None assert record.status == IdempotencyStatus.FAILED + assert record.error is not None assert "Processing failed" in record.error @pytest.mark.asyncio - async def test_handler_duplicate_callback(self, manager): + async def test_handler_duplicate_callback(self, manager: IdempotencyManager) -> None: """Test duplicate callback is invoked""" - duplicate_events = [] + duplicate_events: list[tuple[BaseEvent, Any]] = [] - async def actual_handler(event: BaseEvent): + async def actual_handler(event: BaseEvent) -> None: # noqa: ARG001 pass # Do nothing - async def on_duplicate(event: BaseEvent, result): + async def on_duplicate(event: BaseEvent, result: Any) -> None: duplicate_events.append((event, result)) handler = IdempotentEventHandler( @@ -362,9 +371,9 @@ async def on_duplicate(event: BaseEvent, result): assert duplicate_events[0][1].is_duplicate is True @pytest.mark.asyncio - async def test_decorator_integration(self, manager): + async def test_decorator_integration(self, manager: IdempotencyManager) -> None: """Test the @idempotent_handler decorator""" - processed_events = [] + processed_events: list[BaseEvent] = [] @idempotent_handler( idempotency_manager=manager, @@ -372,7 +381,7 @@ async def test_decorator_integration(self, manager): ttl_seconds=300, logger=_test_logger, ) - async def my_handler(event: BaseEvent): + async def my_handler(event: BaseEvent) -> None: processed_events.append(event) # Process same event twice @@ -394,18 +403,18 @@ async def my_handler(event: BaseEvent): assert len(processed_events) == 1 # Still only one @pytest.mark.asyncio - async def test_custom_key_function(self, manager): + async def test_custom_key_function(self, manager: IdempotencyManager) -> None: """Test handler with custom key function""" - processed_scripts = [] + processed_scripts: list[str] = [] async def process_script(event: BaseEvent) -> None: - processed_scripts.append(event.script) + script: str = getattr(event, "script", "") + processed_scripts.append(script) def extract_script_key(event: BaseEvent) -> str: # Custom key based on script content only - if hasattr(event, 'script'): - return f"script:{hash(event.script)}" - return str(event.event_id) + script: str = getattr(event, "script", "") + return f"script:{hash(script)}" handler = IdempotentEventHandler( handler=process_script, @@ -445,25 +454,25 @@ def extract_script_key(event: BaseEvent) -> str: assert processed_scripts[0] == "print('hello')" @pytest.mark.asyncio - async def test_invalid_key_strategy(self, manager): + async def test_invalid_key_strategy(self, manager: IdempotencyManager) -> None: """Test that invalid key strategy raises error""" real_event = make_execution_requested_event(execution_id="invalid-strategy-1") with pytest.raises(ValueError, match="Invalid key strategy"): await manager.check_and_reserve(real_event, key_strategy="invalid_strategy") @pytest.mark.asyncio - async def test_custom_key_without_custom_key_param(self, manager): + async def test_custom_key_without_custom_key_param(self, manager: IdempotencyManager) -> None: """Test that custom strategy without custom_key raises error""" real_event = make_execution_requested_event(execution_id="custom-key-missing-1") with pytest.raises(ValueError, match="Invalid key strategy"): await manager.check_and_reserve(real_event, key_strategy="custom") @pytest.mark.asyncio - async def test_get_cached_json_existing(self, manager): + async def test_get_cached_json_existing(self, manager: IdempotencyManager) -> None: """Test retrieving cached JSON result""" # First complete with cached result real_event = make_execution_requested_event(execution_id="cache-exist-1") - result = await manager.check_and_reserve(real_event, key_strategy="event_based") + await manager.check_and_reserve(real_event, key_strategy="event_based") cached_data = json.dumps({"output": "test", "code": 0}) await manager.mark_completed_with_json(real_event, cached_data, "event_based") @@ -472,7 +481,7 @@ async def test_get_cached_json_existing(self, manager): assert retrieved == cached_data @pytest.mark.asyncio - async def test_get_cached_json_non_existing(self, manager): + async def test_get_cached_json_non_existing(self, manager: IdempotencyManager) -> None: """Test retrieving non-existing cached result raises assertion""" real_event = make_execution_requested_event(execution_id="cache-miss-1") # Trying to get cached result for non-existent key should raise @@ -480,7 +489,7 @@ async def test_get_cached_json_non_existing(self, manager): await manager.get_cached_json(real_event, "event_based", None) @pytest.mark.asyncio - async def test_cleanup_expired_keys(self, manager): + async def test_cleanup_expired_keys(self, manager: IdempotencyManager) -> None: """Test cleanup of expired keys""" # Create expired record expired_key = f"{manager.config.key_prefix}:expired" @@ -493,15 +502,15 @@ async def test_cleanup_expired_keys(self, manager): ttl_seconds=3600, # 1 hour TTL completed_at=datetime.now(timezone.utc) - timedelta(hours=2) ) - await manager._repo.insert_processing(expired_record) # type: ignore[attr-defined] + await manager._repo.insert_processing(expired_record) # Cleanup should detect it as expired # Note: actual cleanup implementation depends on repository - record = await manager._repo.find_by_key(expired_key) # type: ignore[attr-defined] + record = await manager._repo.find_by_key(expired_key) assert record is not None # Still exists until explicit cleanup @pytest.mark.asyncio - async def test_metrics_enabled(self, redis_client): # type: ignore[valid-type] + async def test_metrics_enabled(self, redis_client: redis.Redis) -> None: """Test manager with metrics enabled""" config = IdempotencyConfig(key_prefix=f"metrics:{uuid.uuid4().hex[:6]}", enable_metrics=True) repository = RedisIdempotencyRepository(redis_client, key_prefix=config.key_prefix) @@ -515,7 +524,7 @@ async def test_metrics_enabled(self, redis_client): # type: ignore[valid-type] await manager.close() @pytest.mark.asyncio - async def test_content_hash_with_fields(self, manager): + async def test_content_hash_with_fields(self, manager: IdempotencyManager) -> None: """Test content hash with specific fields""" event1 = make_execution_requested_event( execution_id="exec-1", diff --git a/backend/tests/integration/idempotency/test_idempotent_handler.py b/backend/tests/integration/idempotency/test_idempotent_handler.py index 76ea369a..a7872c8b 100644 --- a/backend/tests/integration/idempotency/test_idempotent_handler.py +++ b/backend/tests/integration/idempotency/test_idempotent_handler.py @@ -1,12 +1,12 @@ import logging import pytest +from dishka import AsyncContainer -from app.events.schema.schema_registry import SchemaRegistryManager -from tests.helpers import make_execution_requested_event +from app.infrastructure.kafka.events.base import BaseEvent from app.services.idempotency.idempotency_manager import IdempotencyManager from app.services.idempotency.middleware import IdempotentEventHandler - +from tests.helpers import make_execution_requested_event pytestmark = [pytest.mark.integration] @@ -14,12 +14,12 @@ @pytest.mark.asyncio -async def test_idempotent_handler_blocks_duplicates(scope) -> None: # type: ignore[valid-type] +async def test_idempotent_handler_blocks_duplicates(scope: AsyncContainer) -> None: manager: IdempotencyManager = await scope.get(IdempotencyManager) - processed: list[str] = [] + processed: list[str | None] = [] - async def _handler(ev) -> None: # noqa: ANN001 + async def _handler(ev: BaseEvent) -> None: processed.append(ev.event_id) handler = IdempotentEventHandler( @@ -38,13 +38,13 @@ async def _handler(ev) -> None: # noqa: ANN001 @pytest.mark.asyncio -async def test_idempotent_handler_content_hash_blocks_same_content(scope) -> None: # type: ignore[valid-type] +async def test_idempotent_handler_content_hash_blocks_same_content(scope: AsyncContainer) -> None: manager: IdempotencyManager = await scope.get(IdempotencyManager) processed: list[str] = [] - async def _handler(ev) -> None: # noqa: ANN001 - processed.append(ev.execution_id) + async def _handler(ev: BaseEvent) -> None: + processed.append(getattr(ev, "execution_id", "")) handler = IdempotentEventHandler( handler=_handler, diff --git a/backend/tests/integration/notifications/test_notification_sse.py b/backend/tests/integration/notifications/test_notification_sse.py index c2fbb401..5beabd4f 100644 --- a/backend/tests/integration/notifications/test_notification_sse.py +++ b/backend/tests/integration/notifications/test_notification_sse.py @@ -1,7 +1,9 @@ import asyncio import json from uuid import uuid4 + import pytest +from dishka import AsyncContainer from app.domain.enums.notification import NotificationChannel, NotificationSeverity from app.schemas_pydantic.sse import RedisNotificationMessage @@ -13,7 +15,7 @@ @pytest.mark.asyncio -async def test_in_app_notification_published_to_sse(scope) -> None: # type: ignore[valid-type] +async def test_in_app_notification_published_to_sse(scope: AsyncContainer) -> None: svc: NotificationService = await scope.get(NotificationService) bus: SSERedisBus = await scope.get(SSERedisBus) diff --git a/backend/tests/integration/result_processor/test_result_processor.py b/backend/tests/integration/result_processor/test_result_processor.py index 5c9a98c4..cb897bbe 100644 --- a/backend/tests/integration/result_processor/test_result_processor.py +++ b/backend/tests/integration/result_processor/test_result_processor.py @@ -1,16 +1,16 @@ import asyncio import logging import uuid -from tests.helpers.eventually import eventually + import pytest +from dishka import AsyncContainer from app.core.database_context import Database - from app.db.repositories.execution_repository import ExecutionRepository from app.domain.enums.events import EventType from app.domain.enums.execution import ExecutionStatus -from app.domain.execution import DomainExecutionCreate from app.domain.enums.kafka import KafkaTopic +from app.domain.execution import DomainExecutionCreate from app.domain.execution.models import ResourceUsageDomain from app.events.core import UnifiedConsumer, UnifiedProducer from app.events.core.dispatcher import EventDispatcher @@ -18,17 +18,27 @@ from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.infrastructure.kafka.events.execution import ExecutionCompletedEvent from app.infrastructure.kafka.events.metadata import AvroEventMetadata +from app.infrastructure.kafka.events.system import ResultStoredEvent from app.services.idempotency import IdempotencyManager from app.services.result_processor.processor import ResultProcessor from app.settings import Settings -pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.mongodb] +from tests.helpers.eventually import eventually + +# xdist_group: Kafka consumer creation can crash librdkafka when multiple workers +# instantiate Consumer() objects simultaneously. Serial execution prevents this. +pytestmark = [ + pytest.mark.integration, + pytest.mark.kafka, + pytest.mark.mongodb, + pytest.mark.xdist_group("kafka_consumers"), +] _test_logger = logging.getLogger("test.result_processor.processor") @pytest.mark.asyncio -async def test_result_processor_persists_and_emits(scope) -> None: # type: ignore[valid-type] +async def test_result_processor_persists_and_emits(scope: AsyncContainer) -> None: # Ensure schemas registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) settings: Settings = await scope.get(Settings) @@ -65,7 +75,7 @@ async def test_result_processor_persists_and_emits(scope) -> None: # type: igno stored_received = asyncio.Event() @dispatcher.register(EventType.RESULT_STORED) - async def _stored(_event) -> None: # noqa: ANN001 + async def _stored(_event: ResultStoredEvent) -> None: stored_received.set() group_id = f"rp-test.{uuid.uuid4().hex[:6]}" @@ -82,7 +92,7 @@ async def _stored(_event) -> None: # noqa: ANN001 settings=settings, logger=_test_logger, ) - await stored_consumer.start([str(KafkaTopic.EXECUTION_RESULTS)]) + await stored_consumer.start([KafkaTopic.EXECUTION_RESULTS]) try: async with processor: diff --git a/backend/tests/integration/services/admin/test_admin_user_service.py b/backend/tests/integration/services/admin/test_admin_user_service.py index a392a908..ed6b3dca 100644 --- a/backend/tests/integration/services/admin/test_admin_user_service.py +++ b/backend/tests/integration/services/admin/test_admin_user_service.py @@ -1,8 +1,9 @@ from datetime import datetime, timezone import pytest -from app.core.database_context import Database +from dishka import AsyncContainer +from app.core.database_context import Database from app.domain.enums.user import UserRole from app.services.admin import AdminUserService @@ -10,7 +11,7 @@ @pytest.mark.asyncio -async def test_get_user_overview_basic(scope) -> None: # type: ignore[valid-type] +async def test_get_user_overview_basic(scope: AsyncContainer) -> None: svc: AdminUserService = await scope.get(AdminUserService) db: Database = await scope.get(Database) await db.get_collection("users").insert_one({ @@ -30,7 +31,7 @@ async def test_get_user_overview_basic(scope) -> None: # type: ignore[valid-typ @pytest.mark.asyncio -async def test_get_user_overview_user_not_found(scope) -> None: # type: ignore[valid-type] +async def test_get_user_overview_user_not_found(scope: AsyncContainer) -> None: svc: AdminUserService = await scope.get(AdminUserService) with pytest.raises(ValueError): await svc.get_user_overview("missing") diff --git a/backend/tests/integration/services/coordinator/test_execution_coordinator.py b/backend/tests/integration/services/coordinator/test_execution_coordinator.py index 7131b2ab..6043ede7 100644 --- a/backend/tests/integration/services/coordinator/test_execution_coordinator.py +++ b/backend/tests/integration/services/coordinator/test_execution_coordinator.py @@ -1,4 +1,5 @@ import pytest +from dishka import AsyncContainer from app.services.coordinator.coordinator import ExecutionCoordinator from tests.helpers import make_execution_requested_event @@ -7,7 +8,7 @@ @pytest.mark.asyncio -async def test_handle_requested_and_schedule(scope) -> None: # type: ignore[valid-type] +async def test_handle_requested_and_schedule(scope: AsyncContainer) -> None: coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) ev = make_execution_requested_event(execution_id="e-real-1") diff --git a/backend/tests/integration/services/events/test_event_bus.py b/backend/tests/integration/services/events/test_event_bus.py index 398300c0..bc0d453f 100644 --- a/backend/tests/integration/services/events/test_event_bus.py +++ b/backend/tests/integration/services/events/test_event_bus.py @@ -1,4 +1,5 @@ import pytest +from dishka import AsyncContainer from app.services.event_bus import EventBusEvent, EventBusManager from tests.helpers.eventually import eventually @@ -7,7 +8,7 @@ @pytest.mark.asyncio -async def test_event_bus_publish_subscribe(scope) -> None: # type: ignore[valid-type] +async def test_event_bus_publish_subscribe(scope: AsyncContainer) -> None: manager: EventBusManager = await scope.get(EventBusManager) bus = await manager.get_event_bus() @@ -19,7 +20,7 @@ async def handler(event: EventBusEvent) -> None: await bus.subscribe("test.*", handler) await bus.publish("test.created", {"x": 1}) - async def _received(): + async def _received() -> None: assert any(e.event_type == "test.created" for e in received) await eventually(_received, timeout=2.0, interval=0.05) diff --git a/backend/tests/integration/services/events/test_kafka_event_service.py b/backend/tests/integration/services/events/test_kafka_event_service.py index 8a13fdee..42fac8ec 100644 --- a/backend/tests/integration/services/events/test_kafka_event_service.py +++ b/backend/tests/integration/services/events/test_kafka_event_service.py @@ -1,4 +1,5 @@ import pytest +from dishka import AsyncContainer from app.db.repositories import EventRepository from app.domain.enums.events import EventType @@ -9,7 +10,7 @@ @pytest.mark.asyncio -async def test_publish_user_registered_event(scope) -> None: # type: ignore[valid-type] +async def test_publish_user_registered_event(scope: AsyncContainer) -> None: svc: KafkaEventService = await scope.get(KafkaEventService) repo: EventRepository = await scope.get(EventRepository) @@ -24,7 +25,7 @@ async def test_publish_user_registered_event(scope) -> None: # type: ignore[val @pytest.mark.asyncio -async def test_publish_execution_event(scope) -> None: # type: ignore[valid-type] +async def test_publish_execution_event(scope: AsyncContainer) -> None: svc: KafkaEventService = await scope.get(KafkaEventService) repo: EventRepository = await scope.get(EventRepository) @@ -40,7 +41,7 @@ async def test_publish_execution_event(scope) -> None: # type: ignore[valid-typ @pytest.mark.asyncio -async def test_publish_pod_event_and_without_metadata(scope) -> None: # type: ignore[valid-type] +async def test_publish_pod_event_and_without_metadata(scope: AsyncContainer) -> None: svc: KafkaEventService = await scope.get(KafkaEventService) repo: EventRepository = await scope.get(EventRepository) diff --git a/backend/tests/integration/services/execution/test_execution_service.py b/backend/tests/integration/services/execution/test_execution_service.py index 184a3494..1fffd3be 100644 --- a/backend/tests/integration/services/execution/test_execution_service.py +++ b/backend/tests/integration/services/execution/test_execution_service.py @@ -1,4 +1,5 @@ import pytest +from dishka import AsyncContainer from app.domain.execution import ResourceLimitsDomain from app.services.execution_service import ExecutionService @@ -7,7 +8,7 @@ @pytest.mark.asyncio -async def test_execute_script_and_limits(scope) -> None: # type: ignore[valid-type] +async def test_execute_script_and_limits(scope: AsyncContainer) -> None: svc: ExecutionService = await scope.get(ExecutionService) limits = await svc.get_k8s_resource_limits() assert isinstance(limits, ResourceLimitsDomain) diff --git a/backend/tests/integration/services/idempotency/test_redis_repository.py b/backend/tests/integration/services/idempotency/test_redis_repository.py index 7f96b783..f9442539 100644 --- a/backend/tests/integration/services/idempotency/test_redis_repository.py +++ b/backend/tests/integration/services/idempotency/test_redis_repository.py @@ -1,6 +1,8 @@ import json from datetime import datetime, timedelta, timezone + import pytest +import redis.asyncio as redis from pymongo.errors import DuplicateKeyError from app.domain.idempotency import IdempotencyRecord, IdempotencyStatus @@ -16,41 +18,43 @@ class TestHelperFunctions: - def test_iso_datetime(self): + def test_iso_datetime(self) -> None: dt = datetime(2025, 1, 15, 10, 30, 45, tzinfo=timezone.utc) result = _iso(dt) assert result == "2025-01-15T10:30:45+00:00" - def test_iso_datetime_with_timezone(self): + def test_iso_datetime_with_timezone(self) -> None: dt = datetime(2025, 1, 15, 10, 30, 45, tzinfo=timezone(timedelta(hours=5))) result = _iso(dt) assert result == "2025-01-15T05:30:45+00:00" - def test_json_default_datetime(self): + def test_json_default_datetime(self) -> None: dt = datetime(2025, 1, 15, 10, 30, 45, tzinfo=timezone.utc) result = _json_default(dt) assert result == "2025-01-15T10:30:45+00:00" - def test_json_default_other(self): + def test_json_default_other(self) -> None: obj = {"key": "value"} result = _json_default(obj) assert result == "{'key': 'value'}" - def test_parse_iso_datetime_variants(self): - assert _parse_iso_datetime("2025-01-15T10:30:45+00:00").year == 2025 - assert _parse_iso_datetime("2025-01-15T10:30:45Z").tzinfo == timezone.utc + def test_parse_iso_datetime_variants(self) -> None: + result1 = _parse_iso_datetime("2025-01-15T10:30:45+00:00") + assert result1 is not None and result1.year == 2025 + result2 = _parse_iso_datetime("2025-01-15T10:30:45Z") + assert result2 is not None and result2.tzinfo == timezone.utc assert _parse_iso_datetime(None) is None assert _parse_iso_datetime("") is None assert _parse_iso_datetime("not-a-date") is None @pytest.fixture -def repository(redis_client): # type: ignore[valid-type] +def repository(redis_client: redis.Redis) -> RedisIdempotencyRepository: return RedisIdempotencyRepository(redis_client, key_prefix="idempotency") @pytest.fixture -def sample_record(): +def sample_record() -> IdempotencyRecord: return IdempotencyRecord( key="test-key", status=IdempotencyStatus.PROCESSING, @@ -65,12 +69,12 @@ def sample_record(): ) -def test_full_key_helpers(repository): +def test_full_key_helpers(repository: RedisIdempotencyRepository) -> None: assert repository._full_key("my") == "idempotency:my" assert repository._full_key("idempotency:my") == "idempotency:my" -def test_doc_record_roundtrip(repository): +def test_doc_record_roundtrip(repository: RedisIdempotencyRepository) -> None: rec = IdempotencyRecord( key="k", status=IdempotencyStatus.COMPLETED, @@ -89,7 +93,11 @@ def test_doc_record_roundtrip(repository): @pytest.mark.asyncio -async def test_insert_find_update_delete_flow(repository, redis_client, sample_record): # type: ignore[valid-type] +async def test_insert_find_update_delete_flow( + repository: RedisIdempotencyRepository, + redis_client: redis.Redis, + sample_record: IdempotencyRecord, +) -> None: # Insert processing (NX) await repository.insert_processing(sample_record) key = repository._full_key(sample_record.key) @@ -121,14 +129,18 @@ async def test_insert_find_update_delete_flow(repository, redis_client, sample_r @pytest.mark.asyncio -async def test_update_record_when_missing(repository, sample_record): +async def test_update_record_when_missing( + repository: RedisIdempotencyRepository, sample_record: IdempotencyRecord +) -> None: # If key missing, update returns 0 res = await repository.update_record(sample_record) assert res == 0 @pytest.mark.asyncio -async def test_aggregate_status_counts(repository, redis_client): # type: ignore[valid-type] +async def test_aggregate_status_counts( + repository: RedisIdempotencyRepository, redis_client: redis.Redis +) -> None: # Seed few keys directly using repository for i, status in enumerate((IdempotencyStatus.PROCESSING, IdempotencyStatus.PROCESSING, IdempotencyStatus.COMPLETED)): rec = IdempotencyRecord( @@ -146,5 +158,5 @@ async def test_aggregate_status_counts(repository, redis_client): # type: ignor @pytest.mark.asyncio -async def test_health_check(repository): +async def test_health_check(repository: RedisIdempotencyRepository) -> None: await repository.health_check() # should not raise diff --git a/backend/tests/integration/services/notifications/test_notification_service.py b/backend/tests/integration/services/notifications/test_notification_service.py index c1faa79a..1d93b259 100644 --- a/backend/tests/integration/services/notifications/test_notification_service.py +++ b/backend/tests/integration/services/notifications/test_notification_service.py @@ -1,26 +1,27 @@ import pytest +from dishka import AsyncContainer from app.db.repositories import NotificationRepository from app.domain.enums.notification import NotificationChannel, NotificationSeverity -from app.domain.notification import DomainNotification +from app.domain.notification import DomainNotificationCreate from app.services.notification_service import NotificationService pytestmark = [pytest.mark.integration, pytest.mark.mongodb] @pytest.mark.asyncio -async def test_notification_service_crud_and_subscription(scope) -> None: # type: ignore[valid-type] +async def test_notification_service_crud_and_subscription(scope: AsyncContainer) -> None: svc: NotificationService = await scope.get(NotificationService) repo: NotificationRepository = await scope.get(NotificationRepository) # Create a notification via repository and then use service to mark/delete - n = DomainNotification(user_id="u1", severity=NotificationSeverity.MEDIUM, tags=["x"], channel=NotificationChannel.IN_APP, subject="s", body="b") - _nid = await repo.create_notification(n) - got = await repo.get_notification(n.notification_id, "u1") + n = DomainNotificationCreate(user_id="u1", severity=NotificationSeverity.MEDIUM, tags=["x"], channel=NotificationChannel.IN_APP, subject="s", body="b") + created = await repo.create_notification(n) + got = await repo.get_notification(created.notification_id, "u1") assert got is not None # Mark as read through service - ok = await svc.mark_as_read("u1", got.notification_id) + ok = await svc.mark_as_read("u1", created.notification_id) assert ok is True # Subscriptions via service wrapper calls the repo @@ -29,4 +30,4 @@ async def test_notification_service_crud_and_subscription(scope) -> None: # typ assert sub and sub.enabled is True # Delete via service - assert await svc.delete_notification("u1", got.notification_id) is True + assert await svc.delete_notification("u1", created.notification_id) is True diff --git a/backend/tests/integration/services/rate_limit/test_rate_limit_service.py b/backend/tests/integration/services/rate_limit/test_rate_limit_service.py index 24f11477..4ce27ecf 100644 --- a/backend/tests/integration/services/rate_limit/test_rate_limit_service.py +++ b/backend/tests/integration/services/rate_limit/test_rate_limit_service.py @@ -1,10 +1,11 @@ import asyncio import json -import time -from datetime import datetime, timezone +from collections.abc import Awaitable +from typing import Any, cast from uuid import uuid4 import pytest +from dishka import AsyncContainer from app.domain.rate_limit import ( EndpointGroup, @@ -19,7 +20,7 @@ @pytest.mark.asyncio -async def test_normalize_and_disabled_and_bypass_and_no_rule(scope) -> None: # type: ignore[valid-type] +async def test_normalize_and_disabled_and_bypass_and_no_rule(scope: AsyncContainer) -> None: svc: RateLimitService = await scope.get(RateLimitService) svc.prefix = f"{svc.prefix}{uuid4().hex[:6]}:" # ensure disabled for first path @@ -48,7 +49,7 @@ async def test_normalize_and_disabled_and_bypass_and_no_rule(scope) -> None: # @pytest.mark.asyncio -async def test_sliding_window_allowed_and_rejected(scope) -> None: # type: ignore[valid-type] +async def test_sliding_window_allowed_and_rejected(scope: AsyncContainer) -> None: svc: RateLimitService = await scope.get(RateLimitService) svc.prefix = f"{svc.prefix}{uuid4().hex[:6]}:" svc.settings.RATE_LIMIT_ENABLED = True # Enable rate limiting for this test @@ -56,7 +57,7 @@ async def test_sliding_window_allowed_and_rejected(scope) -> None: # type: igno rule = RateLimitRule(endpoint_pattern=r"^/api/v1/x", group=EndpointGroup.API, requests=3, window_seconds=5, algorithm=RateLimitAlgorithm.SLIDING_WINDOW) await svc.update_config(RateLimitConfig(default_rules=[rule])) - + # Make 3 requests - all should be allowed for i in range(3): ok = await svc.check_rate_limit("u", "/api/v1/x") @@ -73,14 +74,14 @@ async def test_sliding_window_allowed_and_rejected(scope) -> None: # type: igno @pytest.mark.asyncio -async def test_token_bucket_paths(scope) -> None: # type: ignore[valid-type] +async def test_token_bucket_paths(scope: AsyncContainer) -> None: svc: RateLimitService = await scope.get(RateLimitService) svc.prefix = f"{svc.prefix}{uuid4().hex[:6]}:" svc.settings.RATE_LIMIT_ENABLED = True # Enable rate limiting for this test rule = RateLimitRule(endpoint_pattern=r"^/api/v1/t", group=EndpointGroup.API, requests=2, window_seconds=10, burst_multiplier=1.0, algorithm=RateLimitAlgorithm.TOKEN_BUCKET) await svc.update_config(RateLimitConfig(default_rules=[rule])) - + # Make 2 requests - both should be allowed for i in range(2): ok = await svc.check_rate_limit("u", "/api/v1/t") @@ -101,7 +102,7 @@ async def test_token_bucket_paths(scope) -> None: # type: ignore[valid-type] @pytest.mark.asyncio -async def test_config_update_and_user_helpers(scope) -> None: # type: ignore[valid-type] +async def test_config_update_and_user_helpers(scope: AsyncContainer) -> None: svc: RateLimitService = await scope.get(RateLimitService) svc.prefix = f"{svc.prefix}{uuid4().hex[:6]}:" cfg = RateLimitConfig( @@ -124,7 +125,7 @@ async def test_config_update_and_user_helpers(scope) -> None: # type: ignore[va @pytest.mark.asyncio -async def test_ip_based_rate_limiting(scope) -> None: # type: ignore[valid-type] +async def test_ip_based_rate_limiting(scope: AsyncContainer) -> None: """Test IP-based rate limiting.""" svc: RateLimitService = await scope.get(RateLimitService) svc.prefix = f"{svc.prefix}{uuid4().hex[:6]}:" @@ -151,7 +152,7 @@ async def test_ip_based_rate_limiting(scope) -> None: # type: ignore[valid-type @pytest.mark.asyncio -async def test_get_config_roundtrip(scope) -> None: # type: ignore[valid-type] +async def test_get_config_roundtrip(scope: AsyncContainer) -> None: svc: RateLimitService = await scope.get(RateLimitService) svc.prefix = f"{svc.prefix}{uuid4().hex[:6]}:" cfg = RateLimitConfig(default_rules=[RateLimitRule(endpoint_pattern=r"^/z", group=EndpointGroup.API, requests=1, window_seconds=1)]) @@ -161,7 +162,7 @@ async def test_get_config_roundtrip(scope) -> None: # type: ignore[valid-type] @pytest.mark.asyncio -async def test_sliding_window_edge(scope) -> None: # type: ignore[valid-type] +async def test_sliding_window_edge(scope: AsyncContainer) -> None: svc: RateLimitService = await scope.get(RateLimitService) svc.prefix = f"{svc.prefix}{uuid4().hex[:6]}:" svc.settings.RATE_LIMIT_ENABLED = True # Enable rate limiting for this test @@ -176,16 +177,25 @@ async def test_sliding_window_edge(scope) -> None: # type: ignore[valid-type] @pytest.mark.asyncio -async def test_sliding_window_pipeline_failure(scope, monkeypatch) -> None: # type: ignore[valid-type] +async def test_sliding_window_pipeline_failure(scope: AsyncContainer, monkeypatch: pytest.MonkeyPatch) -> None: svc: RateLimitService = await scope.get(RateLimitService) svc.prefix = f"{svc.prefix}{uuid4().hex[:6]}:" class FailingPipe: - def zremrangebyscore(self, *a, **k): return self # noqa: ANN001, D401 - def zadd(self, *a, **k): return self # noqa: ANN001, D401 - def zcard(self, *a, **k): return self # noqa: ANN001, D401 - def expire(self, *a, **k): return self # noqa: ANN001, D401 - async def execute(self): raise ConnectionError("Pipeline failed") + def zremrangebyscore(self, *a: Any, **k: Any) -> "FailingPipe": + return self + + def zadd(self, *a: Any, **k: Any) -> "FailingPipe": + return self + + def zcard(self, *a: Any, **k: Any) -> "FailingPipe": + return self + + def expire(self, *a: Any, **k: Any) -> "FailingPipe": + return self + + async def execute(self) -> None: + raise ConnectionError("Pipeline failed") monkeypatch.setattr(svc.redis, "pipeline", lambda: FailingPipe()) @@ -204,7 +214,7 @@ async def execute(self): raise ConnectionError("Pipeline failed") @pytest.mark.asyncio -async def test_token_bucket_invalid_data(scope) -> None: # type: ignore[valid-type] +async def test_token_bucket_invalid_data(scope: AsyncContainer) -> None: svc: RateLimitService = await scope.get(RateLimitService) key = f"{svc.prefix}tb:user:/api" await svc.redis.set(key, "invalid-json") @@ -224,10 +234,12 @@ async def test_token_bucket_invalid_data(scope) -> None: # type: ignore[valid-t @pytest.mark.asyncio -async def test_update_config_serialization_error(scope, monkeypatch) -> None: # type: ignore[valid-type] +async def test_update_config_serialization_error(scope: AsyncContainer, monkeypatch: pytest.MonkeyPatch) -> None: svc: RateLimitService = await scope.get(RateLimitService) - async def failing_setex(key, ttl, value): # noqa: ANN001 + + async def failing_setex(key: str, ttl: int, value: str) -> None: # noqa: ARG001 raise ValueError("Serialization failed") + monkeypatch.setattr(svc.redis, "setex", failing_setex) cfg = RateLimitConfig(default_rules=[]) @@ -236,35 +248,38 @@ async def failing_setex(key, ttl, value): # noqa: ANN001 @pytest.mark.asyncio -async def test_get_user_rate_limit_not_found(scope) -> None: # type: ignore[valid-type] +async def test_get_user_rate_limit_not_found(scope: AsyncContainer) -> None: svc: RateLimitService = await scope.get(RateLimitService) result = await svc.get_user_rate_limit("nonexistent") assert result is None @pytest.mark.asyncio -async def test_reset_user_limits_error(scope, monkeypatch) -> None: # type: ignore[valid-type] +async def test_reset_user_limits_error(scope: AsyncContainer, monkeypatch: pytest.MonkeyPatch) -> None: svc: RateLimitService = await scope.get(RateLimitService) - async def failing_smembers(key): # noqa: ANN001 + + async def failing_smembers(key: str) -> None: # noqa: ARG001 raise ConnectionError("smembers failed") + monkeypatch.setattr(svc.redis, "smembers", failing_smembers) with pytest.raises(ConnectionError): await svc.reset_user_limits("user") @pytest.mark.asyncio -async def test_get_usage_stats_with_keys(scope) -> None: # type: ignore[valid-type] +async def test_get_usage_stats_with_keys(scope: AsyncContainer) -> None: svc: RateLimitService = await scope.get(RateLimitService) user_id = "user" index_key = f"{svc.prefix}index:{user_id}" sw_key = f"{svc.prefix}sw:{user_id}:/api:key1" - await svc.redis.sadd(index_key, sw_key) + awaitable_result = cast("Awaitable[int]", svc.redis.sadd(index_key, sw_key)) + await awaitable_result stats = await svc.get_usage_stats(user_id) assert isinstance(stats, dict) @pytest.mark.asyncio -async def test_check_rate_limit_with_user_override(scope) -> None: # type: ignore[valid-type] +async def test_check_rate_limit_with_user_override(scope: AsyncContainer) -> None: svc: RateLimitService = await scope.get(RateLimitService) svc.settings.RATE_LIMIT_ENABLED = True # Enable rate limiting for this test rule = RateLimitRule( diff --git a/backend/tests/integration/services/replay/test_replay_service.py b/backend/tests/integration/services/replay/test_replay_service.py index de47f756..0705062e 100644 --- a/backend/tests/integration/services/replay/test_replay_service.py +++ b/backend/tests/integration/services/replay/test_replay_service.py @@ -1,4 +1,5 @@ import pytest +from dishka import AsyncContainer from app.domain.enums.replay import ReplayTarget, ReplayType from app.services.event_replay import ReplayConfig, ReplayFilter @@ -8,7 +9,7 @@ @pytest.mark.asyncio -async def test_replay_service_create_and_list(scope) -> None: # type: ignore[valid-type] +async def test_replay_service_create_and_list(scope: AsyncContainer) -> None: svc: ReplayService = await scope.get(ReplayService) cfg = ReplayConfig( diff --git a/backend/tests/integration/services/saga/test_saga_service.py b/backend/tests/integration/services/saga/test_saga_service.py index 21d6f3b1..b0d1a1a1 100644 --- a/backend/tests/integration/services/saga/test_saga_service.py +++ b/backend/tests/integration/services/saga/test_saga_service.py @@ -1,24 +1,25 @@ -import pytest from datetime import datetime, timezone +import pytest +from dishka import AsyncContainer + +from app.domain.enums.user import UserRole +from app.schemas_pydantic.user import User from app.services.saga.saga_service import SagaService pytestmark = [pytest.mark.integration, pytest.mark.mongodb] @pytest.mark.asyncio -async def test_saga_service_basic(scope) -> None: # type: ignore[valid-type] +async def test_saga_service_basic(scope: AsyncContainer) -> None: svc: SagaService = await scope.get(SagaService) - from app.domain.user import User as DomainUser - from app.domain.enums.user import UserRole - user = DomainUser( + user = User( user_id="u1", username="u1", email="u1@example.com", role=UserRole.USER, is_active=True, is_superuser=False, - hashed_password="x", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) diff --git a/backend/tests/integration/services/saved_script/test_saved_script_service.py b/backend/tests/integration/services/saved_script/test_saved_script_service.py index 16d980c8..c016146f 100644 --- a/backend/tests/integration/services/saved_script/test_saved_script_service.py +++ b/backend/tests/integration/services/saved_script/test_saved_script_service.py @@ -1,4 +1,5 @@ import pytest +from dishka import AsyncContainer from app.domain.saved_script import DomainSavedScriptCreate, DomainSavedScriptUpdate, SavedScriptNotFoundError from app.services.saved_script_service import SavedScriptService @@ -11,7 +12,7 @@ def _create_payload() -> DomainSavedScriptCreate: @pytest.mark.asyncio -async def test_crud_saved_script(scope) -> None: # type: ignore[valid-type] +async def test_crud_saved_script(scope: AsyncContainer) -> None: service: SavedScriptService = await scope.get(SavedScriptService) created = await service.create_saved_script(_create_payload(), user_id="u1") assert created.user_id == "u1" diff --git a/backend/tests/integration/services/sse/test_partitioned_event_router.py b/backend/tests/integration/services/sse/test_partitioned_event_router.py index 040a62b5..18bc3f86 100644 --- a/backend/tests/integration/services/sse/test_partitioned_event_router.py +++ b/backend/tests/integration/services/sse/test_partitioned_event_router.py @@ -2,6 +2,8 @@ from uuid import uuid4 import pytest +import redis.asyncio as redis + from app.core.metrics.events import EventMetrics from app.events.core import EventDispatcher from app.events.schema.schema_registry import SchemaRegistryManager @@ -9,7 +11,6 @@ from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge from app.services.sse.redis_bus import SSERedisBus from app.settings import Settings - from tests.helpers import make_execution_requested_event from tests.helpers.eventually import eventually @@ -19,7 +20,7 @@ @pytest.mark.asyncio -async def test_router_bridges_to_redis(redis_client, test_settings: Settings) -> None: +async def test_router_bridges_to_redis(redis_client: redis.Redis, test_settings: Settings) -> None: suffix = uuid4().hex[:6] bus = SSERedisBus( redis_client, @@ -30,7 +31,7 @@ async def test_router_bridges_to_redis(redis_client, test_settings: Settings) -> router = SSEKafkaRedisBridge( schema_registry=SchemaRegistryManager(settings=test_settings, logger=_test_logger), settings=test_settings, - event_metrics=EventMetrics(), + event_metrics=EventMetrics(test_settings), sse_bus=bus, logger=_test_logger, ) @@ -45,7 +46,7 @@ async def test_router_bridges_to_redis(redis_client, test_settings: Settings) -> handler = disp.get_handlers(ev.event_type)[0] await handler(ev) - async def _recv(): + async def _recv() -> RedisSSEMessage: m = await subscription.get(RedisSSEMessage) assert m is not None return m @@ -55,13 +56,13 @@ async def _recv(): @pytest.mark.asyncio -async def test_router_start_and_stop(redis_client, test_settings: Settings) -> None: +async def test_router_start_and_stop(redis_client: redis.Redis, test_settings: Settings) -> None: test_settings.SSE_CONSUMER_POOL_SIZE = 1 suffix = uuid4().hex[:6] router = SSEKafkaRedisBridge( schema_registry=SchemaRegistryManager(settings=test_settings, logger=_test_logger), settings=test_settings, - event_metrics=EventMetrics(), + event_metrics=EventMetrics(test_settings), sse_bus=SSERedisBus( redis_client, exec_prefix=f"sse:exec:{suffix}:", diff --git a/backend/tests/integration/services/sse/test_redis_bus.py b/backend/tests/integration/services/sse/test_redis_bus.py index ae54a6e4..74c05691 100644 --- a/backend/tests/integration/services/sse/test_redis_bus.py +++ b/backend/tests/integration/services/sse/test_redis_bus.py @@ -1,27 +1,31 @@ import asyncio -import json import logging -from typing import Any +from typing import Any, ClassVar, cast import pytest - -pytestmark = pytest.mark.integration +import redis.asyncio as redis_async from app.domain.enums.events import EventType +from app.domain.enums.kafka import KafkaTopic +from app.domain.enums.notification import NotificationSeverity, NotificationStatus +from app.infrastructure.kafka.events import BaseEvent +from app.infrastructure.kafka.events.metadata import AvroEventMetadata from app.schemas_pydantic.sse import RedisNotificationMessage, RedisSSEMessage from app.services.sse.redis_bus import SSERedisBus +pytestmark = pytest.mark.integration + _test_logger = logging.getLogger("test.services.sse.redis_bus") -class _DummyEvent: - def __init__(self, execution_id: str, event_type: EventType, extra: dict[str, Any] | None = None) -> None: - self.execution_id = execution_id - self.event_type = event_type - self._extra = extra or {} +class _DummyEvent(BaseEvent): + """Dummy event for testing.""" + execution_id: str = "" + status: str | None = None + topic: ClassVar[KafkaTopic] = KafkaTopic.EXECUTION_EVENTS - def model_dump(self, mode: str | None = None) -> dict[str, Any]: # noqa: ARG002 - return {"execution_id": self.execution_id, **self._extra} + def model_dump(self, **kwargs: object) -> dict[str, Any]: + return {"execution_id": self.execution_id, "status": self.status} class _FakePubSub: @@ -33,7 +37,7 @@ def __init__(self) -> None: async def subscribe(self, channel: str) -> None: self.subscribed.add(channel) - async def get_message(self, ignore_subscribe_messages: bool = True, timeout: float = 0.5): # noqa: ARG002 + async def get_message(self, ignore_subscribe_messages: bool = True, timeout: float = 0.5) -> dict[str, Any] | None: try: msg = await asyncio.wait_for(self._queue.get(), timeout=timeout) return msg @@ -51,21 +55,31 @@ async def aclose(self) -> None: class _FakeRedis: + """Fake Redis for testing - used in place of real Redis. + + Note: SSERedisBus uses duck-typing so this works without inheritance. + """ + def __init__(self) -> None: self.published: list[tuple[str, str]] = [] self._pubsub = _FakePubSub() - async def publish(self, channel: str, payload: str) -> None: + async def publish(self, channel: str, payload: str) -> int: self.published.append((channel, payload)) + return 1 def pubsub(self) -> _FakePubSub: return self._pubsub +def _make_metadata() -> AvroEventMetadata: + return AvroEventMetadata(service_name="test", service_version="1.0") + + @pytest.mark.asyncio async def test_publish_and_subscribe_round_trip() -> None: r = _FakeRedis() - bus = SSERedisBus(r, logger=_test_logger) + bus = SSERedisBus(cast(redis_async.Redis, r), logger=_test_logger) # Subscribe sub = await bus.open_subscription("exec-1") @@ -73,7 +87,12 @@ async def test_publish_and_subscribe_round_trip() -> None: assert "sse:exec:exec-1" in r._pubsub.subscribed # Publish event - evt = _DummyEvent("exec-1", EventType.EXECUTION_COMPLETED, {"status": "completed"}) + evt = _DummyEvent( + event_type=EventType.EXECUTION_COMPLETED, + metadata=_make_metadata(), + execution_id="exec-1", + status="completed" + ) await bus.publish_event("exec-1", evt) assert r.published, "nothing published" ch, payload = r.published[-1] @@ -96,14 +115,14 @@ async def test_publish_and_subscribe_round_trip() -> None: @pytest.mark.asyncio async def test_notifications_channels() -> None: r = _FakeRedis() - bus = SSERedisBus(r, logger=_test_logger) + bus = SSERedisBus(cast(redis_async.Redis, r), logger=_test_logger) nsub = await bus.open_notification_subscription("user-1") assert "sse:notif:user-1" in r._pubsub.subscribed notif = RedisNotificationMessage( notification_id="n1", - severity="low", - status="pending", + severity=NotificationSeverity.LOW, + status=NotificationStatus.PENDING, tags=[], subject="test", body="body", diff --git a/backend/tests/integration/services/user_settings/test_user_settings_service.py b/backend/tests/integration/services/user_settings/test_user_settings_service.py index dccc3b2b..d3a15d54 100644 --- a/backend/tests/integration/services/user_settings/test_user_settings_service.py +++ b/backend/tests/integration/services/user_settings/test_user_settings_service.py @@ -1,6 +1,7 @@ from datetime import datetime, timezone import pytest +from dishka import AsyncContainer from app.domain.enums import Theme from app.domain.user.settings_models import ( @@ -14,7 +15,7 @@ @pytest.mark.asyncio -async def test_get_update_and_history(scope) -> None: # type: ignore[valid-type] +async def test_get_update_and_history(scope: AsyncContainer) -> None: svc: UserSettingsService = await scope.get(UserSettingsService) user_id = "u1" diff --git a/backend/tests/integration/test_admin_routes.py b/backend/tests/integration/test_admin_routes.py index 03206678..1626589a 100644 --- a/backend/tests/integration/test_admin_routes.py +++ b/backend/tests/integration/test_admin_routes.py @@ -1,4 +1,3 @@ -from typing import Dict from uuid import uuid4 import pytest @@ -27,18 +26,10 @@ async def test_get_settings_requires_auth(self, client: AsyncClient) -> None: assert "not authenticated" in error["detail"].lower() or "unauthorized" in error["detail"].lower() @pytest.mark.asyncio - async def test_get_settings_with_admin_auth(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_get_settings_with_admin_auth(self, test_admin: AsyncClient) -> None: """Test getting system settings with admin authentication.""" - # Login and get cookies - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - - # Now get settings with auth cookie - response = await client.get("/api/v1/admin/settings/") + # Get settings with auth cookie (logged in via test_admin fixture) + response = await test_admin.get("/api/v1/admin/settings/") assert response.status_code == 200 # Validate response structure @@ -68,18 +59,10 @@ async def test_get_settings_with_admin_auth(self, client: AsyncClient, test_admi assert settings.monitoring_settings.sampling_rate == 0.1 @pytest.mark.asyncio - async def test_update_and_reset_settings(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_update_and_reset_settings(self, test_admin: AsyncClient) -> None: """Test updating and resetting system settings.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get original settings - original_response = await client.get("/api/v1/admin/settings/") + original_response = await test_admin.get("/api/v1/admin/settings/") assert original_response.status_code == 200 original_settings = original_response.json() @@ -105,7 +88,9 @@ async def test_update_and_reset_settings(self, client: AsyncClient, test_admin: } } - update_response = await client.put("/api/v1/admin/settings/", json=updated_settings) + update_response = await test_admin.put( + "/api/v1/admin/settings/", json=updated_settings + ) assert update_response.status_code == 200 # Verify updates were applied @@ -115,7 +100,7 @@ async def test_update_and_reset_settings(self, client: AsyncClient, test_admin: assert returned_settings.monitoring_settings.log_level == "WARNING" # Reset settings - reset_response = await client.post("/api/v1/admin/settings/reset") + reset_response = await test_admin.post("/api/v1/admin/settings/reset") assert reset_response.status_code == 200 # Verify reset to defaults @@ -125,18 +110,10 @@ async def test_update_and_reset_settings(self, client: AsyncClient, test_admin: assert reset_settings.monitoring_settings.log_level == "INFO" @pytest.mark.asyncio - async def test_regular_user_cannot_access_settings(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_regular_user_cannot_access_settings(self, test_user: AsyncClient) -> None: """Test that regular users cannot access admin settings.""" - # Login as regular user - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - - # Try to access admin settings - response = await client.get("/api/v1/admin/settings/") + # Try to access admin settings (logged in as regular user via test_user fixture) + response = await test_user.get("/api/v1/admin/settings/") assert response.status_code == 403 error = response.json() @@ -149,18 +126,10 @@ class TestAdminUsers: """Test admin user management endpoints against real backend.""" @pytest.mark.asyncio - async def test_list_users_with_pagination(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_list_users_with_pagination(self, test_admin: AsyncClient) -> None: """Test listing users with pagination.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # List users - response = await client.get("/api/v1/admin/users/?limit=10&offset=0") + response = await test_admin.get("/api/v1/admin/users/?limit=10&offset=0") assert response.status_code == 200 data = response.json() @@ -188,16 +157,8 @@ async def test_list_users_with_pagination(self, client: AsyncClient, test_admin: assert "updated_at" in user @pytest.mark.asyncio - async def test_create_and_manage_user(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_create_and_manage_user(self, test_admin: AsyncClient) -> None: """Test full user CRUD operations.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Create a new user unique_id = str(uuid4())[:8] new_user_data = { @@ -206,7 +167,7 @@ async def test_create_and_manage_user(self, client: AsyncClient, test_admin: Dic "password": "SecureP@ssw0rd123" } - create_response = await client.post("/api/v1/admin/users/", json=new_user_data) + create_response = await test_admin.post("/api/v1/admin/users/", json=new_user_data) assert create_response.status_code in [200, 201] created_user = create_response.json() @@ -218,11 +179,11 @@ async def test_create_and_manage_user(self, client: AsyncClient, test_admin: Dic user_id = created_user["user_id"] # Get user details - get_response = await client.get(f"/api/v1/admin/users/{user_id}") + get_response = await test_admin.get(f"/api/v1/admin/users/{user_id}") assert get_response.status_code == 200 # Get user overview - overview_response = await client.get(f"/api/v1/admin/users/{user_id}/overview") + overview_response = await test_admin.get(f"/api/v1/admin/users/{user_id}/overview") assert overview_response.status_code == 200 overview_data = overview_response.json() @@ -236,7 +197,9 @@ async def test_create_and_manage_user(self, client: AsyncClient, test_admin: Dic "email": f"updated_{unique_id}@example.com" } - update_response = await client.put(f"/api/v1/admin/users/{user_id}", json=update_data) + update_response = await test_admin.put( + f"/api/v1/admin/users/{user_id}", json=update_data + ) assert update_response.status_code == 200 updated_user = update_response.json() @@ -244,11 +207,11 @@ async def test_create_and_manage_user(self, client: AsyncClient, test_admin: Dic assert updated_user["email"] == update_data["email"] # Delete user - delete_response = await client.delete(f"/api/v1/admin/users/{user_id}") + delete_response = await test_admin.delete(f"/api/v1/admin/users/{user_id}") assert delete_response.status_code in [200, 204] # Verify deletion - get_deleted_response = await client.get(f"/api/v1/admin/users/{user_id}") + get_deleted_response = await test_admin.get(f"/api/v1/admin/users/{user_id}") assert get_deleted_response.status_code == 404 @@ -257,16 +220,8 @@ class TestAdminEvents: """Test admin event management endpoints against real backend.""" @pytest.mark.asyncio - async def test_browse_events(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_browse_events(self, test_admin: AsyncClient) -> None: """Test browsing events with filters.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Browse events browse_payload = { "filters": { @@ -278,7 +233,7 @@ async def test_browse_events(self, client: AsyncClient, test_admin: Dict[str, st "sort_order": -1 } - response = await client.post("/api/v1/admin/events/browse", json=browse_payload) + response = await test_admin.post("/api/v1/admin/events/browse", json=browse_payload) assert response.status_code == 200 data = response.json() @@ -291,18 +246,10 @@ async def test_browse_events(self, client: AsyncClient, test_admin: Dict[str, st assert data["total"] >= 0 @pytest.mark.asyncio - async def test_event_statistics(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_event_statistics(self, test_admin: AsyncClient) -> None: """Test getting event statistics.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get event statistics - response = await client.get("/api/v1/admin/events/stats?hours=24") + response = await test_admin.get("/api/v1/admin/events/stats?hours=24") assert response.status_code == 200 data = response.json() @@ -324,15 +271,10 @@ async def test_event_statistics(self, client: AsyncClient, test_admin: Dict[str, assert data["error_rate"] >= 0.0 @pytest.mark.asyncio - async def test_admin_events_export_csv_and_json(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_admin_events_export_csv_and_json(self, test_admin: AsyncClient) -> None: """Export admin events as CSV and JSON and validate basic structure.""" - # Login as admin - login_data = {"username": test_admin["username"], "password": test_admin["password"]} - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # CSV export - r_csv = await client.get("/api/v1/admin/events/export/csv?limit=10") + r_csv = await test_admin.get("/api/v1/admin/events/export/csv?limit=10") assert r_csv.status_code == 200, f"CSV export failed: {r_csv.status_code} - {r_csv.text[:200]}" ct_csv = r_csv.headers.get("content-type", "") assert "text/csv" in ct_csv @@ -341,7 +283,7 @@ async def test_admin_events_export_csv_and_json(self, client: AsyncClient, test_ assert "Event ID" in body_csv and "Timestamp" in body_csv # JSON export - r_json = await client.get("/api/v1/admin/events/export/json?limit=10") + r_json = await test_admin.get("/api/v1/admin/events/export/json?limit=10") assert r_json.status_code == 200, f"JSON export failed: {r_json.status_code} - {r_json.text[:200]}" ct_json = r_json.headers.get("content-type", "") assert "application/json" in ct_json @@ -351,14 +293,8 @@ async def test_admin_events_export_csv_and_json(self, client: AsyncClient, test_ assert "exported_at" in data["export_metadata"] @pytest.mark.asyncio - async def test_admin_user_rate_limits_and_password_reset(self, client: AsyncClient, - test_admin: Dict[str, str]) -> None: + async def test_admin_user_rate_limits_and_password_reset(self, test_admin: AsyncClient) -> None: """Create a user, manage rate limits, and reset password via admin endpoints.""" - # Login as admin - login_data = {"username": test_admin["username"], "password": test_admin["password"]} - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Create a new user to operate on unique_id = str(uuid4())[:8] new_user = { @@ -366,12 +302,12 @@ async def test_admin_user_rate_limits_and_password_reset(self, client: AsyncClie "email": f"rl_{unique_id}@example.com", "password": "TempP@ss1234" } - create_response = await client.post("/api/v1/admin/users/", json=new_user) + create_response = await test_admin.post("/api/v1/admin/users/", json=new_user) assert create_response.status_code in [200, 201] target_user_id = create_response.json()["user_id"] # Get current rate limits (may be None for fresh user) - rl_get = await client.get(f"/api/v1/admin/users/{target_user_id}/rate-limits") + rl_get = await test_admin.get(f"/api/v1/admin/users/{target_user_id}/rate-limits") assert rl_get.status_code == 200 rl_body = rl_get.json() assert rl_body.get("user_id") == target_user_id @@ -395,28 +331,33 @@ async def test_admin_user_rate_limits_and_password_reset(self, client: AsyncClie } ] } - rl_put = await client.put(f"/api/v1/admin/users/{target_user_id}/rate-limits", json=update_payload) + rl_put = await test_admin.put( + f"/api/v1/admin/users/{target_user_id}/rate-limits", + json=update_payload + ) assert rl_put.status_code == 200 put_body = rl_put.json() assert put_body.get("updated") is True assert put_body.get("config", {}).get("user_id") == target_user_id # Reset rate limits - rl_reset = await client.post(f"/api/v1/admin/users/{target_user_id}/rate-limits/reset") + rl_reset = await test_admin.post( + f"/api/v1/admin/users/{target_user_id}/rate-limits/reset" + ) assert rl_reset.status_code == 200 # Reset password for the user new_password = "NewPassw0rd!" - pw_reset = await client.post( + pw_reset = await test_admin.post( f"/api/v1/admin/users/{target_user_id}/reset-password", json={"new_password": new_password} ) assert pw_reset.status_code == 200 # Verify user can login with the new password - logout_resp = await client.post("/api/v1/auth/logout") + logout_resp = await test_admin.post("/api/v1/auth/logout") assert logout_resp.status_code in [200, 204] - login_new = await client.post( + login_new = await test_admin.post( "/api/v1/auth/login", data={"username": new_user["username"], "password": new_password} ) diff --git a/backend/tests/integration/test_alertmanager.py b/backend/tests/integration/test_alertmanager.py index c61304c1..f7b7dcd6 100644 --- a/backend/tests/integration/test_alertmanager.py +++ b/backend/tests/integration/test_alertmanager.py @@ -1,3 +1,4 @@ +import httpx import pytest from datetime import datetime, timezone @@ -6,7 +7,7 @@ @pytest.mark.asyncio -async def test_grafana_alert_endpoints(client): +async def test_grafana_alert_endpoints(client: httpx.AsyncClient) -> None: # Test endpoint r_test = await client.get("/api/v1/alerts/grafana/test") assert r_test.status_code == 200 diff --git a/backend/tests/integration/test_dlq_routes.py b/backend/tests/integration/test_dlq_routes.py index 5cc114a0..50b0f71b 100644 --- a/backend/tests/integration/test_dlq_routes.py +++ b/backend/tests/integration/test_dlq_routes.py @@ -1,19 +1,24 @@ from datetime import datetime -from typing import Dict +from typing import TypedDict import pytest from httpx import AsyncClient +from app.dlq import DLQMessageStatus from app.schemas_pydantic.dlq import ( - DLQStats, - DLQMessagesResponse, - DLQMessageResponse, - DLQMessageDetail, - DLQMessageStatus, DLQBatchRetryResponse, - DLQTopicSummaryResponse + DLQMessageDetail, + DLQMessageResponse, + DLQMessagesResponse, + DLQStats, + DLQTopicSummaryResponse, ) from app.schemas_pydantic.user import MessageResponse +from app.settings import Settings + + +class _RetryRequest(TypedDict): + event_ids: list[str] @pytest.mark.integration @@ -33,18 +38,10 @@ async def test_dlq_requires_authentication(self, client: AsyncClient) -> None: for word in ["not authenticated", "unauthorized", "login"]) @pytest.mark.asyncio - async def test_get_dlq_statistics(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_dlq_statistics(self, test_user: AsyncClient) -> None: """Test getting DLQ statistics.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get DLQ stats - response = await client.get("/api/v1/dlq/stats") + response = await test_user.get("/api/v1/dlq/stats") assert response.status_code == 200 # Validate response structure @@ -86,18 +83,10 @@ async def test_get_dlq_statistics(self, client: AsyncClient, test_user: Dict[str assert stats.age_stats[key] >= 0 @pytest.mark.asyncio - async def test_list_dlq_messages(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_list_dlq_messages(self, test_user: AsyncClient) -> None: """Test listing DLQ messages with filters.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # List all DLQ messages - response = await client.get("/api/v1/dlq/messages?limit=10&offset=0") + response = await test_user.get("/api/v1/dlq/messages?limit=10&offset=0") assert response.status_code == 200 # Validate response structure @@ -125,24 +114,12 @@ async def test_list_dlq_messages(self, client: AsyncClient, test_user: Dict[str, if message.age_seconds is not None: assert message.age_seconds >= 0 - # Check details if present - if message.details: - assert isinstance(message.details, dict) - @pytest.mark.asyncio - async def test_filter_dlq_messages_by_status(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_filter_dlq_messages_by_status(self, test_user: AsyncClient) -> None: """Test filtering DLQ messages by status.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Test different status filters for status in ["pending", "scheduled", "retried", "discarded"]: - response = await client.get(f"/api/v1/dlq/messages?status={status}&limit=5") + response = await test_user.get(f"/api/v1/dlq/messages?status={status}&limit=5") assert response.status_code == 200 messages_data = response.json() @@ -153,19 +130,11 @@ async def test_filter_dlq_messages_by_status(self, client: AsyncClient, test_use assert message.status == status @pytest.mark.asyncio - async def test_filter_dlq_messages_by_topic(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_filter_dlq_messages_by_topic(self, test_user: AsyncClient) -> None: """Test filtering DLQ messages by topic.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Filter by a specific topic test_topic = "execution-events" - response = await client.get(f"/api/v1/dlq/messages?topic={test_topic}&limit=5") + response = await test_user.get(f"/api/v1/dlq/messages?topic={test_topic}&limit=5") assert response.status_code == 200 messages_data = response.json() @@ -176,18 +145,10 @@ async def test_filter_dlq_messages_by_topic(self, client: AsyncClient, test_user assert message.original_topic == test_topic @pytest.mark.asyncio - async def test_get_single_dlq_message_detail(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_single_dlq_message_detail(self, test_user: AsyncClient) -> None: """Test getting detailed information for a single DLQ message.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # First get list of messages to find an ID - list_response = await client.get("/api/v1/dlq/messages?limit=1") + list_response = await test_user.get("/api/v1/dlq/messages?limit=1") assert list_response.status_code == 200 messages_data = list_response.json() @@ -195,7 +156,7 @@ async def test_get_single_dlq_message_detail(self, client: AsyncClient, test_use # Get details for the first message event_id = messages_data["messages"][0]["event_id"] - detail_response = await client.get(f"/api/v1/dlq/messages/{event_id}") + detail_response = await test_user.get(f"/api/v1/dlq/messages/{event_id}") assert detail_response.status_code == 200 # Validate detailed response @@ -224,19 +185,11 @@ async def test_get_single_dlq_message_detail(self, client: AsyncClient, test_use assert message_detail.dlq_partition >= 0 @pytest.mark.asyncio - async def test_get_nonexistent_dlq_message(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_nonexistent_dlq_message(self, test_user: AsyncClient) -> None: """Test getting a non-existent DLQ message.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Try to get non-existent message fake_event_id = "00000000-0000-0000-0000-000000000000" - response = await client.get(f"/api/v1/dlq/messages/{fake_event_id}") + response = await test_user.get(f"/api/v1/dlq/messages/{fake_event_id}") assert response.status_code == 404 error_data = response.json() @@ -244,19 +197,14 @@ async def test_get_nonexistent_dlq_message(self, client: AsyncClient, test_user: assert "not found" in error_data["detail"].lower() @pytest.mark.asyncio - async def test_set_retry_policy(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_set_retry_policy( + self, test_user: AsyncClient, test_settings: Settings + ) -> None: """Test setting a retry policy for a topic.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Set retry policy + topic = f"{test_settings.KAFKA_TOPIC_PREFIX}test-topic" policy_data = { - "topic": "test-topic", + "topic": topic, "strategy": "exponential_backoff", "max_retries": 5, "base_delay_seconds": 10, @@ -264,28 +212,20 @@ async def test_set_retry_policy(self, client: AsyncClient, test_user: Dict[str, "retry_multiplier": 2.0 } - response = await client.post("/api/v1/dlq/retry-policy", json=policy_data) + response = await test_user.post("/api/v1/dlq/retry-policy", json=policy_data) assert response.status_code == 200 # Validate response result_data = response.json() result = MessageResponse(**result_data) assert "retry policy set" in result.message.lower() - assert policy_data["topic"] in result.message + assert topic in result.message @pytest.mark.asyncio - async def test_retry_dlq_messages_batch(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_retry_dlq_messages_batch(self, test_user: AsyncClient) -> None: """Test retrying a batch of DLQ messages.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get some failed messages to retry - list_response = await client.get("/api/v1/dlq/messages?status=discarded&limit=3") + list_response = await test_user.get("/api/v1/dlq/messages?status=discarded&limit=3") assert list_response.status_code == 200 messages_data = list_response.json() @@ -298,7 +238,7 @@ async def test_retry_dlq_messages_batch(self, client: AsyncClient, test_user: Di "event_ids": event_ids } - retry_response = await client.post("/api/v1/dlq/retry", json=retry_request) + retry_response = await test_user.post("/api/v1/dlq/retry", json=retry_request) assert retry_response.status_code == 200 # Validate retry response @@ -319,18 +259,10 @@ async def test_retry_dlq_messages_batch(self, client: AsyncClient, test_user: Di assert "success" in detail @pytest.mark.asyncio - async def test_discard_dlq_message(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_discard_dlq_message(self, test_user: AsyncClient) -> None: """Test discarding a DLQ message.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get a failed message to discard - list_response = await client.get("/api/v1/dlq/messages?status=discarded&limit=1") + list_response = await test_user.get("/api/v1/dlq/messages?status=discarded&limit=1") assert list_response.status_code == 200 messages_data = list_response.json() @@ -339,7 +271,7 @@ async def test_discard_dlq_message(self, client: AsyncClient, test_user: Dict[st # Discard the message discard_reason = "Test discard - message unrecoverable" - discard_response = await client.delete( + discard_response = await test_user.delete( f"/api/v1/dlq/messages/{event_id}?reason={discard_reason}" ) assert discard_response.status_code == 200 @@ -351,25 +283,17 @@ async def test_discard_dlq_message(self, client: AsyncClient, test_user: Dict[st assert event_id in result.message # Verify message is now discarded - detail_response = await client.get(f"/api/v1/dlq/messages/{event_id}") + detail_response = await test_user.get(f"/api/v1/dlq/messages/{event_id}") if detail_response.status_code == 200: detail_data = detail_response.json() # Status should be discarded assert detail_data["status"] == "discarded" @pytest.mark.asyncio - async def test_get_dlq_topics_summary(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_dlq_topics_summary(self, test_user: AsyncClient) -> None: """Test getting DLQ topics summary.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get topics summary - response = await client.get("/api/v1/dlq/topics") + response = await test_user.get("/api/v1/dlq/topics") assert response.status_code == 200 # Validate response @@ -404,18 +328,10 @@ async def test_get_dlq_topics_summary(self, client: AsyncClient, test_user: Dict assert topic_summary.max_retry_count >= 0 @pytest.mark.asyncio - async def test_dlq_message_pagination(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_dlq_message_pagination(self, test_user: AsyncClient) -> None: """Test DLQ message pagination.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get first page - page1_response = await client.get("/api/v1/dlq/messages?limit=5&offset=0") + page1_response = await test_user.get("/api/v1/dlq/messages?limit=5&offset=0") assert page1_response.status_code == 200 page1_data = page1_response.json() @@ -423,7 +339,7 @@ async def test_dlq_message_pagination(self, client: AsyncClient, test_user: Dict # If there are more than 5 messages, get second page if page1.total > 5: - page2_response = await client.get("/api/v1/dlq/messages?limit=5&offset=5") + page2_response = await test_user.get("/api/v1/dlq/messages?limit=5&offset=5") assert page2_response.status_code == 200 page2_data = page2_response.json() @@ -442,39 +358,31 @@ async def test_dlq_message_pagination(self, client: AsyncClient, test_user: Dict assert len(page1_ids.intersection(page2_ids)) == 0 @pytest.mark.asyncio - async def test_dlq_error_handling(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_dlq_error_handling(self, test_user: AsyncClient) -> None: """Test DLQ error handling for invalid requests.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Test invalid limit - response = await client.get("/api/v1/dlq/messages?limit=10000") # Too high + response = await test_user.get("/api/v1/dlq/messages?limit=10000") # Too high # Should either accept with max limit or reject assert response.status_code in [200, 400, 422] # Test negative offset - response = await client.get("/api/v1/dlq/messages?limit=10&offset=-1") + response = await test_user.get("/api/v1/dlq/messages?limit=10&offset=-1") assert response.status_code in [400, 422] # Test invalid status filter - response = await client.get("/api/v1/dlq/messages?status=invalid_status") + response = await test_user.get("/api/v1/dlq/messages?status=invalid_status") assert response.status_code in [400, 422] # Test retry with empty list - retry_request = { + retry_request: _RetryRequest = { "event_ids": [] } - response = await client.post("/api/v1/dlq/retry", json=retry_request) + response = await test_user.post("/api/v1/dlq/retry", json=retry_request) # Should handle gracefully or reject invalid input assert response.status_code in [200, 400, 404, 422] # Test discard without reason fake_event_id = "00000000-0000-0000-0000-000000000000" - response = await client.delete(f"/api/v1/dlq/messages/{fake_event_id}") + response = await test_user.delete(f"/api/v1/dlq/messages/{fake_event_id}") # Should require reason parameter assert response.status_code in [400, 422, 404] diff --git a/backend/tests/integration/test_events_routes.py b/backend/tests/integration/test_events_routes.py index 342bd8ad..992fbbc4 100644 --- a/backend/tests/integration/test_events_routes.py +++ b/backend/tests/integration/test_events_routes.py @@ -1,18 +1,16 @@ -from datetime import datetime, timezone, timedelta -from typing import Dict +from datetime import datetime, timedelta, timezone from uuid import uuid4 import pytest -from httpx import AsyncClient - from app.domain.enums.events import EventType from app.schemas_pydantic.events import ( EventListResponse, EventResponse, EventStatistics, PublishEventResponse, - ReplayAggregateResponse + ReplayAggregateResponse, ) +from httpx import AsyncClient @pytest.mark.integration @@ -32,12 +30,12 @@ async def test_events_require_authentication(self, client: AsyncClient) -> None: for word in ["not authenticated", "unauthorized", "login"]) @pytest.mark.asyncio - async def test_get_user_events(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_user_events(self, test_user: AsyncClient) -> None: """Test getting user's events.""" # Already authenticated via test_user fixture # Get user events - response = await client.get("/api/v1/events/user?limit=10&skip=0") + response = await test_user.get("/api/v1/events/user?limit=10&skip=0") # Some deployments may route this path under a dynamic segment and return 404. # Accept 200 with a valid payload or 404 (no such resource). assert response.status_code in [200, 404] @@ -61,19 +59,18 @@ async def test_get_user_events(self, client: AsyncClient, test_user: Dict[str, s assert event.event_type is not None assert event.aggregate_id is not None assert event.timestamp is not None - assert event.version is not None - assert event.user_id is not None + assert event.event_version is not None + assert event.metadata is not None + assert event.metadata.user_id is not None # Optional fields if event.payload: assert isinstance(event.payload, dict) - if event.metadata: - assert isinstance(event.metadata, dict) if event.correlation_id: assert isinstance(event.correlation_id, str) @pytest.mark.asyncio - async def test_get_user_events_with_filters(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_user_events_with_filters(self, test_user: AsyncClient) -> None: """Test filtering user events.""" # Already authenticated via test_user fixture @@ -83,18 +80,18 @@ async def test_get_user_events_with_filters(self, client: AsyncClient, test_user "lang": "python", "lang_version": "3.11" } - exec_response = await client.post("/api/v1/execute", json=execution_request) + exec_response = await test_user.post("/api/v1/execute", json=execution_request) assert exec_response.status_code == 200 # Filter by event types event_types = ["execution.requested", "execution.completed"] - params = { + params: dict[str, str | int | list[str]] = { "event_types": event_types, "limit": 20, "sort_order": "desc" } - response = await client.get("/api/v1/events/user", params=params) + response = await test_user.get("/api/v1/events/user", params=params) assert response.status_code in [200, 404] if response.status_code == 200: events_data = response.json() @@ -107,29 +104,21 @@ async def test_get_user_events_with_filters(self, client: AsyncClient, test_user events_response.events) == 0 @pytest.mark.asyncio - async def test_get_execution_events(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_execution_events(self, test_user: AsyncClient) -> None: """Test getting events for a specific execution.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Create an execution execution_request = { "script": "print('Test execution events')", "lang": "python", "lang_version": "3.11" } - exec_response = await client.post("/api/v1/execute", json=execution_request) + exec_response = await test_user.post("/api/v1/execute", json=execution_request) assert exec_response.status_code == 200 execution_id = exec_response.json()["execution_id"] # Get execution events (JSON, not SSE stream) - response = await client.get( + response = await test_user.get( f"/api/v1/events/executions/{execution_id}/events?include_system_events=true" ) assert response.status_code == 200 @@ -147,16 +136,8 @@ async def test_get_execution_events(self, client: AsyncClient, test_user: Dict[s assert execution_id in event.aggregate_id or event.aggregate_id == execution_id @pytest.mark.asyncio - async def test_query_events_advanced(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_query_events_advanced(self, test_user: AsyncClient) -> None: """Test advanced event querying with filters.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Query events with multiple filters query_request = { "event_types": [ @@ -171,7 +152,7 @@ async def test_query_events_advanced(self, client: AsyncClient, test_user: Dict[ "sort_order": "desc" } - response = await client.post("/api/v1/events/query", json=query_request) + response = await test_user.post("/api/v1/events/query", json=query_request) assert response.status_code == 200 events_data = response.json() @@ -191,27 +172,19 @@ async def test_query_events_advanced(self, client: AsyncClient, test_user: Dict[ assert t1 >= t2 # Descending order @pytest.mark.asyncio - async def test_get_events_by_correlation_id(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_events_by_correlation_id(self, test_user: AsyncClient) -> None: """Test getting events by correlation ID.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Create an execution (which generates correlated events) execution_request = { "script": "print('Test correlation')", "lang": "python", "lang_version": "3.11" } - exec_response = await client.post("/api/v1/execute", json=execution_request) + exec_response = await test_user.post("/api/v1/execute", json=execution_request) assert exec_response.status_code == 200 # Get events for the user to find a correlation ID - user_events_response = await client.get("/api/v1/events/user?limit=10") + user_events_response = await test_user.get("/api/v1/events/user?limit=10") assert user_events_response.status_code == 200 user_events = user_events_response.json() @@ -219,7 +192,7 @@ async def test_get_events_by_correlation_id(self, client: AsyncClient, test_user correlation_id = user_events["events"][0]["correlation_id"] # Get events by correlation ID - response = await client.get(f"/api/v1/events/correlation/{correlation_id}?limit=50") + response = await test_user.get(f"/api/v1/events/correlation/{correlation_id}?limit=50") assert response.status_code == 200 correlated_events = response.json() @@ -231,18 +204,10 @@ async def test_get_events_by_correlation_id(self, client: AsyncClient, test_user assert event.correlation_id == correlation_id @pytest.mark.asyncio - async def test_get_current_request_events(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_current_request_events(self, test_user: AsyncClient) -> None: """Test getting events for the current request.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get current request events (might be empty if no correlation context) - response = await client.get("/api/v1/events/current-request?limit=10") + response = await test_user.get("/api/v1/events/current-request?limit=10") assert response.status_code == 200 events_data = response.json() @@ -253,18 +218,10 @@ async def test_get_current_request_events(self, client: AsyncClient, test_user: assert events_response.total >= 0 @pytest.mark.asyncio - async def test_get_event_statistics(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_event_statistics(self, test_user: AsyncClient) -> None: """Test getting event statistics.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get statistics for last 24 hours - response = await client.get("/api/v1/events/statistics") + response = await test_user.get("/api/v1/events/statistics") assert response.status_code == 200 stats_data = response.json() @@ -281,26 +238,16 @@ async def test_get_event_statistics(self, client: AsyncClient, test_user: Dict[s # Events by hour should have proper structure for hourly_stat in stats.events_by_hour: - # Some implementations return {'_id': hour, 'count': n} - hour_key = "hour" if "hour" in hourly_stat else "_id" - assert hour_key in hourly_stat - assert "count" in hourly_stat - assert isinstance(hourly_stat["count"], int) - assert hourly_stat["count"] >= 0 + # HourlyEventCountSchema has hour: str and count: int + assert isinstance(hourly_stat.hour, str) + assert isinstance(hourly_stat.count, int) + assert hourly_stat.count >= 0 @pytest.mark.asyncio - async def test_get_single_event(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_single_event(self, test_user: AsyncClient) -> None: """Test getting a single event by ID.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get user events to find an event ID - events_response = await client.get("/api/v1/events/user?limit=1") + events_response = await test_user.get("/api/v1/events/user?limit=1") assert events_response.status_code == 200 events_data = events_response.json() @@ -308,7 +255,7 @@ async def test_get_single_event(self, client: AsyncClient, test_user: Dict[str, event_id = events_data["events"][0]["event_id"] # Get single event - response = await client.get(f"/api/v1/events/{event_id}") + response = await test_user.get(f"/api/v1/events/{event_id}") assert response.status_code == 200 event_data = response.json() @@ -320,19 +267,11 @@ async def test_get_single_event(self, client: AsyncClient, test_user: Dict[str, assert event.timestamp is not None @pytest.mark.asyncio - async def test_get_nonexistent_event(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_nonexistent_event(self, test_user: AsyncClient) -> None: """Test getting a non-existent event.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Try to get non-existent event fake_event_id = str(uuid4()) - response = await client.get(f"/api/v1/events/{fake_event_id}") + response = await test_user.get(f"/api/v1/events/{fake_event_id}") assert response.status_code == 404 error_data = response.json() @@ -340,48 +279,24 @@ async def test_get_nonexistent_event(self, client: AsyncClient, test_user: Dict[ assert "not found" in error_data["detail"].lower() @pytest.mark.asyncio - async def test_list_event_types(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_list_event_types(self, test_user: AsyncClient) -> None: """Test listing available event types.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # List event types - response = await client.get("/api/v1/events/types/list") + response = await test_user.get("/api/v1/events/types/list") assert response.status_code == 200 event_types = response.json() assert isinstance(event_types, list) - # Should contain common event types - common_types = [ - "execution.requested", - "execution.completed", - "user.logged_in", - "user.registered" - ] - - # At least some common types should be present + # Event types should be non-empty strings for event_type in event_types: assert isinstance(event_type, str) assert len(event_type) > 0 @pytest.mark.asyncio - async def test_publish_custom_event_requires_admin(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_publish_custom_event_requires_admin(self, test_user: AsyncClient) -> None: """Test that publishing custom events requires admin privileges.""" - # Login as regular user - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - - # Try to publish custom event + # Try to publish custom event (logged in as regular user via fixture) publish_request = { "event_type": EventType.SYSTEM_ERROR.value, "payload": { @@ -392,21 +307,13 @@ async def test_publish_custom_event_requires_admin(self, client: AsyncClient, te "correlation_id": str(uuid4()) } - response = await client.post("/api/v1/events/publish", json=publish_request) + response = await test_user.post("/api/v1/events/publish", json=publish_request) assert response.status_code == 403 # Forbidden for non-admin @pytest.mark.asyncio @pytest.mark.kafka - async def test_publish_custom_event_as_admin(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_publish_custom_event_as_admin(self, test_admin: AsyncClient) -> None: """Test publishing custom events as admin.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Publish custom event (requires Kafka); skip if not available aggregate_id = str(uuid4()) publish_request = { @@ -424,7 +331,7 @@ async def test_publish_custom_event_as_admin(self, client: AsyncClient, test_adm } } - response = await client.post("/api/v1/events/publish", json=publish_request) + response = await test_admin.post("/api/v1/events/publish", json=publish_request) if response.status_code != 200: pytest.skip("Kafka not available for publishing events") @@ -434,16 +341,8 @@ async def test_publish_custom_event_as_admin(self, client: AsyncClient, test_adm assert publish_response.timestamp is not None @pytest.mark.asyncio - async def test_aggregate_events(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_aggregate_events(self, test_user: AsyncClient) -> None: """Test event aggregation.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Create aggregation pipeline aggregation_request = { "pipeline": [ @@ -454,7 +353,7 @@ async def test_aggregate_events(self, client: AsyncClient, test_user: Dict[str, "limit": 10 } - response = await client.post("/api/v1/events/aggregate", json=aggregation_request) + response = await test_user.post("/api/v1/events/aggregate", json=aggregation_request) assert response.status_code == 200 results = response.json() @@ -469,51 +368,26 @@ async def test_aggregate_events(self, client: AsyncClient, test_user: Dict[str, assert result["count"] >= 0 @pytest.mark.asyncio - async def test_delete_event_requires_admin(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_delete_event_requires_admin(self, test_user: AsyncClient) -> None: """Test that deleting events requires admin privileges.""" - # Login as regular user - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - - # Try to delete an event + # Try to delete an event (logged in as regular user via fixture) fake_event_id = str(uuid4()) - response = await client.delete(f"/api/v1/events/{fake_event_id}") + response = await test_user.delete(f"/api/v1/events/{fake_event_id}") assert response.status_code == 403 # Forbidden for non-admin @pytest.mark.asyncio - async def test_replay_aggregate_events_requires_admin(self, client: AsyncClient, - test_user: Dict[str, str]) -> None: + async def test_replay_aggregate_events_requires_admin(self, test_user: AsyncClient) -> None: """Test that replaying events requires admin privileges.""" - # Login as regular user - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - - # Try to replay events + # Try to replay events (logged in as regular user via fixture) aggregate_id = str(uuid4()) - response = await client.post(f"/api/v1/events/replay/{aggregate_id}?dry_run=true") + response = await test_user.post(f"/api/v1/events/replay/{aggregate_id}?dry_run=true") assert response.status_code == 403 # Forbidden for non-admin @pytest.mark.asyncio - async def test_replay_aggregate_events_dry_run(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_replay_aggregate_events_dry_run(self, test_admin: AsyncClient) -> None: """Test replaying events in dry-run mode.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get an existing aggregate ID from events - events_response = await client.get("/api/v1/events/user?limit=1") + events_response = await test_admin.get("/api/v1/events/user?limit=1") assert events_response.status_code == 200 events_data = events_response.json() @@ -521,7 +395,9 @@ async def test_replay_aggregate_events_dry_run(self, client: AsyncClient, test_a aggregate_id = events_data["events"][0]["aggregate_id"] # Try dry-run replay - response = await client.post(f"/api/v1/events/replay/{aggregate_id}?dry_run=true") + response = await test_admin.post( + f"/api/v1/events/replay/{aggregate_id}?dry_run=true" + ) if response.status_code == 200: replay_data = response.json() @@ -529,7 +405,7 @@ async def test_replay_aggregate_events_dry_run(self, client: AsyncClient, test_a assert replay_response.dry_run is True assert replay_response.aggregate_id == aggregate_id - assert replay_response.event_count >= 0 + assert replay_response.event_count is not None and replay_response.event_count >= 0 if replay_response.event_types: assert isinstance(replay_response.event_types, list) @@ -543,18 +419,10 @@ async def test_replay_aggregate_events_dry_run(self, client: AsyncClient, test_a assert "detail" in error_data @pytest.mark.asyncio - async def test_event_pagination(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_event_pagination(self, test_user: AsyncClient) -> None: """Test event pagination.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get first page - page1_response = await client.get("/api/v1/events/user?limit=5&skip=0") + page1_response = await test_user.get("/api/v1/events/user?limit=5&skip=0") assert page1_response.status_code == 200 page1_data = page1_response.json() @@ -562,7 +430,7 @@ async def test_event_pagination(self, client: AsyncClient, test_user: Dict[str, # If there are more than 5 events, get second page if page1.total > 5: - page2_response = await client.get("/api/v1/events/user?limit=5&skip=5") + page2_response = await test_user.get("/api/v1/events/user?limit=5&skip=5") assert page2_response.status_code == 200 page2_data = page2_response.json() @@ -581,46 +449,51 @@ async def test_event_pagination(self, client: AsyncClient, test_user: Dict[str, assert len(page1_ids.intersection(page2_ids)) == 0 @pytest.mark.asyncio - async def test_events_isolation_between_users(self, client: AsyncClient, - test_user: Dict[str, str], - test_admin: Dict[str, str]) -> None: + async def test_events_isolation_between_users(self, test_user: AsyncClient, + test_admin: AsyncClient) -> None: """Test that events are properly isolated between users.""" - # Get events as regular user - user_login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - user_login_response = await client.post("/api/v1/auth/login", data=user_login_data) - assert user_login_response.status_code == 200 + # Get each user's user_id from /me endpoint + user_me_response = await test_user.get("/api/v1/auth/me") + assert user_me_response.status_code == 200 + user_id = user_me_response.json()["user_id"] + + admin_me_response = await test_admin.get("/api/v1/auth/me") + assert admin_me_response.status_code == 200 + admin_id = admin_me_response.json()["user_id"] - user_events_response = await client.get("/api/v1/events/user?limit=10") + # Verify the two users are different + assert user_id != admin_id, "Test requires two different users" + + # Get events as regular user + user_events_response = await test_user.get("/api/v1/events/user?limit=10") assert user_events_response.status_code == 200 user_events = user_events_response.json() - user_event_ids = [e["event_id"] for e in user_events["events"]] + user_event_ids = {e["event_id"] for e in user_events["events"]} # Get events as admin (without include_all_users flag) - admin_login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - admin_login_response = await client.post("/api/v1/auth/login", data=admin_login_data) - assert admin_login_response.status_code == 200 - - admin_events_response = await client.get("/api/v1/events/user?limit=10") + admin_events_response = await test_admin.get("/api/v1/events/user?limit=10") assert admin_events_response.status_code == 200 admin_events = admin_events_response.json() - admin_event_ids = [e["event_id"] for e in admin_events["events"]] + admin_event_ids = {e["event_id"] for e in admin_events["events"]} - # Events should be different (unless users share some events) - # But user IDs in events should be different + # Verify user events belong to the user for event in user_events["events"]: meta = event.get("metadata") or {} if meta.get("user_id"): - assert meta["user_id"] == test_user.get("user_id", meta["user_id"]) + assert meta["user_id"] == user_id, ( + f"User event has wrong user_id: expected {user_id}, got {meta['user_id']}" + ) + # Verify admin events belong to the admin for event in admin_events["events"]: meta = event.get("metadata") or {} if meta.get("user_id"): - assert meta["user_id"] == test_admin.get("user_id", meta["user_id"]) + assert meta["user_id"] == admin_id, ( + f"Admin event has wrong user_id: expected {admin_id}, got {meta['user_id']}" + ) + + # Verify no overlap in event IDs between users (proper isolation) + overlap = user_event_ids & admin_event_ids + assert not overlap, f"Events leaked between users: {overlap}" diff --git a/backend/tests/integration/test_health_routes.py b/backend/tests/integration/test_health_routes.py index 40105561..15485b8a 100644 --- a/backend/tests/integration/test_health_routes.py +++ b/backend/tests/integration/test_health_routes.py @@ -1,6 +1,5 @@ import asyncio import time -from typing import Dict import pytest from httpx import AsyncClient @@ -48,26 +47,18 @@ async def test_concurrent_liveness_fetch(self, client: AsyncClient) -> None: assert all(r.status_code == 200 for r in responses) @pytest.mark.asyncio - async def test_app_responds_during_load(self, client: AsyncClient, test_user: Dict[str, str]) -> None: - # Login first for creating load - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - + async def test_app_responds_during_load(self, client: AsyncClient, test_user: AsyncClient) -> None: # Create some load with execution requests - async def create_load(): + async def create_load() -> int | None: execution_request = { "script": "print('Load test')", "lang": "python", "lang_version": "3.11" } try: - response = await client.post("/api/v1/execute", json=execution_request) + response = await test_user.post("/api/v1/execute", json=execution_request) return response.status_code - except: + except Exception: return None # Start load generation diff --git a/backend/tests/integration/test_notifications_routes.py b/backend/tests/integration/test_notifications_routes.py index 5e60164f..9cb9764a 100644 --- a/backend/tests/integration/test_notifications_routes.py +++ b/backend/tests/integration/test_notifications_routes.py @@ -1,17 +1,21 @@ -from typing import Dict - import pytest -from httpx import AsyncClient - +from app.domain.enums.notification import ( + NotificationChannel, + NotificationSeverity, + NotificationStatus, +) from app.schemas_pydantic.notification import ( + DeleteNotificationResponse, NotificationListResponse, - NotificationStatus, - NotificationChannel, NotificationSubscription, SubscriptionsResponse, UnreadCountResponse, - DeleteNotificationResponse ) +from app.services.notification_service import NotificationService +from dishka import AsyncContainer +from httpx import AsyncClient + +from tests.helpers.eventually import eventually @pytest.mark.integration @@ -31,18 +35,10 @@ async def test_notifications_require_authentication(self, client: AsyncClient) - for word in ["not authenticated", "unauthorized", "login"]) @pytest.mark.asyncio - async def test_list_user_notifications(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_list_user_notifications(self, test_user: AsyncClient) -> None: """Test listing user's notifications.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # List notifications - response = await client.get("/api/v1/notifications?limit=10&offset=0") + response = await test_user.get("/api/v1/notifications?limit=10&offset=0") assert response.status_code == 200 # Validate response structure @@ -66,19 +62,16 @@ async def test_list_user_notifications(self, client: AsyncClient, test_user: Dic assert n.created_at is not None @pytest.mark.asyncio - async def test_filter_notifications_by_status(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_filter_notifications_by_status(self, test_user: AsyncClient) -> None: """Test filtering notifications by status.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Test different status filters - for status in [NotificationStatus.READ.value, NotificationStatus.DELIVERED.value, NotificationStatus.SKIPPED.value]: - response = await client.get(f"/api/v1/notifications?status={status}&limit=5") + statuses = [ + NotificationStatus.READ.value, + NotificationStatus.DELIVERED.value, + NotificationStatus.SKIPPED.value, + ] + for status in statuses: + response = await test_user.get(f"/api/v1/notifications?status={status}&limit=5") assert response.status_code == 200 notifications_data = response.json() @@ -89,18 +82,10 @@ async def test_filter_notifications_by_status(self, client: AsyncClient, test_us assert notification.status == status @pytest.mark.asyncio - async def test_get_unread_count(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_unread_count(self, test_user: AsyncClient) -> None: """Test getting count of unread notifications.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get unread count - response = await client.get("/api/v1/notifications/unread-count") + response = await test_user.get("/api/v1/notifications/unread-count") assert response.status_code == 200 # Validate response @@ -113,18 +98,10 @@ async def test_get_unread_count(self, client: AsyncClient, test_user: Dict[str, # Note: listing cannot filter 'unread' directly; count is authoritative @pytest.mark.asyncio - async def test_mark_notification_as_read(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_mark_notification_as_read(self, test_user: AsyncClient) -> None: """Test marking a notification as read.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get an unread notification - notifications_response = await client.get( + notifications_response = await test_user.get( f"/api/v1/notifications?status={NotificationStatus.DELIVERED.value}&limit=1") assert notifications_response.status_code == 200 @@ -133,11 +110,11 @@ async def test_mark_notification_as_read(self, client: AsyncClient, test_user: D notification_id = notifications_data["notifications"][0]["notification_id"] # Mark as read - mark_response = await client.put(f"/api/v1/notifications/{notification_id}/read") + mark_response = await test_user.put(f"/api/v1/notifications/{notification_id}/read") assert mark_response.status_code == 204 # Verify it's now marked as read - updated_response = await client.get("/api/v1/notifications") + updated_response = await test_user.get("/api/v1/notifications") assert updated_response.status_code == 200 updated_data = updated_response.json() @@ -148,20 +125,11 @@ async def test_mark_notification_as_read(self, client: AsyncClient, test_user: D break @pytest.mark.asyncio - async def test_mark_nonexistent_notification_as_read(self, client: AsyncClient, - test_user: Dict[str, str]) -> None: + async def test_mark_nonexistent_notification_as_read(self, test_user: AsyncClient) -> None: """Test marking a non-existent notification as read.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Try to mark non-existent notification as read fake_notification_id = "00000000-0000-0000-0000-000000000000" - response = await client.put(f"/api/v1/notifications/{fake_notification_id}/read") + response = await test_user.put(f"/api/v1/notifications/{fake_notification_id}/read") # Prefer 404; if backend returns 500, treat as unavailable feature if response.status_code == 500: pytest.skip("Backend returns 500 for unknown notification IDs") @@ -172,44 +140,56 @@ async def test_mark_nonexistent_notification_as_read(self, client: AsyncClient, assert "not found" in error_data["detail"].lower() @pytest.mark.asyncio - async def test_mark_all_notifications_as_read(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_mark_all_notifications_as_read( + self, test_user: AsyncClient, scope: AsyncContainer + ) -> None: """Test marking all notifications as read.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 + # Get user_id and create a test notification to ensure we have something to mark + me_response = await test_user.get("/api/v1/auth/me") + assert me_response.status_code == 200 + user_id = me_response.json()["user_id"] + + notification_service = await scope.get(NotificationService) + await notification_service.create_notification( + user_id=user_id, + subject="Test notification", + body="Created for mark-all-read test", + tags=["test"], + severity=NotificationSeverity.LOW, + channel=NotificationChannel.IN_APP, + ) + + # Wait for async delivery to complete (create_notification uses asyncio.create_task) + async def _has_unread() -> None: + resp = await test_user.get("/api/v1/notifications/unread-count") + assert resp.status_code == 200 + assert resp.json()["unread_count"] >= 1 + + await eventually(_has_unread, timeout=5.0, interval=0.1) + + # Get initial unread count (guaranteed >= 1 now) + initial_response = await test_user.get("/api/v1/notifications/unread-count") + assert initial_response.status_code == 200 + initial_count = initial_response.json()["unread_count"] # Mark all as read - mark_all_response = await client.post("/api/v1/notifications/mark-all-read") + mark_all_response = await test_user.post("/api/v1/notifications/mark-all-read") assert mark_all_response.status_code == 204 - # Verify all are now read - # Verify via unread-count only (list endpoint pagination can hide remaining) - unread_response = await client.get("/api/v1/notifications/unread-count") - assert unread_response.status_code == 200 + # Verify strict decrease - no branching needed + final_response = await test_user.get("/api/v1/notifications/unread-count") + assert final_response.status_code == 200 + final_count = final_response.json()["unread_count"] - # Also verify unread count is 0 - count_response = await client.get("/api/v1/notifications/unread-count") - assert count_response.status_code == 200 - count_data = count_response.json() - assert count_data["unread_count"] == 0 + assert final_count < initial_count, ( + f"mark-all-read must decrease unread count: was {initial_count}, now {final_count}" + ) @pytest.mark.asyncio - async def test_get_notification_subscriptions(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_notification_subscriptions(self, test_user: AsyncClient) -> None: """Test getting user's notification subscriptions.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get subscriptions - response = await client.get("/api/v1/notifications/subscriptions") + response = await test_user.get("/api/v1/notifications/subscriptions") assert response.status_code == 200 # Validate response @@ -239,16 +219,8 @@ async def test_get_notification_subscriptions(self, client: AsyncClient, test_us assert subscription.slack_webhook.startswith("http") @pytest.mark.asyncio - async def test_update_notification_subscription(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_update_notification_subscription(self, test_user: AsyncClient) -> None: """Test updating a notification subscription.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Update in_app subscription update_data = { "enabled": True, @@ -257,7 +229,7 @@ async def test_update_notification_subscription(self, client: AsyncClient, test_ "exclude_tags": ["external_alert"] } - response = await client.put("/api/v1/notifications/subscriptions/in_app", json=update_data) + response = await test_user.put("/api/v1/notifications/subscriptions/in_app", json=update_data) assert response.status_code == 200 # Validate response @@ -271,7 +243,7 @@ async def test_update_notification_subscription(self, client: AsyncClient, test_ assert updated_subscription.exclude_tags == update_data["exclude_tags"] # Verify the update persisted - get_response = await client.get("/api/v1/notifications/subscriptions") + get_response = await test_user.get("/api/v1/notifications/subscriptions") assert get_response.status_code == 200 subs_data = get_response.json() @@ -284,16 +256,8 @@ async def test_update_notification_subscription(self, client: AsyncClient, test_ break @pytest.mark.asyncio - async def test_update_webhook_subscription(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_update_webhook_subscription(self, test_user: AsyncClient) -> None: """Test updating webhook subscription with URL.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Update webhook subscription update_data = { "enabled": True, @@ -303,7 +267,7 @@ async def test_update_webhook_subscription(self, client: AsyncClient, test_user: "exclude_tags": [] } - response = await client.put("/api/v1/notifications/subscriptions/webhook", json=update_data) + response = await test_user.put("/api/v1/notifications/subscriptions/webhook", json=update_data) assert response.status_code == 200 # Validate response @@ -316,16 +280,8 @@ async def test_update_webhook_subscription(self, client: AsyncClient, test_user: assert updated_subscription.severities == update_data["severities"] @pytest.mark.asyncio - async def test_update_slack_subscription(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_update_slack_subscription(self, test_user: AsyncClient) -> None: """Test updating Slack subscription with webhook.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Update Slack subscription update_data = { "enabled": True, @@ -335,7 +291,7 @@ async def test_update_slack_subscription(self, client: AsyncClient, test_user: D "exclude_tags": [] } - response = await client.put("/api/v1/notifications/subscriptions/slack", json=update_data) + response = await test_user.put("/api/v1/notifications/subscriptions/slack", json=update_data) # Slack subscription may be disabled by config; 422 indicates validation assert response.status_code in [200, 422] if response.status_code == 422: @@ -351,18 +307,10 @@ async def test_update_slack_subscription(self, client: AsyncClient, test_user: D assert updated_subscription.severities == update_data["severities"] @pytest.mark.asyncio - async def test_delete_notification(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_delete_notification(self, test_user: AsyncClient) -> None: """Test deleting a notification.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get a notification to delete - notifications_response = await client.get("/api/v1/notifications?limit=1") + notifications_response = await test_user.get("/api/v1/notifications?limit=1") assert notifications_response.status_code == 200 notifications_data = notifications_response.json() @@ -370,7 +318,7 @@ async def test_delete_notification(self, client: AsyncClient, test_user: Dict[st notification_id = notifications_data["notifications"][0]["notification_id"] # Delete the notification - delete_response = await client.delete(f"/api/v1/notifications/{notification_id}") + delete_response = await test_user.delete(f"/api/v1/notifications/{notification_id}") assert delete_response.status_code == 200 # Validate response @@ -379,7 +327,7 @@ async def test_delete_notification(self, client: AsyncClient, test_user: Dict[st assert "deleted" in delete_result.message.lower() # Verify it's deleted - list_response = await client.get("/api/v1/notifications") + list_response = await test_user.get("/api/v1/notifications") assert list_response.status_code == 200 list_data = list_response.json() @@ -388,19 +336,11 @@ async def test_delete_notification(self, client: AsyncClient, test_user: Dict[st assert notification_id not in notification_ids @pytest.mark.asyncio - async def test_delete_nonexistent_notification(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_delete_nonexistent_notification(self, test_user: AsyncClient) -> None: """Test deleting a non-existent notification.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Try to delete non-existent notification fake_notification_id = "00000000-0000-0000-0000-000000000000" - response = await client.delete(f"/api/v1/notifications/{fake_notification_id}") + response = await test_user.delete(f"/api/v1/notifications/{fake_notification_id}") assert response.status_code == 404 error_data = response.json() @@ -408,18 +348,10 @@ async def test_delete_nonexistent_notification(self, client: AsyncClient, test_u assert "not found" in error_data["detail"].lower() @pytest.mark.asyncio - async def test_notification_pagination(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_notification_pagination(self, test_user: AsyncClient) -> None: """Test notification pagination.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get first page - page1_response = await client.get("/api/v1/notifications?limit=5&offset=0") + page1_response = await test_user.get("/api/v1/notifications?limit=5&offset=0") assert page1_response.status_code == 200 page1_data = page1_response.json() @@ -427,7 +359,7 @@ async def test_notification_pagination(self, client: AsyncClient, test_user: Dic # If there are more than 5 notifications, get second page if page1.total > 5: - page2_response = await client.get("/api/v1/notifications?limit=5&offset=5") + page2_response = await test_user.get("/api/v1/notifications?limit=5&offset=5") assert page2_response.status_code == 200 page2_data = page2_response.json() @@ -444,35 +376,18 @@ async def test_notification_pagination(self, client: AsyncClient, test_user: Dic assert len(page1_ids.intersection(page2_ids)) == 0 @pytest.mark.asyncio - async def test_notifications_isolation_between_users(self, client: AsyncClient, - test_user: Dict[str, str], - test_admin: Dict[str, str]) -> None: + async def test_notifications_isolation_between_users(self, test_user: AsyncClient, + test_admin: AsyncClient) -> None: """Test that notifications are isolated between users.""" - # Login as regular user - user_login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - user_login_response = await client.post("/api/v1/auth/login", data=user_login_data) - assert user_login_response.status_code == 200 - # Get user's notifications - user_notifications_response = await client.get("/api/v1/notifications") + user_notifications_response = await test_user.get("/api/v1/notifications") assert user_notifications_response.status_code == 200 user_notifications_data = user_notifications_response.json() user_notification_ids = [n["notification_id"] for n in user_notifications_data["notifications"]] - # Login as admin - admin_login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - admin_login_response = await client.post("/api/v1/auth/login", data=admin_login_data) - assert admin_login_response.status_code == 200 - # Get admin's notifications - admin_notifications_response = await client.get("/api/v1/notifications") + admin_notifications_response = await test_admin.get("/api/v1/notifications") assert admin_notifications_response.status_code == 200 admin_notifications_data = admin_notifications_response.json() @@ -483,21 +398,13 @@ async def test_notifications_isolation_between_users(self, client: AsyncClient, assert len(set(user_notification_ids).intersection(set(admin_notification_ids))) == 0 @pytest.mark.asyncio - async def test_invalid_notification_channel(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_invalid_notification_channel(self, test_user: AsyncClient) -> None: """Test updating subscription with invalid channel.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Try invalid channel update_data = { "enabled": True, "severities": ["medium"] } - response = await client.put("/api/v1/notifications/subscriptions/invalid_channel", json=update_data) + response = await test_user.put("/api/v1/notifications/subscriptions/invalid_channel", json=update_data) assert response.status_code in [400, 404, 422] diff --git a/backend/tests/integration/test_replay_routes.py b/backend/tests/integration/test_replay_routes.py index 1cdf73ec..4cd74755 100644 --- a/backend/tests/integration/test_replay_routes.py +++ b/backend/tests/integration/test_replay_routes.py @@ -1,20 +1,14 @@ -import asyncio -from datetime import datetime, timezone, timedelta -from typing import Dict +from datetime import datetime, timedelta, timezone from uuid import uuid4 import pytest -from httpx import AsyncClient - from app.domain.enums.events import EventType -from app.domain.enums.replay import ReplayStatus, ReplayType, ReplayTarget -from app.schemas_pydantic.replay import ( - ReplayRequest, - ReplayResponse, - SessionSummary, - CleanupResponse -) +from app.domain.enums.replay import ReplayStatus, ReplayTarget, ReplayType +from app.domain.replay.models import ReplayFilter +from app.schemas_pydantic.replay import CleanupResponse, ReplayRequest, ReplayResponse, SessionSummary from app.schemas_pydantic.replay_models import ReplaySession +from httpx import AsyncClient + from tests.helpers.eventually import eventually @@ -23,12 +17,12 @@ class TestReplayRoutes: """Test replay endpoints against real backend.""" @pytest.mark.asyncio - async def test_replay_requires_admin_authentication(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_replay_requires_admin_authentication(self, test_user: AsyncClient) -> None: """Test that replay endpoints require admin authentication.""" - # Already authenticated via test_user fixture + # test_user is authenticated but not admin # Try to access replay endpoints as non-admin - response = await client.get("/api/v1/replay/sessions") + response = await test_user.get("/api/v1/replay/sessions") assert response.status_code == 403 error_data = response.json() @@ -37,22 +31,22 @@ async def test_replay_requires_admin_authentication(self, client: AsyncClient, t for word in ["admin", "forbidden", "denied"]) @pytest.mark.asyncio - async def test_create_replay_session(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_create_replay_session(self, test_admin: AsyncClient) -> None: """Test creating a replay session.""" - # Already authenticated via test_admin fixture - # Create replay session replay_request = ReplayRequest( replay_type=ReplayType.QUERY, target=ReplayTarget.KAFKA, - event_types=[EventType.EXECUTION_REQUESTED, EventType.EXECUTION_COMPLETED], - start_time=datetime.now(timezone.utc) - timedelta(days=7), - end_time=datetime.now(timezone.utc), + filter=ReplayFilter( + event_types=[EventType.EXECUTION_REQUESTED, EventType.EXECUTION_COMPLETED], + start_time=datetime.now(timezone.utc) - timedelta(days=7), + end_time=datetime.now(timezone.utc), + ), speed_multiplier=1.0, preserve_timestamps=True, ).model_dump(mode="json") - response = await client.post("/api/v1/replay/sessions", json=replay_request) + response = await test_admin.post("/api/v1/replay/sessions", json=replay_request) assert response.status_code in [200, 422] if response.status_code == 422: return @@ -67,12 +61,10 @@ async def test_create_replay_session(self, client: AsyncClient, test_admin: Dict assert replay_response.message is not None @pytest.mark.asyncio - async def test_list_replay_sessions(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_list_replay_sessions(self, test_admin: AsyncClient) -> None: """Test listing replay sessions.""" - # Already authenticated via test_admin fixture - # List replay sessions - response = await client.get("/api/v1/replay/sessions?limit=10") + response = await test_admin.get("/api/v1/replay/sessions?limit=10") assert response.status_code in [200, 404] if response.status_code != 200: return @@ -88,27 +80,27 @@ async def test_list_replay_sessions(self, client: AsyncClient, test_admin: Dict[ assert session_summary.created_at is not None @pytest.mark.asyncio - async def test_get_replay_session_details(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_get_replay_session_details(self, test_admin: AsyncClient) -> None: """Test getting detailed information about a replay session.""" - # Already authenticated via test_admin fixture - # Create a session first replay_request = ReplayRequest( replay_type=ReplayType.QUERY, target=ReplayTarget.KAFKA, - event_types=[EventType.USER_LOGGED_IN], - start_time=datetime.now(timezone.utc) - timedelta(hours=24), - end_time=datetime.now(timezone.utc), + filter=ReplayFilter( + event_types=[EventType.USER_LOGGED_IN], + start_time=datetime.now(timezone.utc) - timedelta(hours=24), + end_time=datetime.now(timezone.utc), + ), speed_multiplier=2.0, ).model_dump(mode="json") - create_response = await client.post("/api/v1/replay/sessions", json=replay_request) + create_response = await test_admin.post("/api/v1/replay/sessions", json=replay_request) assert create_response.status_code == 200 session_id = create_response.json()["session_id"] # Get session details - detail_response = await client.get(f"/api/v1/replay/sessions/{session_id}") + detail_response = await test_admin.get(f"/api/v1/replay/sessions/{session_id}") assert detail_response.status_code in [200, 404] if detail_response.status_code != 200: return @@ -121,33 +113,27 @@ async def test_get_replay_session_details(self, client: AsyncClient, test_admin: assert session.created_at is not None @pytest.mark.asyncio - async def test_start_replay_session(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_start_replay_session(self, test_admin: AsyncClient) -> None: """Test starting a replay session.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Create a session replay_request = ReplayRequest( replay_type=ReplayType.QUERY, target=ReplayTarget.KAFKA, - event_types=[EventType.SYSTEM_ERROR], - start_time=datetime.now(timezone.utc) - timedelta(hours=1), - end_time=datetime.now(timezone.utc), + filter=ReplayFilter( + event_types=[EventType.SYSTEM_ERROR], + start_time=datetime.now(timezone.utc) - timedelta(hours=1), + end_time=datetime.now(timezone.utc), + ), speed_multiplier=1.0, ).model_dump(mode="json") - create_response = await client.post("/api/v1/replay/sessions", json=replay_request) + create_response = await test_admin.post("/api/v1/replay/sessions", json=replay_request) assert create_response.status_code == 200 session_id = create_response.json()["session_id"] # Start the session - start_response = await client.post(f"/api/v1/replay/sessions/{session_id}/start") + start_response = await test_admin.post(f"/api/v1/replay/sessions/{session_id}/start") assert start_response.status_code in [200, 404] if start_response.status_code != 200: return @@ -160,39 +146,33 @@ async def test_start_replay_session(self, client: AsyncClient, test_admin: Dict[ assert start_result.message is not None @pytest.mark.asyncio - async def test_pause_and_resume_replay_session(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_pause_and_resume_replay_session(self, test_admin: AsyncClient) -> None: """Test pausing and resuming a replay session.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Create and start a session replay_request = ReplayRequest( replay_type=ReplayType.QUERY, target=ReplayTarget.KAFKA, - event_types=[EventType.SYSTEM_ERROR], - start_time=datetime.now(timezone.utc) - timedelta(hours=2), - end_time=datetime.now(timezone.utc), + filter=ReplayFilter( + event_types=[EventType.SYSTEM_ERROR], + start_time=datetime.now(timezone.utc) - timedelta(hours=2), + end_time=datetime.now(timezone.utc), + ), speed_multiplier=0.5, ).model_dump(mode="json") - create_response = await client.post("/api/v1/replay/sessions", json=replay_request) + create_response = await test_admin.post("/api/v1/replay/sessions", json=replay_request) assert create_response.status_code == 200 session_id = create_response.json()["session_id"] # Start the session - start_response = await client.post(f"/api/v1/replay/sessions/{session_id}/start") + start_response = await test_admin.post(f"/api/v1/replay/sessions/{session_id}/start") assert start_response.status_code in [200, 404] if start_response.status_code != 200: return # Pause the session - pause_response = await client.post(f"/api/v1/replay/sessions/{session_id}/pause") + pause_response = await test_admin.post(f"/api/v1/replay/sessions/{session_id}/pause") # Could succeed or fail if session already completed or not found assert pause_response.status_code in [200, 400, 404] @@ -205,7 +185,9 @@ async def test_pause_and_resume_replay_session(self, client: AsyncClient, test_a # If paused, try to resume if pause_result.status == "paused": - resume_response = await client.post(f"/api/v1/replay/sessions/{session_id}/resume") + resume_response = await test_admin.post( + f"/api/v1/replay/sessions/{session_id}/resume" + ) assert resume_response.status_code == 200 resume_data = resume_response.json() @@ -215,33 +197,27 @@ async def test_pause_and_resume_replay_session(self, client: AsyncClient, test_a assert resume_result.status in [ReplayStatus.RUNNING, ReplayStatus.COMPLETED] @pytest.mark.asyncio - async def test_cancel_replay_session(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_cancel_replay_session(self, test_admin: AsyncClient) -> None: """Test cancelling a replay session.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Create a session replay_request = ReplayRequest( replay_type=ReplayType.QUERY, target=ReplayTarget.KAFKA, - event_types=[EventType.SYSTEM_ERROR], - start_time=datetime.now(timezone.utc) - timedelta(hours=1), - end_time=datetime.now(timezone.utc), + filter=ReplayFilter( + event_types=[EventType.SYSTEM_ERROR], + start_time=datetime.now(timezone.utc) - timedelta(hours=1), + end_time=datetime.now(timezone.utc), + ), speed_multiplier=1.0, ).model_dump(mode="json") - create_response = await client.post("/api/v1/replay/sessions", json=replay_request) + create_response = await test_admin.post("/api/v1/replay/sessions", json=replay_request) assert create_response.status_code == 200 session_id = create_response.json()["session_id"] # Cancel the session - cancel_response = await client.post(f"/api/v1/replay/sessions/{session_id}/cancel") + cancel_response = await test_admin.post(f"/api/v1/replay/sessions/{session_id}/cancel") assert cancel_response.status_code in [200, 404] if cancel_response.status_code != 200: return @@ -254,16 +230,8 @@ async def test_cancel_replay_session(self, client: AsyncClient, test_admin: Dict assert cancel_result.message is not None @pytest.mark.asyncio - async def test_filter_sessions_by_status(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_filter_sessions_by_status(self, test_admin: AsyncClient) -> None: """Test filtering replay sessions by status.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Test different status filters for status in [ ReplayStatus.CREATED.value, @@ -272,7 +240,7 @@ async def test_filter_sessions_by_status(self, client: AsyncClient, test_admin: ReplayStatus.FAILED.value, ReplayStatus.CANCELLED.value, ]: - response = await client.get(f"/api/v1/replay/sessions?status={status}&limit=5") + response = await test_admin.get(f"/api/v1/replay/sessions?status={status}&limit=5") assert response.status_code in [200, 404] if response.status_code != 200: continue @@ -286,18 +254,10 @@ async def test_filter_sessions_by_status(self, client: AsyncClient, test_admin: assert session.status == status @pytest.mark.asyncio - async def test_cleanup_old_sessions(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_cleanup_old_sessions(self, test_admin: AsyncClient) -> None: """Test cleanup of old replay sessions.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Cleanup sessions older than 24 hours - cleanup_response = await client.post("/api/v1/replay/cleanup?older_than_hours=24") + cleanup_response = await test_admin.post("/api/v1/replay/cleanup?older_than_hours=24") assert cleanup_response.status_code == 200 cleanup_data = cleanup_response.json() @@ -308,19 +268,11 @@ async def test_cleanup_old_sessions(self, client: AsyncClient, test_admin: Dict[ assert cleanup_result.message is not None @pytest.mark.asyncio - async def test_get_nonexistent_session(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_get_nonexistent_session(self, test_admin: AsyncClient) -> None: """Test getting a non-existent replay session.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Try to get non-existent session fake_session_id = str(uuid4()) - response = await client.get(f"/api/v1/replay/sessions/{fake_session_id}") + response = await test_admin.get(f"/api/v1/replay/sessions/{fake_session_id}") # Could return 404 or empty result assert response.status_code in [200, 404] @@ -329,33 +281,17 @@ async def test_get_nonexistent_session(self, client: AsyncClient, test_admin: Di assert "detail" in error_data @pytest.mark.asyncio - async def test_start_nonexistent_session(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_start_nonexistent_session(self, test_admin: AsyncClient) -> None: """Test starting a non-existent replay session.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Try to start non-existent session fake_session_id = str(uuid4()) - response = await client.post(f"/api/v1/replay/sessions/{fake_session_id}/start") + response = await test_admin.post(f"/api/v1/replay/sessions/{fake_session_id}/start") # Should fail assert response.status_code in [400, 404] @pytest.mark.asyncio - async def test_replay_session_state_transitions(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_replay_session_state_transitions(self, test_admin: AsyncClient) -> None: """Test valid state transitions for replay sessions.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Create a session replay_request = { "name": f"State Test Session {uuid4().hex[:8]}", @@ -369,7 +305,7 @@ async def test_replay_session_state_transitions(self, client: AsyncClient, test_ "speed_multiplier": 1.0 } - create_response = await client.post("/api/v1/replay/sessions", json=replay_request) + create_response = await test_admin.post("/api/v1/replay/sessions", json=replay_request) assert create_response.status_code in [200, 422] if create_response.status_code != 200: return @@ -379,28 +315,20 @@ async def test_replay_session_state_transitions(self, client: AsyncClient, test_ assert initial_status == ReplayStatus.CREATED # Can't pause a session that hasn't started - pause_response = await client.post(f"/api/v1/replay/sessions/{session_id}/pause") + pause_response = await test_admin.post(f"/api/v1/replay/sessions/{session_id}/pause") assert pause_response.status_code in [400, 409] # Invalid state transition # Can start from pending - start_response = await client.post(f"/api/v1/replay/sessions/{session_id}/start") + start_response = await test_admin.post(f"/api/v1/replay/sessions/{session_id}/start") assert start_response.status_code == 200 # Can't start again if already running - start_again_response = await client.post(f"/api/v1/replay/sessions/{session_id}/start") + start_again_response = await test_admin.post(f"/api/v1/replay/sessions/{session_id}/start") assert start_again_response.status_code in [200, 400, 409] # Might be idempotent or error @pytest.mark.asyncio - async def test_replay_with_complex_filters(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_replay_with_complex_filters(self, test_admin: AsyncClient) -> None: """Test creating replay session with complex filters.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Create session with complex filters replay_request = { "name": f"Complex Filter Session {uuid4().hex[:8]}", @@ -416,7 +344,6 @@ async def test_replay_with_complex_filters(self, client: AsyncClient, test_admin "end_time": datetime.now(timezone.utc).isoformat(), "aggregate_id": str(uuid4()), "correlation_id": str(uuid4()), - "user_id": test_admin.get("user_id"), "service_name": "execution-service" }, "target_topic": "complex-filter-topic", @@ -425,7 +352,7 @@ async def test_replay_with_complex_filters(self, client: AsyncClient, test_admin "batch_size": 100 } - response = await client.post("/api/v1/replay/sessions", json=replay_request) + response = await test_admin.post("/api/v1/replay/sessions", json=replay_request) assert response.status_code in [200, 422] if response.status_code != 200: return @@ -437,16 +364,8 @@ async def test_replay_with_complex_filters(self, client: AsyncClient, test_admin assert replay_response.status in ["created", "pending"] @pytest.mark.asyncio - async def test_replay_session_progress_tracking(self, client: AsyncClient, test_admin: Dict[str, str]) -> None: + async def test_replay_session_progress_tracking(self, test_admin: AsyncClient) -> None: """Test tracking progress of replay sessions.""" - # Login as admin - login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Create and start a session replay_request = { "name": f"Progress Test Session {uuid4().hex[:8]}", @@ -460,7 +379,7 @@ async def test_replay_session_progress_tracking(self, client: AsyncClient, test_ "speed_multiplier": 10.0 # Fast replay } - create_response = await client.post("/api/v1/replay/sessions", json=replay_request) + create_response = await test_admin.post("/api/v1/replay/sessions", json=replay_request) assert create_response.status_code in [200, 422] if create_response.status_code != 200: return @@ -468,18 +387,18 @@ async def test_replay_session_progress_tracking(self, client: AsyncClient, test_ session_id = create_response.json()["session_id"] # Start the session - await client.post(f"/api/v1/replay/sessions/{session_id}/start") + await test_admin.post(f"/api/v1/replay/sessions/{session_id}/start") # Poll progress without fixed sleeps async def _check_progress_once() -> None: - detail_response = await client.get(f"/api/v1/replay/sessions/{session_id}") + detail_response = await test_admin.get(f"/api/v1/replay/sessions/{session_id}") assert detail_response.status_code == 200 session_data = detail_response.json() session = ReplaySession(**session_data) - if session.events_replayed is not None and session.events_total is not None: - assert 0 <= session.events_replayed <= session.events_total - if session.events_total > 0: - progress = (session.events_replayed / session.events_total) * 100 + if session.replayed_events is not None and session.total_events is not None: + assert 0 <= session.replayed_events <= session.total_events + if session.total_events > 0: + progress = (session.replayed_events / session.total_events) * 100 assert 0.0 <= progress <= 100.0 await eventually(_check_progress_once, timeout=5.0, interval=0.5) diff --git a/backend/tests/integration/test_saga_routes.py b/backend/tests/integration/test_saga_routes.py index b26d7b90..b084dc90 100644 --- a/backend/tests/integration/test_saga_routes.py +++ b/backend/tests/integration/test_saga_routes.py @@ -1,6 +1,5 @@ import asyncio import uuid -from typing import Dict import pytest from app.domain.enums.saga import SagaState @@ -23,15 +22,11 @@ async def test_get_saga_requires_auth(self, client: AsyncClient) -> None: assert "Not authenticated" in response.json()["detail"] @pytest.mark.asyncio - async def test_get_saga_not_found( - self, client: AsyncClient, test_user: Dict[str, str] - ) -> None: + async def test_get_saga_not_found(self, test_user: AsyncClient) -> None: """Test getting non-existent saga returns 404.""" - # Already authenticated via test_user fixture - # Try to get non-existent saga saga_id = str(uuid.uuid4()) - response = await client.get(f"/api/v1/sagas/{saga_id}") + response = await test_user.get(f"/api/v1/sagas/{saga_id}") assert response.status_code == 404 assert "not found" in response.json()["detail"] @@ -45,28 +40,20 @@ async def test_get_execution_sagas_requires_auth( assert response.status_code == 401 @pytest.mark.asyncio - async def test_get_execution_sagas_empty( - self, client: AsyncClient, test_user: Dict[str, str] - ) -> None: + async def test_get_execution_sagas_empty(self, test_user: AsyncClient) -> None: """Test getting sagas for execution with no sagas.""" - # Already authenticated via test_user fixture - # Get sagas for non-existent execution execution_id = str(uuid.uuid4()) - response = await client.get(f"/api/v1/sagas/execution/{execution_id}") + response = await test_user.get(f"/api/v1/sagas/execution/{execution_id}") # Access to a random execution (non-owned) must be forbidden assert response.status_code == 403 @pytest.mark.asyncio - async def test_get_execution_sagas_with_state_filter( - self, client: AsyncClient, test_user: Dict[str, str] - ) -> None: + async def test_get_execution_sagas_with_state_filter(self, test_user: AsyncClient) -> None: """Test getting execution sagas filtered by state.""" - # Already authenticated via test_user fixture - # Get sagas filtered by running state execution_id = str(uuid.uuid4()) - response = await client.get( + response = await test_user.get( f"/api/v1/sagas/execution/{execution_id}", params={"state": SagaState.RUNNING.value} ) @@ -84,14 +71,10 @@ async def test_list_sagas_requires_auth(self, client: AsyncClient) -> None: assert response.status_code == 401 @pytest.mark.asyncio - async def test_list_sagas_paginated( - self, client: AsyncClient, test_user: Dict[str, str] - ) -> None: + async def test_list_sagas_paginated(self, test_user: AsyncClient) -> None: """Test listing sagas with pagination.""" - # Already authenticated via test_user fixture - # List sagas with pagination - response = await client.get( + response = await test_user.get( "/api/v1/sagas/", params={"limit": 10, "offset": 0} ) @@ -103,20 +86,10 @@ async def test_list_sagas_paginated( assert saga_list.total >= 0 @pytest.mark.asyncio - async def test_list_sagas_with_state_filter( - self, client: AsyncClient, test_user: Dict[str, str] - ) -> None: + async def test_list_sagas_with_state_filter(self, test_user: AsyncClient) -> None: """Test listing sagas filtered by state.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # List completed sagas - response = await client.get( + response = await test_user.get( "/api/v1/sagas/", params={"state": SagaState.COMPLETED.value, "limit": 5} ) @@ -129,20 +102,10 @@ async def test_list_sagas_with_state_filter( assert saga.state == SagaState.COMPLETED @pytest.mark.asyncio - async def test_list_sagas_large_limit( - self, client: AsyncClient, test_user: Dict[str, str] - ) -> None: + async def test_list_sagas_large_limit(self, test_user: AsyncClient) -> None: """Test listing sagas with maximum limit.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # List with max limit - response = await client.get( + response = await test_user.get( "/api/v1/sagas/", params={"limit": 1000} ) @@ -152,20 +115,10 @@ async def test_list_sagas_large_limit( assert len(saga_list.sagas) <= 1000 @pytest.mark.asyncio - async def test_list_sagas_invalid_limit( - self, client: AsyncClient, test_user: Dict[str, str] - ) -> None: + async def test_list_sagas_invalid_limit(self, test_user: AsyncClient) -> None: """Test listing sagas with invalid limit.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Try with limit too large - response = await client.get( + response = await test_user.get( "/api/v1/sagas/", params={"limit": 10000} ) @@ -179,56 +132,28 @@ async def test_cancel_saga_requires_auth(self, client: AsyncClient) -> None: assert response.status_code == 401 @pytest.mark.asyncio - async def test_cancel_saga_not_found( - self, client: AsyncClient, test_user: Dict[str, str] - ) -> None: + async def test_cancel_saga_not_found(self, test_user: AsyncClient) -> None: """Test cancelling non-existent saga returns 404.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Try to cancel non-existent saga saga_id = str(uuid.uuid4()) - response = await client.post(f"/api/v1/sagas/{saga_id}/cancel") + response = await test_user.post(f"/api/v1/sagas/{saga_id}/cancel") assert response.status_code == 404 assert "not found" in response.json()["detail"] @pytest.mark.asyncio async def test_saga_access_control( self, - client: AsyncClient, - test_user: Dict[str, str], - another_user: Dict[str, str] + test_user: AsyncClient, + another_user: AsyncClient ) -> None: """Test that users can only access their own sagas.""" # User 1 lists their sagas - login_data1 = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response1 = await client.post("/api/v1/auth/login", data=login_data1) - assert login_response1.status_code == 200 - - response1 = await client.get("/api/v1/sagas/") + response1 = await test_user.get("/api/v1/sagas/") assert response1.status_code == 200 user1_sagas = SagaListResponse(**response1.json()) - # Logout - await client.post("/api/v1/auth/logout") - # User 2 lists their sagas - login_data2 = { - "username": another_user["username"], - "password": another_user["password"] - } - login_response2 = await client.post("/api/v1/auth/login", data=login_data2) - assert login_response2.status_code == 200 - - response2 = await client.get("/api/v1/sagas/") + response2 = await another_user.get("/api/v1/sagas/") assert response2.status_code == 200 user2_sagas = SagaListResponse(**response2.json()) @@ -239,27 +164,17 @@ async def test_saga_access_control( assert isinstance(user2_sagas.sagas, list) @pytest.mark.asyncio - async def test_get_saga_with_details( - self, client: AsyncClient, test_user: Dict[str, str] - ) -> None: + async def test_get_saga_with_details(self, test_user: AsyncClient) -> None: """Test getting saga with all details when it exists.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # First list sagas to potentially find one - list_response = await client.get("/api/v1/sagas/", params={"limit": 1}) + list_response = await test_user.get("/api/v1/sagas/", params={"limit": 1}) assert list_response.status_code == 200 saga_list = SagaListResponse(**list_response.json()) if saga_list.sagas and len(saga_list.sagas) > 0: # Get details of the first saga saga_id = saga_list.sagas[0].saga_id - response = await client.get(f"/api/v1/sagas/{saga_id}") + response = await test_user.get(f"/api/v1/sagas/{saga_id}") # Could be 200 if accessible or 403 if not owned by user assert response.status_code in [200, 403, 404] @@ -270,20 +185,10 @@ async def test_get_saga_with_details( assert saga_status.state in [s.value for s in SagaState] @pytest.mark.asyncio - async def test_list_sagas_with_offset( - self, client: AsyncClient, test_user: Dict[str, str] - ) -> None: + async def test_list_sagas_with_offset(self, test_user: AsyncClient) -> None: """Test listing sagas with offset for pagination.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Get first page - response1 = await client.get( + response1 = await test_user.get( "/api/v1/sagas/", params={"limit": 5, "offset": 0} ) @@ -291,7 +196,7 @@ async def test_list_sagas_with_offset( page1 = SagaListResponse(**response1.json()) # Get second page - response2 = await client.get( + response2 = await test_user.get( "/api/v1/sagas/", params={"limit": 5, "offset": 5} ) @@ -306,20 +211,10 @@ async def test_list_sagas_with_offset( assert len(page1_ids.intersection(page2_ids)) == 0 @pytest.mark.asyncio - async def test_cancel_saga_invalid_state( - self, client: AsyncClient, test_user: Dict[str, str] - ) -> None: + async def test_cancel_saga_invalid_state(self, test_user: AsyncClient) -> None: """Test cancelling a saga in invalid state (if one exists).""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Try to find a completed saga to cancel - response = await client.get( + response = await test_user.get( "/api/v1/sagas/", params={"state": SagaState.COMPLETED.value, "limit": 1} ) @@ -329,29 +224,19 @@ async def test_cancel_saga_invalid_state( if saga_list.sagas and len(saga_list.sagas) > 0: # Try to cancel completed saga (should fail) saga_id = saga_list.sagas[0].saga_id - cancel_response = await client.post(f"/api/v1/sagas/{saga_id}/cancel") + cancel_response = await test_user.post(f"/api/v1/sagas/{saga_id}/cancel") # Should get 400 (invalid state) or 403 (access denied) or 404 assert cancel_response.status_code in [400, 403, 404] @pytest.mark.asyncio - async def test_get_execution_sagas_multiple_states( - self, client: AsyncClient, test_user: Dict[str, str] - ) -> None: + async def test_get_execution_sagas_multiple_states(self, test_user: AsyncClient) -> None: """Test getting execution sagas across different states.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - execution_id = str(uuid.uuid4()) # Test each state filter for state in [SagaState.CREATED, SagaState.RUNNING, SagaState.COMPLETED, SagaState.FAILED, SagaState.CANCELLED]: - response = await client.get( + response = await test_user.get( f"/api/v1/sagas/execution/{execution_id}", params={"state": state.value} ) @@ -366,20 +251,10 @@ async def test_get_execution_sagas_multiple_states( assert saga.state == state @pytest.mark.asyncio - async def test_saga_response_structure( - self, client: AsyncClient, test_user: Dict[str, str] - ) -> None: + async def test_saga_response_structure(self, test_user: AsyncClient) -> None: """Test that saga responses have correct structure.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # List sagas to verify response structure - response = await client.get("/api/v1/sagas/", params={"limit": 1}) + response = await test_user.get("/api/v1/sagas/", params={"limit": 1}) assert response.status_code == 200 saga_list = SagaListResponse(**response.json()) @@ -397,22 +272,12 @@ async def test_saga_response_structure( assert hasattr(saga, "created_at") @pytest.mark.asyncio - async def test_concurrent_saga_access( - self, client: AsyncClient, test_user: Dict[str, str] - ) -> None: + async def test_concurrent_saga_access(self, test_user: AsyncClient) -> None: """Test concurrent access to saga endpoints.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Make multiple concurrent requests tasks = [] for i in range(5): - tasks.append(client.get( + tasks.append(test_user.get( "/api/v1/sagas/", params={"limit": 10, "offset": i * 10} )) diff --git a/backend/tests/integration/test_saved_scripts_routes.py b/backend/tests/integration/test_saved_scripts_routes.py index cc42b39c..2561ad60 100644 --- a/backend/tests/integration/test_saved_scripts_routes.py +++ b/backend/tests/integration/test_saved_scripts_routes.py @@ -1,14 +1,10 @@ from datetime import datetime, timezone -from typing import Dict from uuid import UUID, uuid4 import pytest +from app.schemas_pydantic.saved_script import SavedScriptResponse from httpx import AsyncClient -from app.schemas_pydantic.saved_script import ( - SavedScriptResponse -) - @pytest.mark.integration class TestSavedScripts: @@ -33,7 +29,7 @@ async def test_create_script_requires_authentication(self, client: AsyncClient) for word in ["not authenticated", "unauthorized", "login"]) @pytest.mark.asyncio - async def test_create_and_retrieve_saved_script(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_create_and_retrieve_saved_script(self, test_user: AsyncClient) -> None: """Test creating and retrieving a saved script.""" # Already authenticated via test_user fixture @@ -47,8 +43,8 @@ async def test_create_and_retrieve_saved_script(self, client: AsyncClient, test_ "description": f"Test script created at {datetime.now(timezone.utc).isoformat()}" } - # Create the script - create_response = await client.post("/api/v1/scripts", json=script_data) + # Create the script (include CSRF header for POST request) + create_response = await test_user.post("/api/v1/scripts", json=script_data) assert create_response.status_code in [200, 201] # Validate response structure @@ -77,7 +73,7 @@ async def test_create_and_retrieve_saved_script(self, client: AsyncClient, test_ assert saved_script.updated_at is not None # Now retrieve the script by ID - get_response = await client.get(f"/api/v1/scripts/{saved_script.script_id}") + get_response = await test_user.get(f"/api/v1/scripts/{saved_script.script_id}") assert get_response.status_code == 200 retrieved_data = get_response.json() @@ -89,7 +85,7 @@ async def test_create_and_retrieve_saved_script(self, client: AsyncClient, test_ assert retrieved_script.script == script_data["script"] @pytest.mark.asyncio - async def test_list_user_scripts(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_list_user_scripts(self, test_user: AsyncClient) -> None: """Test listing user's saved scripts.""" # Already authenticated via test_user fixture @@ -120,12 +116,12 @@ async def test_list_user_scripts(self, client: AsyncClient, test_user: Dict[str, created_ids = [] for script_data in scripts_to_create: - create_response = await client.post("/api/v1/scripts", json=script_data) + create_response = await test_user.post("/api/v1/scripts", json=script_data) if create_response.status_code in [200, 201]: created_ids.append(create_response.json()["script_id"]) # List all scripts - list_response = await client.get("/api/v1/scripts") + list_response = await test_user.get("/api/v1/scripts") assert list_response.status_code == 200 scripts_list = list_response.json() @@ -149,7 +145,7 @@ async def test_list_user_scripts(self, client: AsyncClient, test_user: Dict[str, assert created_id in returned_ids @pytest.mark.asyncio - async def test_update_saved_script(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_update_saved_script(self, test_user: AsyncClient) -> None: """Test updating a saved script.""" # Already authenticated via test_user fixture @@ -163,7 +159,7 @@ async def test_update_saved_script(self, client: AsyncClient, test_user: Dict[st "description": "Original description" } - create_response = await client.post("/api/v1/scripts", json=original_data) + create_response = await test_user.post("/api/v1/scripts", json=original_data) assert create_response.status_code in [200, 201] created_script = create_response.json() @@ -179,7 +175,7 @@ async def test_update_saved_script(self, client: AsyncClient, test_user: Dict[st "description": "Updated description with more details" } - update_response = await client.put(f"/api/v1/scripts/{script_id}", json=updated_data) + update_response = await test_user.put(f"/api/v1/scripts/{script_id}", json=updated_data) assert update_response.status_code == 200 updated_script_data = update_response.json() @@ -202,7 +198,7 @@ async def test_update_saved_script(self, client: AsyncClient, test_user: Dict[st assert updated_script.updated_at > updated_script.created_at @pytest.mark.asyncio - async def test_delete_saved_script(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_delete_saved_script(self, test_user: AsyncClient) -> None: """Test deleting a saved script.""" # Already authenticated via test_user fixture @@ -216,17 +212,17 @@ async def test_delete_saved_script(self, client: AsyncClient, test_user: Dict[st "description": "This script will be deleted" } - create_response = await client.post("/api/v1/scripts", json=script_data) + create_response = await test_user.post("/api/v1/scripts", json=script_data) assert create_response.status_code in [200, 201] script_id = create_response.json()["script_id"] # Delete the script - delete_response = await client.delete(f"/api/v1/scripts/{script_id}") + delete_response = await test_user.delete(f"/api/v1/scripts/{script_id}") assert delete_response.status_code in [200, 204] # Verify it's deleted by trying to get it - get_response = await client.get(f"/api/v1/scripts/{script_id}") + get_response = await test_user.get(f"/api/v1/scripts/{script_id}") assert get_response.status_code in [404, 403] if get_response.status_code == 404: @@ -234,17 +230,9 @@ async def test_delete_saved_script(self, client: AsyncClient, test_user: Dict[st assert "detail" in error_data @pytest.mark.asyncio - async def test_cannot_access_other_users_scripts(self, client: AsyncClient, test_user: Dict[str, str], - test_admin: Dict[str, str]) -> None: + async def test_cannot_access_other_users_scripts(self, test_user: AsyncClient, + test_admin: AsyncClient) -> None: """Test that users cannot access scripts created by other users.""" - # Create a script as regular user - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - unique_id = str(uuid4())[:8] user_script_data = { "name": f"User Private Script {unique_id}", @@ -254,27 +242,19 @@ async def test_cannot_access_other_users_scripts(self, client: AsyncClient, test "description": "Should only be visible to creating user" } - create_response = await client.post("/api/v1/scripts", json=user_script_data) + create_response = await test_user.post("/api/v1/scripts", json=user_script_data) assert create_response.status_code in [200, 201] user_script_id = create_response.json()["script_id"] - # Now login as admin - admin_login_data = { - "username": test_admin["username"], - "password": test_admin["password"] - } - admin_login_response = await client.post("/api/v1/auth/login", data=admin_login_data) - assert admin_login_response.status_code == 200 - # Try to access the user's script as admin # This should fail unless admin has special permissions - get_response = await client.get(f"/api/v1/scripts/{user_script_id}") + get_response = await test_admin.get(f"/api/v1/scripts/{user_script_id}") # Should be forbidden or not found assert get_response.status_code in [403, 404] # List scripts as admin - should not include user's script - list_response = await client.get("/api/v1/scripts") + list_response = await test_admin.get("/api/v1/scripts") assert list_response.status_code == 200 admin_scripts = list_response.json() @@ -283,16 +263,8 @@ async def test_cannot_access_other_users_scripts(self, client: AsyncClient, test assert user_script_id not in admin_script_ids @pytest.mark.asyncio - async def test_script_with_invalid_language(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_script_with_invalid_language(self, test_user: AsyncClient) -> None: """Test that invalid language/version combinations are handled.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - unique_id = str(uuid4())[:8] # Try invalid language @@ -303,7 +275,7 @@ async def test_script_with_invalid_language(self, client: AsyncClient, test_user "lang_version": "1.0" } - response = await client.post("/api/v1/scripts", json=invalid_lang_data) + response = await test_user.post("/api/v1/scripts", json=invalid_lang_data) # Backend may accept arbitrary lang values; accept any outcome assert response.status_code in [200, 201, 400, 422] @@ -315,21 +287,13 @@ async def test_script_with_invalid_language(self, client: AsyncClient, test_user "lang_version": "2.7" # Python 2 likely not supported } - response = await client.post("/api/v1/scripts", json=unsupported_version_data) + response = await test_user.post("/api/v1/scripts", json=unsupported_version_data) # Might accept but warn, or reject assert response.status_code in [200, 201, 400, 422] @pytest.mark.asyncio - async def test_script_name_constraints(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_script_name_constraints(self, test_user: AsyncClient) -> None: """Test script name validation and constraints.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - # Test empty name empty_name_data = { "name": "", @@ -338,7 +302,7 @@ async def test_script_name_constraints(self, client: AsyncClient, test_user: Dic "lang_version": "3.11" } - response = await client.post("/api/v1/scripts", json=empty_name_data) + response = await test_user.post("/api/v1/scripts", json=empty_name_data) assert response.status_code in [200, 201, 400, 422] # Test very long name @@ -349,23 +313,15 @@ async def test_script_name_constraints(self, client: AsyncClient, test_user: Dic "lang_version": "3.11" } - response = await client.post("/api/v1/scripts", json=long_name_data) + response = await test_user.post("/api/v1/scripts", json=long_name_data) # Should either accept or reject based on max length if response.status_code in [400, 422]: error_data = response.json() assert "detail" in error_data @pytest.mark.asyncio - async def test_script_content_size_limits(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_script_content_size_limits(self, test_user: AsyncClient) -> None: """Test script content size limits.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - unique_id = str(uuid4())[:8] # Test reasonably large script (should succeed) @@ -377,7 +333,7 @@ async def test_script_content_size_limits(self, client: AsyncClient, test_user: "lang_version": "3.11" } - response = await client.post("/api/v1/scripts", json=large_script_data) + response = await test_user.post("/api/v1/scripts", json=large_script_data) assert response.status_code in [200, 201] # Test excessively large script (should fail) @@ -389,23 +345,15 @@ async def test_script_content_size_limits(self, client: AsyncClient, test_user: "lang_version": "3.11" } - response = await client.post("/api/v1/scripts", json=huge_script_data) + response = await test_user.post("/api/v1/scripts", json=huge_script_data) # If backend returns 500 for oversized payload, skip as environment-specific if response.status_code >= 500: pytest.skip("Backend returned 5xx for oversized script upload") assert response.status_code in [200, 201, 400, 413, 422] @pytest.mark.asyncio - async def test_update_nonexistent_script(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_update_nonexistent_script(self, test_user: AsyncClient) -> None: """Test updating a non-existent script.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - fake_script_id = "00000000-0000-0000-0000-000000000000" update_data = { @@ -415,7 +363,7 @@ async def test_update_nonexistent_script(self, client: AsyncClient, test_user: D "lang_version": "3.11" } - response = await client.put(f"/api/v1/scripts/{fake_script_id}", json=update_data) + response = await test_user.put(f"/api/v1/scripts/{fake_script_id}", json=update_data) # Non-existent script must return 404/403 (no server error) assert response.status_code in [404, 403] @@ -423,33 +371,17 @@ async def test_update_nonexistent_script(self, client: AsyncClient, test_user: D assert "detail" in error_data @pytest.mark.asyncio - async def test_delete_nonexistent_script(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_delete_nonexistent_script(self, test_user: AsyncClient) -> None: """Test deleting a non-existent script.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - fake_script_id = "00000000-0000-0000-0000-000000000000" - response = await client.delete(f"/api/v1/scripts/{fake_script_id}") + response = await test_user.delete(f"/api/v1/scripts/{fake_script_id}") # Could be 404 (not found) or 204 (idempotent delete) assert response.status_code in [404, 403, 204] @pytest.mark.asyncio - async def test_scripts_persist_across_sessions(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_scripts_persist_across_sessions(self, test_user: AsyncClient) -> None: """Test that scripts persist across login sessions.""" - # First session - create script - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_response = await client.post("/api/v1/auth/login", data=login_data) - assert login_response.status_code == 200 - unique_id = str(uuid4())[:8] script_data = { "name": f"Persistent Script {unique_id}", @@ -459,21 +391,33 @@ async def test_scripts_persist_across_sessions(self, client: AsyncClient, test_u "description": "Testing persistence" } - create_response = await client.post("/api/v1/scripts", json=script_data) + create_response = await test_user.post("/api/v1/scripts", json=script_data) assert create_response.status_code in [200, 201] script_id = create_response.json()["script_id"] - # Logout - logout_response = await client.post("/api/v1/auth/logout") + # Get username before logout so we can re-login + me_response = await test_user.get("/api/v1/auth/me") + assert me_response.status_code == 200 + username = me_response.json()["username"] + + # Logout - this clears cookies via Set-Cookie response + logout_response = await test_user.post("/api/v1/auth/logout") assert logout_response.status_code == 200 - # Second session - retrieve script - login_response2 = await client.post("/api/v1/auth/login", data=login_data) - assert login_response2.status_code == 200 + # Re-login to get fresh authentication + login_response = await test_user.post( + "/api/v1/auth/login", + data={"username": username, "password": "TestPass123!"}, + ) + assert login_response.status_code == 200 + + # Update CSRF header from new session + csrf_token = login_response.json().get("csrf_token", "") + test_user.headers["X-CSRF-Token"] = csrf_token - # Script should still exist - get_response = await client.get(f"/api/v1/scripts/{script_id}") + # Script should still exist after logout/login cycle + get_response = await test_user.get(f"/api/v1/scripts/{script_id}") assert get_response.status_code == 200 retrieved_script = SavedScriptResponse(**get_response.json()) diff --git a/backend/tests/integration/test_sse_routes.py b/backend/tests/integration/test_sse_routes.py index ace4bc48..e7c2ff8f 100644 --- a/backend/tests/integration/test_sse_routes.py +++ b/backend/tests/integration/test_sse_routes.py @@ -1,22 +1,18 @@ import asyncio import json -from typing import Dict +from contextlib import aclosing +from typing import Any from uuid import uuid4 import pytest -from httpx import AsyncClient - from app.domain.enums.notification import NotificationSeverity, NotificationStatus -from app.schemas_pydantic.sse import RedisNotificationMessage, SSEHealthResponse -from app.infrastructure.kafka.events.pod import PodCreatedEvent from app.infrastructure.kafka.events.metadata import AvroEventMetadata +from app.infrastructure.kafka.events.pod import PodCreatedEvent +from app.schemas_pydantic.sse import RedisNotificationMessage, SSEHealthResponse from app.services.sse.redis_bus import SSERedisBus from app.services.sse.sse_service import SSEService -from tests.helpers.eventually import eventually - - -# Note: httpx with ASGITransport doesn't support SSE streaming -# We test SSE functionality directly through the service, not HTTP +from dishka import AsyncContainer +from httpx import AsyncClient @pytest.mark.integration @@ -38,169 +34,135 @@ async def test_sse_requires_authentication(self, client: AsyncClient) -> None: assert r.status_code == 401 @pytest.mark.asyncio - async def test_sse_health_status(self, client: AsyncClient, test_user: Dict[str, str]) -> None: - r = await client.get("/api/v1/events/health") + async def test_sse_health_status(self, test_user: AsyncClient) -> None: + r = await test_user.get("/api/v1/events/health") assert r.status_code == 200 health = SSEHealthResponse(**r.json()) assert health.status in ("healthy", "degraded", "unhealthy", "draining") assert isinstance(health.active_connections, int) and health.active_connections >= 0 @pytest.mark.asyncio - async def test_notification_stream_service(self, scope, test_user: Dict[str, str]) -> None: # type: ignore[valid-type] + async def test_notification_stream_service(self, scope: AsyncContainer, test_user: AsyncClient) -> None: """Test SSE notification stream directly through service (httpx doesn't support SSE streaming).""" sse_service: SSEService = await scope.get(SSEService) bus: SSERedisBus = await scope.get(SSERedisBus) user_id = f"user-{uuid4().hex[:8]}" - - # Create notification stream generator - stream_gen = sse_service.create_notification_stream(user_id) - - # Collect events with timeout - events = [] - async def collect_events(): - async for event in stream_gen: - if "data" in event: - data = json.loads(event["data"]) - events.append(data) - if data.get("event_type") == "notification" and data.get("subject") == "Hello": - break - - # Start collecting events - collect_task = asyncio.create_task(collect_events()) - - # Wait until the initial 'connected' event is received - async def _connected() -> None: - assert len(events) > 0 and events[0].get("event_type") == "connected" - await eventually(_connected, timeout=2.0, interval=0.05) - - # Publish a notification - notification = RedisNotificationMessage( - notification_id=f"notif-{uuid4().hex[:8]}", - severity=NotificationSeverity.MEDIUM, - status=NotificationStatus.PENDING, - tags=[], - subject="Hello", - body="World", - action_url="", - created_at="2024-01-01T00:00:00Z", - ) - await bus.publish_notification(user_id, notification) - - # Wait for collection to complete - try: - await asyncio.wait_for(collect_task, timeout=2.0) - except asyncio.TimeoutError: - collect_task.cancel() - - # Verify we got notification - notif_events = [e for e in events if e.get("event_type") == "notification" and e.get("subject") == "Hello"] - assert len(notif_events) > 0 + + events: list[dict[str, Any]] = [] + notification_received = False + + async with aclosing(sse_service.create_notification_stream(user_id)) as stream: + try: + async with asyncio.timeout(3.0): + async for event in stream: + if "data" not in event: + continue + data = json.loads(event["data"]) + events.append(data) + + # Wait for "subscribed" event - Redis subscription is now ready + if data.get("event_type") == "subscribed": + notification = RedisNotificationMessage( + notification_id=f"notif-{uuid4().hex[:8]}", + severity=NotificationSeverity.MEDIUM, + status=NotificationStatus.PENDING, + tags=[], + subject="Hello", + body="World", + action_url="", + created_at="2024-01-01T00:00:00Z", + ) + await bus.publish_notification(user_id, notification) + + # Stop when we receive the notification + if data.get("event_type") == "notification" and data.get("subject") == "Hello": + notification_received = True + break + except TimeoutError: + pass + + assert notification_received, f"Expected notification, got events: {events}" @pytest.mark.asyncio - async def test_execution_event_stream_service(self, scope, test_user: Dict[str, str]) -> None: # type: ignore[valid-type] + async def test_execution_event_stream_service(self, scope: AsyncContainer, test_user: AsyncClient) -> None: """Test SSE execution stream directly through service (httpx doesn't support SSE streaming).""" exec_id = f"e-{uuid4().hex[:8]}" - user_id = "test-user-id" - + user_id = f"user-{uuid4().hex[:8]}" + sse_service: SSEService = await scope.get(SSEService) bus: SSERedisBus = await scope.get(SSERedisBus) - - # Create execution stream generator - stream_gen = sse_service.create_execution_stream(exec_id, user_id) - - # Collect events - events = [] - async def collect_events(): - async for event in stream_gen: - if "data" in event: - data = json.loads(event["data"]) - events.append(data) - if data.get("event_type") == "pod_created" or data.get("type") == "pod_created": - break - - # Start collecting - collect_task = asyncio.create_task(collect_events()) - - # Wait until the initial 'connected' event is received - async def _connected() -> None: - assert len(events) > 0 and events[0].get("event_type") == "connected" - await eventually(_connected, timeout=2.0, interval=0.05) - - # Publish pod event - ev = PodCreatedEvent( - execution_id=exec_id, - pod_name=f"executor-{exec_id}", - namespace="default", - metadata=AvroEventMetadata(service_name="tests", service_version="1"), - ) - await bus.publish_event(exec_id, ev) - - # Wait for collection - try: - await asyncio.wait_for(collect_task, timeout=2.0) - except asyncio.TimeoutError: - collect_task.cancel() - - # Verify pod event received - pod_events = [e for e in events if e.get("event_type") == "pod_created" or e.get("type") == "pod_created"] - assert len(pod_events) > 0 + + events: list[dict[str, Any]] = [] + pod_event_received = False + + async with aclosing(sse_service.create_execution_stream(exec_id, user_id)) as stream: + try: + async with asyncio.timeout(3.0): + async for event in stream: + if "data" not in event: + continue + data = json.loads(event["data"]) + events.append(data) + + # Wait for "subscribed" event - Redis subscription is now ready + if data.get("event_type") == "subscribed": + ev = PodCreatedEvent( + execution_id=exec_id, + pod_name=f"executor-{exec_id}", + namespace="default", + metadata=AvroEventMetadata(service_name="tests", service_version="1"), + ) + await bus.publish_event(exec_id, ev) + + # Stop when we receive the pod event + if data.get("event_type") == "pod_created": + pod_event_received = True + break + except TimeoutError: + pass + + assert pod_event_received, f"Expected pod_created event, got events: {events}" @pytest.mark.asyncio async def test_sse_route_requires_auth(self, client: AsyncClient) -> None: """Test that SSE routes require authentication (HTTP-level test only).""" - # Test notification stream requires auth r = await client.get("/api/v1/events/notifications/stream") assert r.status_code == 401 - - # Test execution stream requires auth + exec_id = str(uuid4()) r = await client.get(f"/api/v1/events/executions/{exec_id}") assert r.status_code == 401 @pytest.mark.asyncio - async def test_sse_endpoint_returns_correct_headers(self, client: AsyncClient, test_user: Dict[str, str]) -> None: - task = asyncio.create_task(client.get("/api/v1/events/notifications/stream")) - - async def _tick() -> None: - return None - await eventually(_tick, timeout=0.1, interval=0.01) - - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - r = await client.get("/api/v1/events/health") + async def test_sse_endpoint_returns_correct_headers(self, test_user: AsyncClient) -> None: + """Test SSE health endpoint works (streaming tests done via service).""" + r = await test_user.get("/api/v1/events/health") assert r.status_code == 200 assert isinstance(r.json(), dict) @pytest.mark.asyncio - async def test_multiple_concurrent_sse_service_connections(self, scope, test_user: Dict[str, str]) -> None: # type: ignore[valid-type] + async def test_multiple_concurrent_sse_service_connections( + self, scope: AsyncContainer, test_user: AsyncClient + ) -> None: """Test multiple concurrent SSE connections through the service.""" sse_service: SSEService = await scope.get(SSEService) - + async def create_and_verify_stream(user_id: str) -> bool: - stream_gen = sse_service.create_notification_stream(user_id) - try: - async for event in stream_gen: + async with aclosing(sse_service.create_notification_stream(user_id)) as stream: + async for event in stream: if "data" in event: data = json.loads(event["data"]) if data.get("event_type") == "connected": return True - break # Only check first event - except Exception: - return False + break return False - - # Create multiple concurrent connections + results = await asyncio.gather( create_and_verify_stream("user1"), create_and_verify_stream("user2"), create_and_verify_stream("user3"), - return_exceptions=True + return_exceptions=True, ) - - # At least 2 should succeed + successful = sum(1 for r in results if r is True) assert successful >= 2 diff --git a/backend/tests/integration/test_user_settings_routes.py b/backend/tests/integration/test_user_settings_routes.py index c6378351..a780cd25 100644 --- a/backend/tests/integration/test_user_settings_routes.py +++ b/backend/tests/integration/test_user_settings_routes.py @@ -1,84 +1,41 @@ -import asyncio from datetime import datetime, timezone -from typing import Dict -from uuid import uuid4 +from typing import TypedDict import pytest -import pytest_asyncio +from app.schemas_pydantic.user_settings import SettingsHistoryResponse, UserSettings from httpx import AsyncClient - -from app.schemas_pydantic.user_settings import ( - UserSettings, - SettingsHistoryResponse -) from tests.helpers.eventually import eventually -# Force these tests to run sequentially on a single worker to avoid state conflicts -pytestmark = pytest.mark.xdist_group(name="user_settings") +class _NotificationSettings(TypedDict): + execution_completed: bool + execution_failed: bool + system_updates: bool + security_alerts: bool + channels: list[str] + + +class _EditorSettings(TypedDict): + theme: str + font_size: int + tab_size: int + use_tabs: bool + word_wrap: bool + show_line_numbers: bool -@pytest_asyncio.fixture -async def test_user(client: AsyncClient) -> Dict[str, str]: - """Create a fresh user for each test.""" - uid = uuid4().hex[:8] - username = f"test_user_{uid}" - email = f"{username}@example.com" - password = "TestPass123!" - - # Register the user - await client.post("/api/v1/auth/register", json={ - "username": username, - "email": email, - "password": password, - "role": "user" - }) - - # Login to get CSRF token - login_resp = await client.post("/api/v1/auth/login", data={ - "username": username, - "password": password - }) - csrf = login_resp.json().get("csrf_token", "") - - return { - "username": username, - "email": email, - "password": password, - "csrf_token": csrf, - "headers": {"X-CSRF-Token": csrf} - } - - -@pytest_asyncio.fixture -async def test_user2(client: AsyncClient) -> Dict[str, str]: - """Create a second fresh user for isolation tests.""" - uid = uuid4().hex[:8] - username = f"test_user2_{uid}" - email = f"{username}@example.com" - password = "TestPass123!" - - # Register the user - await client.post("/api/v1/auth/register", json={ - "username": username, - "email": email, - "password": password, - "role": "user" - }) - - # Login to get CSRF token - login_resp = await client.post("/api/v1/auth/login", data={ - "username": username, - "password": password - }) - csrf = login_resp.json().get("csrf_token", "") - - return { - "username": username, - "email": email, - "password": password, - "csrf_token": csrf, - "headers": {"X-CSRF-Token": csrf} - } + +class _UpdateSettingsData(TypedDict, total=False): + theme: str + timezone: str + date_format: str + time_format: str + notifications: _NotificationSettings + editor: _EditorSettings + custom_settings: dict[str, str] + + +# Force these tests to run sequentially on a single worker to avoid state conflicts +pytestmark = pytest.mark.xdist_group(name="user_settings") @pytest.mark.integration @@ -98,12 +55,12 @@ async def test_user_settings_require_authentication(self, client: AsyncClient) - for word in ["not authenticated", "unauthorized", "login"]) @pytest.mark.asyncio - async def test_get_user_settings(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_user_settings(self, test_user: AsyncClient) -> None: """Test getting user settings.""" # Already authenticated via test_user fixture # Get user settings - response = await client.get("/api/v1/user/settings/") + response = await test_user.get("/api/v1/user/settings/") assert response.status_code == 200 # Validate response structure @@ -125,11 +82,12 @@ async def test_get_user_settings(self, client: AsyncClient, test_user: Dict[str, assert isinstance(settings.notifications.system_updates, bool) assert isinstance(settings.notifications.security_alerts, bool) - # Verify editor settings + # Verify editor settings assert settings.editor is not None assert isinstance(settings.editor.font_size, int) assert 8 <= settings.editor.font_size <= 32 - assert settings.editor.theme in ["auto", "one-dark", "monokai", "github", "dracula", "solarized", "vs", "vscode"] + assert settings.editor.theme in ["auto", "one-dark", "monokai", "github", "dracula", "solarized", "vs", + "vscode"] assert isinstance(settings.editor.tab_size, int) assert settings.editor.tab_size in [2, 4, 8] assert isinstance(settings.editor.word_wrap, bool) @@ -144,17 +102,17 @@ async def test_get_user_settings(self, client: AsyncClient, test_user: Dict[str, assert isinstance(settings.custom_settings, dict) @pytest.mark.asyncio - async def test_update_user_settings(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_update_user_settings(self, test_user: AsyncClient) -> None: """Test updating user settings.""" # Already authenticated via test_user fixture # Get current settings to preserve original values - original_response = await client.get("/api/v1/user/settings/") + original_response = await test_user.get("/api/v1/user/settings/") assert original_response.status_code == 200 original_settings = original_response.json() # Update settings - update_data = { + update_data: _UpdateSettingsData = { "theme": "dark" if original_settings["theme"] == "light" else "light", "timezone": "America/New_York" if original_settings["timezone"] != "America/New_York" else "UTC", "date_format": "MM/DD/YYYY", @@ -176,7 +134,7 @@ async def test_update_user_settings(self, client: AsyncClient, test_user: Dict[s } } - response = await client.put("/api/v1/user/settings/", json=update_data) + response = await test_user.put("/api/v1/user/settings/", json=update_data) if response.status_code != 200: pytest.fail(f"Status: {response.status_code}, Body: {response.json()}, Data: {update_data}") assert response.status_code == 200 @@ -204,12 +162,12 @@ async def test_update_user_settings(self, client: AsyncClient, test_user: Dict[s assert updated_settings.editor.show_line_numbers == update_data["editor"]["show_line_numbers"] @pytest.mark.asyncio - async def test_update_theme_only(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_update_theme_only(self, test_user: AsyncClient) -> None: """Test updating only the theme setting.""" # Already authenticated via test_user fixture # Get current theme - original_response = await client.get("/api/v1/user/settings/") + original_response = await test_user.get("/api/v1/user/settings/") assert original_response.status_code == 200 original_theme = original_response.json()["theme"] @@ -219,7 +177,7 @@ async def test_update_theme_only(self, client: AsyncClient, test_user: Dict[str, "theme": new_theme } - response = await client.put("/api/v1/user/settings/theme", json=theme_update) + response = await test_user.put("/api/v1/user/settings/theme", json=theme_update) assert response.status_code == 200 # Validate updated settings @@ -233,7 +191,7 @@ async def test_update_theme_only(self, client: AsyncClient, test_user: Dict[str, assert updated_settings.timezone == original_response.json()["timezone"] @pytest.mark.asyncio - async def test_update_notification_settings_only(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_update_notification_settings_only(self, test_user: AsyncClient) -> None: """Test updating only notification settings.""" # Already authenticated via test_user fixture @@ -246,7 +204,7 @@ async def test_update_notification_settings_only(self, client: AsyncClient, test "channels": ["in_app"] } - response = await client.put("/api/v1/user/settings/notifications", json=notification_update) + response = await test_user.put("/api/v1/user/settings/notifications", json=notification_update) if response.status_code >= 500: pytest.skip("Notification settings update not available in this environment") assert response.status_code == 200 @@ -260,7 +218,7 @@ async def test_update_notification_settings_only(self, client: AsyncClient, test assert "in_app" in [str(c) for c in updated_settings.notifications.channels] @pytest.mark.asyncio - async def test_update_editor_settings_only(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_update_editor_settings_only(self, test_user: AsyncClient) -> None: """Test updating only editor settings.""" # Already authenticated via test_user fixture @@ -274,7 +232,7 @@ async def test_update_editor_settings_only(self, client: AsyncClient, test_user: "show_line_numbers": True } - response = await client.put("/api/v1/user/settings/editor", json=editor_update) + response = await test_user.put("/api/v1/user/settings/editor", json=editor_update) if response.status_code >= 500: pytest.skip("Editor settings update not available in this environment") assert response.status_code == 200 @@ -288,7 +246,7 @@ async def test_update_editor_settings_only(self, client: AsyncClient, test_user: assert updated_settings.editor.show_line_numbers == editor_update["show_line_numbers"] @pytest.mark.asyncio - async def test_update_custom_setting(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_update_custom_setting(self, test_user: AsyncClient) -> None: """Test updating a custom setting.""" # Update custom settings via main settings endpoint custom_key = "custom_preference" @@ -299,7 +257,7 @@ async def test_update_custom_setting(self, client: AsyncClient, test_user: Dict[ } } - response = await client.put("/api/v1/user/settings/", json=update_data) + response = await test_user.put("/api/v1/user/settings/", json=update_data) assert response.status_code == 200 # Validate updated settings @@ -308,24 +266,16 @@ async def test_update_custom_setting(self, client: AsyncClient, test_user: Dict[ assert updated_settings.custom_settings[custom_key] == custom_value @pytest.mark.asyncio - async def test_get_settings_history(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_get_settings_history(self, test_user: AsyncClient) -> None: """Test getting settings change history.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_resp = await client.post("/api/v1/auth/login", data=login_data) - assert login_resp.status_code == 200 - # Make some changes to build history (theme change) theme_update = {"theme": "dark"} - response = await client.put("/api/v1/user/settings/theme", json=theme_update) + response = await test_user.put("/api/v1/user/settings/theme", json=theme_update) if response.status_code >= 500: pytest.skip("Settings history not available in this environment") # Get history - history_response = await client.get("/api/v1/user/settings/history") + history_response = await test_user.get("/api/v1/user/settings/history") if history_response.status_code >= 500: pytest.skip("Settings history endpoint not available in this environment") assert history_response.status_code == 200 @@ -339,28 +289,21 @@ async def test_get_settings_history(self, client: AsyncClient, test_user: Dict[s assert entry.timestamp is not None @pytest.mark.asyncio - async def test_restore_settings_to_previous_point(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_restore_settings_to_previous_point(self, test_user: AsyncClient) -> None: """Test restoring settings to a previous point in time.""" - # Login first - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - await client.post("/api/v1/auth/login", data=login_data) - # Get original settings - original_resp = await client.get("/api/v1/user/settings/") + original_resp = await test_user.get("/api/v1/user/settings/") assert original_resp.status_code == 200 original_theme = original_resp.json()["theme"] # Make a change new_theme = "dark" if original_theme != "dark" else "light" - await client.put("/api/v1/user/settings/theme", json={"theme": new_theme}) + await test_user.put("/api/v1/user/settings/theme", json={"theme": new_theme}) # Ensure restore point is distinct by checking time monotonicity prev = datetime.now(timezone.utc) - async def _tick(): + async def _tick() -> None: now = datetime.now(timezone.utc) assert (now - prev).total_seconds() >= 0 @@ -371,11 +314,11 @@ async def _tick(): # Make another change second_theme = "auto" if new_theme != "auto" else "system" - await client.put("/api/v1/user/settings/theme", json={"theme": second_theme}) + await test_user.put("/api/v1/user/settings/theme", json={"theme": second_theme}) # Try to restore to the restore point restore_data = {"timestamp": restore_point} - restore_resp = await client.post("/api/v1/user/settings/restore", json=restore_data) + restore_resp = await test_user.post("/api/v1/user/settings/restore", json=restore_data) # Skip if restore functionality not available if restore_resp.status_code >= 500: @@ -383,26 +326,26 @@ async def _tick(): # If successful, verify the theme was restored if restore_resp.status_code == 200: - current_resp = await client.get("/api/v1/user/settings/") + current_resp = await test_user.get("/api/v1/user/settings/") # Since restore might not work exactly as expected in test environment, # just verify we get valid settings back assert current_resp.status_code == 200 @pytest.mark.asyncio - async def test_invalid_theme_value(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_invalid_theme_value(self, test_user: AsyncClient) -> None: """Test that invalid theme values are rejected.""" # Already authenticated via test_user fixture # Try to update with invalid theme invalid_theme = {"theme": "invalid_theme"} - response = await client.put("/api/v1/user/settings/theme", json=invalid_theme) + response = await test_user.put("/api/v1/user/settings/theme", json=invalid_theme) if response.status_code >= 500: pytest.skip("Theme validation not available in this environment") assert response.status_code in [400, 422] @pytest.mark.asyncio - async def test_invalid_editor_settings(self, client: AsyncClient, test_user: Dict[str, str]) -> None: + async def test_invalid_editor_settings(self, test_user: AsyncClient) -> None: """Test that invalid editor settings are rejected.""" # Already authenticated via test_user fixture @@ -416,44 +359,27 @@ async def test_invalid_editor_settings(self, client: AsyncClient, test_user: Dic "show_line_numbers": True } - response = await client.put("/api/v1/user/settings/editor", json=invalid_editor) + response = await test_user.put("/api/v1/user/settings/editor", json=invalid_editor) if response.status_code >= 500: pytest.skip("Editor validation not available in this environment") assert response.status_code in [400, 422] @pytest.mark.asyncio - async def test_settings_isolation_between_users(self, client: AsyncClient, - test_user: Dict[str, str], - test_user2: Dict[str, str]) -> None: + async def test_settings_isolation_between_users(self, + test_user: AsyncClient, + another_user: AsyncClient) -> None: """Test that settings are isolated between users.""" - # Login as first user - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - await client.post("/api/v1/auth/login", data=login_data) - # Update first user's settings user1_update = { "theme": "dark", "timezone": "America/New_York" } - response = await client.put("/api/v1/user/settings/", json=user1_update) + response = await test_user.put("/api/v1/user/settings/", json=user1_update) assert response.status_code == 200 - # Log out - await client.post("/api/v1/auth/logout") - - # Login as second user - login_data = { - "username": test_user2["username"], - "password": test_user2["password"] - } - await client.post("/api/v1/auth/login", data=login_data) - # Get second user's settings - response = await client.get("/api/v1/user/settings/") + response = await another_user.get("/api/v1/user/settings/") assert response.status_code == 200 user2_settings = response.json() @@ -463,48 +389,36 @@ async def test_settings_isolation_between_users(self, client: AsyncClient, "timezone"] @pytest.mark.asyncio - async def test_settings_persistence(self, client: AsyncClient, test_user: Dict[str, str]) -> None: - """Test that settings persist across login sessions.""" - # Already authenticated via test_user fixture - + async def test_settings_persistence(self, test_user: AsyncClient) -> None: + """Test that settings persist after being saved.""" # Update settings - update_data = { + editor_settings: _EditorSettings = { + "theme": "github", + "font_size": 18, + "tab_size": 8, + "use_tabs": True, + "word_wrap": False, + "show_line_numbers": False + } + update_data: _UpdateSettingsData = { "theme": "dark", "timezone": "Europe/London", - "editor": { - "theme": "github", - "font_size": 18, - "tab_size": 8, - "use_tabs": True, - "word_wrap": False, - "show_line_numbers": False - } + "editor": editor_settings } - response = await client.put("/api/v1/user/settings/", json=update_data) + response = await test_user.put("/api/v1/user/settings/", json=update_data) assert response.status_code == 200 - # Log out - await client.post("/api/v1/auth/logout") - - # Log back in as same user - login_data = { - "username": test_user["username"], - "password": test_user["password"] - } - login_resp = await client.post("/api/v1/auth/login", data=login_data) - assert login_resp.status_code == 200 - - # Get settings again - response = await client.get("/api/v1/user/settings/") + # Get settings again to verify persistence + response = await test_user.get("/api/v1/user/settings/") assert response.status_code == 200 persisted_settings = UserSettings(**response.json()) # Verify settings persisted assert persisted_settings.theme == update_data["theme"] assert persisted_settings.timezone == update_data["timezone"] - assert persisted_settings.editor.theme == update_data["editor"]["theme"] - assert persisted_settings.editor.font_size == update_data["editor"]["font_size"] - assert persisted_settings.editor.tab_size == update_data["editor"]["tab_size"] - assert persisted_settings.editor.word_wrap == update_data["editor"]["word_wrap"] - assert persisted_settings.editor.show_line_numbers == update_data["editor"]["show_line_numbers"] + assert persisted_settings.editor.theme == editor_settings["theme"] + assert persisted_settings.editor.font_size == editor_settings["font_size"] + assert persisted_settings.editor.tab_size == editor_settings["tab_size"] + assert persisted_settings.editor.word_wrap == editor_settings["word_wrap"] + assert persisted_settings.editor.show_line_numbers == editor_settings["show_line_numbers"] diff --git a/backend/tests/load/cli.py b/backend/tests/load/cli.py index e672617d..807f5777 100644 --- a/backend/tests/load/cli.py +++ b/backend/tests/load/cli.py @@ -75,7 +75,7 @@ def main(argv: list[str] | None = None) -> int: if args.base_url: cfg.base_url = args.base_url if args.mode: - cfg.mode = args.mode # type: ignore[assignment] + cfg.mode = args.mode if args.clients is not None: cfg.clients = args.clients if args.concurrency is not None: diff --git a/backend/tests/load/config.py b/backend/tests/load/config.py index a5cf208a..1f29fc42 100644 --- a/backend/tests/load/config.py +++ b/backend/tests/load/config.py @@ -1,43 +1,45 @@ from __future__ import annotations -import os -from dataclasses import dataclass, field from typing import Literal +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + Mode = Literal["monkey", "user", "both"] -@dataclass(slots=True) -class LoadConfig: - base_url: str = field(default_factory=lambda: os.getenv("LOAD_BASE_URL", "https://[::1]:443")) - api_prefix: str = field(default_factory=lambda: os.getenv("LOAD_API_PREFIX", "/api/v1")) - verify_tls: bool = field(default_factory=lambda: os.getenv("LOAD_VERIFY_TLS", "false").lower() in ("1", "true", "yes")) - generate_plots: bool = field(default=False) +class LoadConfig(BaseSettings): + model_config = SettingsConfigDict(env_prefix="LOAD_", case_sensitive=False) + + base_url: str = "https://[::1]:443" + api_prefix: str = "/api/v1" + verify_tls: bool = False + generate_plots: bool = False # Clients and workload - mode: Mode = field(default_factory=lambda: os.getenv("LOAD_MODE", "both")) - clients: int = int(os.getenv("LOAD_CLIENTS", "25")) - concurrency: int = int(os.getenv("LOAD_CONCURRENCY", "10")) + mode: Mode = "both" + clients: int = 25 + concurrency: int = 10 # Default run duration ~3 minutes - duration_seconds: int = int(os.getenv("LOAD_DURATION", "180")) - ramp_up_seconds: int = int(os.getenv("LOAD_RAMP", "5")) + duration_seconds: int = Field(default=180, validation_alias="LOAD_DURATION") + ramp_up_seconds: int = Field(default=5, validation_alias="LOAD_RAMP") # User pool (for user-mode) - auto_register_users: bool = field(default_factory=lambda: os.getenv("LOAD_AUTO_REGISTER", "true").lower() in ("1","true","yes")) - user_prefix: str = os.getenv("LOAD_USER_PREFIX", "loaduser") - user_domain: str = os.getenv("LOAD_USER_DOMAIN", "example.com") - user_password: str = os.getenv("LOAD_USER_PASSWORD", "testpass123!") + auto_register_users: bool = Field(default=True, validation_alias="LOAD_AUTO_REGISTER") + user_prefix: str = "loaduser" + user_domain: str = "example.com" + user_password: str = "testpass123!" # Endpoint toggles - enable_sse: bool = field(default_factory=lambda: os.getenv("LOAD_ENABLE_SSE", "true").lower() in ("1","true","yes")) - enable_saved_scripts: bool = field(default_factory=lambda: os.getenv("LOAD_ENABLE_SCRIPTS", "true").lower() in ("1","true","yes")) - enable_user_settings: bool = field(default_factory=lambda: os.getenv("LOAD_ENABLE_SETTINGS", "true").lower() in ("1","true","yes")) - enable_notifications: bool = field(default_factory=lambda: os.getenv("LOAD_ENABLE_NOTIFICATIONS", "true").lower() in ("1","true","yes")) + enable_sse: bool = True + enable_saved_scripts: bool = Field(default=True, validation_alias="LOAD_ENABLE_SCRIPTS") + enable_user_settings: bool = Field(default=True, validation_alias="LOAD_ENABLE_SETTINGS") + enable_notifications: bool = True # Reporting # Default to tests/load/out relative to current working directory - output_dir: str = field(default_factory=lambda: os.getenv("LOAD_OUTPUT_DIR", "tests/load/out")) + output_dir: str = "tests/load/out" def api(self, path: str) -> str: prefix = self.api_prefix.rstrip("/") diff --git a/backend/tests/load/http_client.py b/backend/tests/load/http_client.py index 94d3d4c4..87c53b8a 100644 --- a/backend/tests/load/http_client.py +++ b/backend/tests/load/http_client.py @@ -67,7 +67,7 @@ async def login(self, username: str, password: str) -> bool: r = await self._request("POST", url, data=httpx.QueryParams(data), headers=headers) if r.status_code == 200: # Extract csrf cookie (not httpOnly) for subsequent writes - for cookie in self.client.cookies.jar: # type: ignore[attr-defined] + for cookie in self.client.cookies.jar: if cookie.name == "csrf_token": self.csrf_token = cookie.value break @@ -107,7 +107,7 @@ async def sse_execution(self, execution_id: str, max_seconds: float = 10.0) -> T # Use a separate streaming client to avoid interfering with normal client timeouts async with httpx.AsyncClient(verify=self.cfg.verify_tls, timeout=None) as s: # Reuse cookies for auth - s.cookies = self.client.cookies.copy() + s.cookies.update(self.client.cookies) t0 = time.perf_counter() try: async with s.stream("GET", url) as resp: diff --git a/backend/tests/load/monkey_runner.py b/backend/tests/load/monkey_runner.py index ece0b9f6..21e14e97 100644 --- a/backend/tests/load/monkey_runner.py +++ b/backend/tests/load/monkey_runner.py @@ -8,7 +8,7 @@ from typing import Any from .config import LoadConfig -from .http_client import APIClient +from .http_client import APIClient, APIUser from .stats import StatsCollector from .strategies import json_value @@ -77,11 +77,12 @@ async def one_client(i: int) -> None: # Half of clients attempt to login/register first if random.random() < 0.5: uname = f"monkey_{_rand(6)}" - await c.register(user := type("U", (), { - "username": uname, - "email": f"{uname}@{cfg.user_domain}", - "password": cfg.user_password - })) + user = APIUser( + username=uname, + email=f"{uname}@{cfg.user_domain}", + password=cfg.user_password + ) + await c.register(user) await c.login(uname, cfg.user_password) # Run until deadline diff --git a/backend/tests/load/plot_report.py b/backend/tests/load/plot_report.py index 54c5c365..b415e15e 100644 --- a/backend/tests/load/plot_report.py +++ b/backend/tests/load/plot_report.py @@ -3,14 +3,38 @@ import argparse import json from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Dict, List, Tuple, TypedDict import matplotlib.pyplot as plt -def _load_report(path: str | Path) -> Dict[str, Any]: +class LatencyStats(TypedDict, total=False): + p50: int + p90: int + p99: int + + +class EndpointData(TypedDict, total=False): + count: int + errors: int + latency_ms_success: LatencyStats + + +class TimelineData(TypedDict, total=False): + seconds: List[int] + rps: List[int] + eps: List[int] + + +class ReportDict(TypedDict, total=False): + timeline: TimelineData + endpoints: Dict[str, EndpointData] + + +def _load_report(path: str | Path) -> ReportDict: with open(path, "r", encoding="utf-8") as f: - return json.load(f) + result: ReportDict = json.load(f) + return result def _ensure_out_dir(path: str | Path) -> Path: @@ -19,8 +43,8 @@ def _ensure_out_dir(path: str | Path) -> Path: return p -def plot_timeline(report: Dict[str, Any], out_dir: Path) -> Path: - tl = report.get("timeline", {}) +def plot_timeline(report: ReportDict, out_dir: Path) -> Path: + tl: TimelineData = report.get("timeline", {}) seconds: List[int] = tl.get("seconds", []) rps: List[int] = tl.get("rps", []) eps: List[int] = tl.get("eps", []) @@ -44,22 +68,23 @@ def plot_timeline(report: Dict[str, Any], out_dir: Path) -> Path: return out_path -def _top_endpoints(report: Dict[str, Any], top_n: int = 10) -> List[Tuple[str, Dict[str, Any]]]: - eps: Dict[str, Any] = report.get("endpoints", {}) +def _top_endpoints(report: ReportDict, top_n: int = 10) -> List[Tuple[str, EndpointData]]: + eps: Dict[str, EndpointData] = report.get("endpoints", {}) items = list(eps.items()) items.sort(key=lambda kv: kv[1].get("count", 0), reverse=True) return items[:top_n] -def plot_endpoint_latency(report: Dict[str, Any], out_dir: Path, top_n: int = 10) -> Path: +def plot_endpoint_latency(report: ReportDict, out_dir: Path, top_n: int = 10) -> Path: data = _top_endpoints(report, top_n) if not data: return out_dir / "endpoint_latency.png" labels = [k for k, _ in data] - p50 = [v.get("latency_ms_success", {}).get("p50", 0) for _, v in data] - p90 = [v.get("latency_ms_success", {}).get("p90", 0) for _, v in data] - p99 = [v.get("latency_ms_success", {}).get("p99", 0) for _, v in data] + empty_latency: LatencyStats = {} + p50 = [v.get("latency_ms_success", empty_latency).get("p50", 0) for _, v in data] + p90 = [v.get("latency_ms_success", empty_latency).get("p90", 0) for _, v in data] + p99 = [v.get("latency_ms_success", empty_latency).get("p99", 0) for _, v in data] x = range(len(labels)) width = 0.25 @@ -81,7 +106,7 @@ def plot_endpoint_latency(report: Dict[str, Any], out_dir: Path, top_n: int = 10 return out_path -def plot_endpoint_throughput(report: Dict[str, Any], out_dir: Path, top_n: int = 10) -> Path: +def plot_endpoint_throughput(report: ReportDict, out_dir: Path, top_n: int = 10) -> Path: data = _top_endpoints(report, top_n) if not data: return out_dir / "endpoint_throughput.png" diff --git a/backend/tests/load/strategies.py b/backend/tests/load/strategies.py index 283473bf..ba7e34a6 100644 --- a/backend/tests/load/strategies.py +++ b/backend/tests/load/strategies.py @@ -1,11 +1,13 @@ from __future__ import annotations from datetime import datetime, timedelta -from typing import Any from hypothesis import strategies as st +# Type alias for JSON values +type JsonValue = None | bool | int | float | str | list[JsonValue] | dict[str, JsonValue] + # Generic JSON strategies (bounded sizes to keep payloads realistic) json_scalar = st.one_of( st.none(), @@ -15,7 +17,7 @@ st.text(min_size=0, max_size=256), ) -json_value: st.SearchStrategy[Any] +json_value: st.SearchStrategy[JsonValue] json_value = st.recursive( json_scalar, lambda children: st.one_of( @@ -48,8 +50,8 @@ severity = st.sampled_from(["info", "warning", "error", "critical"]) # common values label_key = st.text(min_size=1, max_size=24) label_val = st.text(min_size=0, max_size=64) -labels = st.dictionaries(label_key, label_val, max_size=8) -annotations = st.dictionaries(label_key, label_val, max_size=8) +label_dict = st.dictionaries(label_key, label_val, max_size=8) +annotation_dict = st.dictionaries(label_key, label_val, max_size=8) def _iso_time() -> st.SearchStrategy[str]: base = datetime(2024, 1, 1) @@ -60,8 +62,8 @@ def _iso_time() -> st.SearchStrategy[str]: alert = st.fixed_dictionaries( { "status": st.sampled_from(["firing", "resolved"]), - "labels": labels, - "annotations": annotations, + "labels": label_dict, + "annotations": annotation_dict, "startsAt": _iso_time(), "endsAt": _iso_time(), "generatorURL": st.text(min_size=0, max_size=120), @@ -75,9 +77,9 @@ def _iso_time() -> st.SearchStrategy[str]: "status": st.sampled_from(["firing", "resolved"]), "alerts": st.lists(alert, min_size=1, max_size=5), "groupKey": st.text(min_size=0, max_size=64), - "groupLabels": labels, - "commonLabels": labels, - "commonAnnotations": annotations, + "groupLabels": label_dict, + "commonLabels": label_dict, + "commonAnnotations": annotation_dict, "externalURL": st.text(min_size=0, max_size=120), "version": st.text(min_size=1, max_size=16), } diff --git a/backend/tests/load/user_runner.py b/backend/tests/load/user_runner.py index 1c441bd2..a3780404 100644 --- a/backend/tests/load/user_runner.py +++ b/backend/tests/load/user_runner.py @@ -2,6 +2,7 @@ import asyncio import random +from collections.abc import Awaitable from dataclasses import dataclass from typing import Callable @@ -14,7 +15,7 @@ class UserTask: name: str weight: int - fn: Callable[[APIClient], asyncio.Future] + fn: Callable[[APIClient], Awaitable[None]] async def _flow_execute_and_get_result(c: APIClient) -> None: @@ -85,7 +86,7 @@ async def _flow_settings_and_notifications(c: APIClient) -> None: async def run_user_swarm(cfg: LoadConfig, stats: StatsCollector, clients: int) -> None: - tasks: list[asyncio.Task] = [] + tasks: list[asyncio.Task[None]] = [] sem = asyncio.Semaphore(cfg.concurrency) deadline = time.time() + max(1, cfg.duration_seconds) diff --git a/backend/tests/unit/conftest.py b/backend/tests/unit/conftest.py index e89e4163..517ae021 100644 --- a/backend/tests/unit/conftest.py +++ b/backend/tests/unit/conftest.py @@ -3,24 +3,67 @@ Unit tests should NOT access real infrastructure (DB, Redis, HTTP). These fixtures raise errors to catch accidental usage. """ +import logging +from collections.abc import Generator +from typing import NoReturn + import pytest +from app.core.metrics.connections import ConnectionMetrics +from app.core.metrics.context import MetricsContext +from app.core.metrics.coordinator import CoordinatorMetrics +from app.core.metrics.database import DatabaseMetrics +from app.core.metrics.dlq import DLQMetrics +from app.core.metrics.events import EventMetrics +from app.core.metrics.execution import ExecutionMetrics +from app.core.metrics.health import HealthMetrics +from app.core.metrics.kubernetes import KubernetesMetrics +from app.core.metrics.notifications import NotificationMetrics +from app.core.metrics.rate_limit import RateLimitMetrics +from app.core.metrics.replay import ReplayMetrics +from app.core.metrics.security import SecurityMetrics +from app.settings import Settings + +_unit_test_logger = logging.getLogger("test.unit") + + +@pytest.fixture(scope="session", autouse=True) +def init_metrics_for_unit_tests(test_settings: Settings) -> Generator[None, None, None]: + """Initialize all metrics context for unit tests.""" + MetricsContext.initialize_all( + _unit_test_logger, + connection=ConnectionMetrics(test_settings), + coordinator=CoordinatorMetrics(test_settings), + database=DatabaseMetrics(test_settings), + dlq=DLQMetrics(test_settings), + event=EventMetrics(test_settings), + execution=ExecutionMetrics(test_settings), + health=HealthMetrics(test_settings), + kubernetes=KubernetesMetrics(test_settings), + notification=NotificationMetrics(test_settings), + rate_limit=RateLimitMetrics(test_settings), + replay=ReplayMetrics(test_settings), + security=SecurityMetrics(test_settings), + ) + yield + MetricsContext.reset_all(_unit_test_logger) + @pytest.fixture -def db(): +def db() -> NoReturn: raise RuntimeError("Unit tests should not access DB - use mocks or move to integration/") @pytest.fixture -def redis_client(): +def redis_client() -> NoReturn: raise RuntimeError("Unit tests should not access Redis - use mocks or move to integration/") @pytest.fixture -def client(): +def client() -> NoReturn: raise RuntimeError("Unit tests should not use HTTP client - use mocks or move to integration/") @pytest.fixture -def app(): +def app() -> NoReturn: raise RuntimeError("Unit tests should not use full app - use mocks or move to integration/") diff --git a/backend/tests/unit/core/metrics/test_base_metrics.py b/backend/tests/unit/core/metrics/test_base_metrics.py index f8a6ab3e..24a36601 100644 --- a/backend/tests/unit/core/metrics/test_base_metrics.py +++ b/backend/tests/unit/core/metrics/test_base_metrics.py @@ -1,29 +1,30 @@ import pytest from app.core.metrics.base import BaseMetrics +from app.settings import Settings pytestmark = pytest.mark.unit class DummyMetrics(BaseMetrics): - def __init__(self) -> None: + def __init__(self, settings: Settings) -> None: self.created = False - super().__init__(meter_name="dummy") - + super().__init__(settings, meter_name="dummy") + def _create_instruments(self) -> None: # noqa: D401 self.created = True -def test_base_metrics_initializes_meter_and_instruments() -> None: +def test_base_metrics_initializes_meter_and_instruments(test_settings: Settings) -> None: """Test that BaseMetrics initializes properly with no-op metrics.""" # Create DummyMetrics instance - will use NoOpMeterProvider automatically - m = DummyMetrics() - + m = DummyMetrics(test_settings) + # Verify that the BaseMetrics init was called and instruments were created assert m.created is True assert m._meter is not None # Meter exists (will be NoOpMeter) - + # close is no-op m.close() diff --git a/backend/tests/unit/core/metrics/test_connections_and_coordinator_metrics.py b/backend/tests/unit/core/metrics/test_connections_and_coordinator_metrics.py index 1103bb98..0610913d 100644 --- a/backend/tests/unit/core/metrics/test_connections_and_coordinator_metrics.py +++ b/backend/tests/unit/core/metrics/test_connections_and_coordinator_metrics.py @@ -2,14 +2,14 @@ from app.core.metrics.connections import ConnectionMetrics from app.core.metrics.coordinator import CoordinatorMetrics +from app.settings import Settings pytestmark = pytest.mark.unit -def test_connection_metrics_methods() -> None: +def test_connection_metrics_methods(test_settings: Settings) -> None: """Test ConnectionMetrics methods with no-op metrics.""" - # Create ConnectionMetrics instance - will use NoOpMeterProvider automatically - m = ConnectionMetrics() + m = ConnectionMetrics(test_settings) m.increment_sse_connections("/events") m.decrement_sse_connections("/events") m.record_sse_message_sent("/events", "etype") @@ -22,10 +22,9 @@ def test_connection_metrics_methods() -> None: m.update_event_bus_subscribers(3, "*") -def test_coordinator_metrics_methods() -> None: +def test_coordinator_metrics_methods(test_settings: Settings) -> None: """Test CoordinatorMetrics methods with no-op metrics.""" - # Create CoordinatorMetrics instance - will use NoOpMeterProvider automatically - m = CoordinatorMetrics() + m = CoordinatorMetrics(test_settings) m.record_coordinator_processing_time(0.1) m.record_scheduling_duration(0.2) m.update_active_executions_gauge(2) diff --git a/backend/tests/unit/core/metrics/test_database_and_dlq_metrics.py b/backend/tests/unit/core/metrics/test_database_and_dlq_metrics.py index 1b8d8072..d0be021d 100644 --- a/backend/tests/unit/core/metrics/test_database_and_dlq_metrics.py +++ b/backend/tests/unit/core/metrics/test_database_and_dlq_metrics.py @@ -2,14 +2,14 @@ from app.core.metrics.database import DatabaseMetrics from app.core.metrics.dlq import DLQMetrics +from app.settings import Settings pytestmark = pytest.mark.unit -def test_database_metrics_methods() -> None: +def test_database_metrics_methods(test_settings: Settings) -> None: """Test DatabaseMetrics methods with no-op metrics.""" - # Create DatabaseMetrics instance - will use NoOpMeterProvider automatically - m = DatabaseMetrics() + m = DatabaseMetrics(test_settings) m.record_mongodb_operation("insert", "ok") m.record_mongodb_query_duration(0.1, "find") m.record_event_store_duration(0.2, "insert", "events") @@ -28,10 +28,9 @@ def test_database_metrics_methods() -> None: m.record_database_connection_error("timeout") -def test_dlq_metrics_methods() -> None: +def test_dlq_metrics_methods(test_settings: Settings) -> None: """Test DLQMetrics methods with no-op metrics.""" - # Create DLQMetrics instance - will use NoOpMeterProvider automatically - m = DLQMetrics() + m = DLQMetrics(test_settings) m.record_dlq_message_received("topic", "etype") m.record_dlq_message_retried("topic", "etype", "success") m.record_dlq_message_discarded("topic", "etype", "bad") diff --git a/backend/tests/unit/core/metrics/test_execution_and_events_metrics.py b/backend/tests/unit/core/metrics/test_execution_and_events_metrics.py index 9f008a66..7fce126b 100644 --- a/backend/tests/unit/core/metrics/test_execution_and_events_metrics.py +++ b/backend/tests/unit/core/metrics/test_execution_and_events_metrics.py @@ -1,43 +1,48 @@ - - import pytest from app.core.metrics.execution import ExecutionMetrics from app.core.metrics.events import EventMetrics from app.domain.enums.execution import ExecutionStatus +from app.settings import Settings pytestmark = pytest.mark.unit -def test_execution_metrics_methods() -> None: +def test_execution_metrics_methods(test_settings: Settings) -> None: """Test with no-op metrics.""" - - m = ExecutionMetrics() + m = ExecutionMetrics(test_settings) m.record_script_execution(ExecutionStatus.QUEUED, "python-3.11") m.record_execution_duration(0.5, "python-3.11") - m.increment_active_executions(); m.decrement_active_executions() + m.increment_active_executions() + m.decrement_active_executions() m.record_memory_usage(123.4, "python-3.11") m.record_error("timeout") - m.update_queue_depth(1); m.update_queue_depth(-1) + m.update_queue_depth(1) + m.update_queue_depth(-1) m.record_queue_wait_time(0.1, "python-3.11") - m.record_execution_assigned(); m.record_execution_queued(); m.record_execution_scheduled("ok") - m.update_cpu_available(100.0); m.update_memory_available(512.0); m.update_gpu_available(1) + m.record_execution_assigned() + m.record_execution_queued() + m.record_execution_scheduled("ok") + m.update_cpu_available(100.0) + m.update_memory_available(512.0) + m.update_gpu_available(1) m.update_allocations_active(2) -def test_event_metrics_methods() -> None: +def test_event_metrics_methods(test_settings: Settings) -> None: """Test with no-op metrics.""" - - m = EventMetrics() + m = EventMetrics(test_settings) m.record_event_published("execution.requested", None) m.record_event_processing_duration(0.05, "execution.requested") m.record_pod_event_published("pod.running") m.record_event_replay_operation("prepare", "success") m.update_event_buffer_size(3) - m.record_event_buffer_dropped(); m.record_event_buffer_processed() + m.record_event_buffer_dropped() + m.record_event_buffer_processed() m.record_event_buffer_latency(0.2) - m.set_event_buffer_backpressure(True); m.set_event_buffer_backpressure(False) + m.set_event_buffer_backpressure(True) + m.set_event_buffer_backpressure(False) m.record_event_buffer_memory_usage(12.3) m.record_event_stored("execution.requested", "events") m.record_events_processing_failed("topic", "etype", "group", "error") @@ -45,8 +50,12 @@ def test_event_metrics_methods() -> None: m.record_event_store_failed("etype", "fail") m.record_event_query_duration(0.2, "by_type", "events") m.record_processing_duration(0.3, "topic", "etype", "group") - m.record_kafka_message_produced("t"); m.record_kafka_message_consumed("t", "g") + m.record_kafka_message_produced("t") + m.record_kafka_message_consumed("t", "g") m.record_kafka_consumer_lag(10, "t", "g", 0) - m.record_kafka_production_error("t", "e"); m.record_kafka_consumption_error("t", "g", "e") - m.update_event_bus_queue_size(1, "default"); m.set_event_bus_queue_size(5, "default"); m.set_event_bus_queue_size(2, "default") + m.record_kafka_production_error("t", "e") + m.record_kafka_consumption_error("t", "g", "e") + m.update_event_bus_queue_size(1, "default") + m.set_event_bus_queue_size(5, "default") + m.set_event_bus_queue_size(2, "default") diff --git a/backend/tests/unit/core/metrics/test_health_and_rate_limit_metrics.py b/backend/tests/unit/core/metrics/test_health_and_rate_limit_metrics.py index ff97c429..710ce31d 100644 --- a/backend/tests/unit/core/metrics/test_health_and_rate_limit_metrics.py +++ b/backend/tests/unit/core/metrics/test_health_and_rate_limit_metrics.py @@ -1,20 +1,20 @@ import pytest from app.core.metrics.health import HealthMetrics +from app.settings import Settings pytestmark = pytest.mark.unit -def test_health_metrics_methods() -> None: +def test_health_metrics_methods(test_settings: Settings) -> None: """Test with no-op metrics.""" - - m = HealthMetrics() + m = HealthMetrics(test_settings) m.record_health_check_duration(0.1, "liveness", "basic") m.record_health_check_failure("readiness", "db", "timeout") m.update_health_check_status(1, "liveness", "basic") m.record_health_status("svc", "healthy") m.record_service_health_score("svc", 95.0) - m.update_liveness_status(True, "app"); + m.update_liveness_status(True, "app") m.update_readiness_status(False, "app") m.record_dependency_health("mongo", True, 0.2) m.record_health_check_timeout("readiness", "db") diff --git a/backend/tests/unit/core/metrics/test_kubernetes_and_notifications_metrics.py b/backend/tests/unit/core/metrics/test_kubernetes_and_notifications_metrics.py index 5fbdcc73..dda78599 100644 --- a/backend/tests/unit/core/metrics/test_kubernetes_and_notifications_metrics.py +++ b/backend/tests/unit/core/metrics/test_kubernetes_and_notifications_metrics.py @@ -1,24 +1,27 @@ - - import pytest from app.core.metrics.kubernetes import KubernetesMetrics from app.core.metrics.notifications import NotificationMetrics +from app.settings import Settings pytestmark = pytest.mark.unit -def test_kubernetes_metrics_methods() -> None: +def test_kubernetes_metrics_methods(test_settings: Settings) -> None: """Test with no-op metrics.""" - - m = KubernetesMetrics() + m = KubernetesMetrics(test_settings) m.record_pod_creation_failure("quota") - m.record_pod_created("success", "python"); m.record_pod_creation_duration(0.4, "python") - m.update_active_pod_creations(2); m.increment_active_pod_creations(); m.decrement_active_pod_creations() + m.record_pod_created("success", "python") + m.record_pod_creation_duration(0.4, "python") + m.update_active_pod_creations(2) + m.increment_active_pod_creations() + m.decrement_active_pod_creations() m.record_config_map_created("ok") - m.record_k8s_pod_created("success", "python"); m.record_k8s_pod_creation_duration(0.3, "python") - m.record_k8s_config_map_created("ok"); m.record_k8s_network_policy_created("ok") + m.record_k8s_pod_created("success", "python") + m.record_k8s_pod_creation_duration(0.3, "python") + m.record_k8s_config_map_created("ok") + m.record_k8s_network_policy_created("ok") m.update_k8s_active_creations(1) m.increment_pod_monitor_watch_reconnects() m.record_pod_monitor_event_processing_duration(0.2, "ADDED") @@ -34,10 +37,9 @@ def test_kubernetes_metrics_methods() -> None: m.record_pods_per_node("node1", 7) -def test_notification_metrics_methods() -> None: +def test_notification_metrics_methods(test_settings: Settings) -> None: """Test with no-op metrics.""" - - m = NotificationMetrics() + m = NotificationMetrics(test_settings) m.record_notification_sent("welcome", channel="email", severity="high") m.record_notification_failed("welcome", "smtp_error", channel="email") m.record_notification_delivery_time(0.5, "welcome", channel="email") diff --git a/backend/tests/unit/core/metrics/test_metrics_classes.py b/backend/tests/unit/core/metrics/test_metrics_classes.py index e0e02ef3..fc620462 100644 --- a/backend/tests/unit/core/metrics/test_metrics_classes.py +++ b/backend/tests/unit/core/metrics/test_metrics_classes.py @@ -1,4 +1,7 @@ +import pytest + from app.core.metrics.connections import ConnectionMetrics +from app.domain.enums.execution import ExecutionStatus from app.core.metrics.coordinator import CoordinatorMetrics from app.core.metrics.database import DatabaseMetrics from app.core.metrics.dlq import DLQMetrics @@ -10,12 +13,14 @@ from app.core.metrics.rate_limit import RateLimitMetrics from app.core.metrics.replay import ReplayMetrics from app.core.metrics.security import SecurityMetrics +from app.settings import Settings + +pytestmark = pytest.mark.unit -def test_connection_metrics_smoke(): +def test_connection_metrics_smoke(test_settings: Settings) -> None: """Test ConnectionMetrics smoke test with no-op metrics.""" - # Create ConnectionMetrics instance - will use NoOpMeterProvider automatically - m = ConnectionMetrics() + m = ConnectionMetrics(test_settings) m.increment_sse_connections("exec") m.decrement_sse_connections("exec") m.record_sse_message_sent("exec", "evt") @@ -25,10 +30,9 @@ def test_connection_metrics_smoke(): m.update_event_bus_subscribers(3, "*") -def test_event_metrics_smoke(): +def test_event_metrics_smoke(test_settings: Settings) -> None: """Test EventMetrics smoke test with no-op metrics.""" - # Create EventMetrics instance - will use NoOpMeterProvider automatically - m = EventMetrics() + m = EventMetrics(test_settings) m.record_event_published("execution.requested") m.record_event_processing_duration(0.01, "execution.requested") m.record_pod_event_published("pod.created") @@ -54,16 +58,15 @@ def test_event_metrics_smoke(): m.set_event_bus_queue_size(5) -def test_other_metrics_classes_smoke(): +def test_other_metrics_classes_smoke(test_settings: Settings) -> None: """Test other metrics classes smoke test with no-op metrics.""" - # Create metrics instances - will use NoOpMeterProvider automatically - CoordinatorMetrics().record_coordinator_processing_time(0.01) - DatabaseMetrics().record_mongodb_operation("read", "ok") - DLQMetrics().record_dlq_message_received("topic", "type") - ExecutionMetrics().record_script_execution("QUEUED", "python") - HealthMetrics().record_health_check_duration(0.001, "liveness", "basic") - KubernetesMetrics().record_k8s_pod_created("success", "python") - NotificationMetrics().record_notification_sent("welcome", channel="email") - RateLimitMetrics().requests_total.add(1) - ReplayMetrics().record_session_created("by_id", "kafka") - SecurityMetrics().record_security_event("scan", severity="low") + CoordinatorMetrics(test_settings).record_coordinator_processing_time(0.01) + DatabaseMetrics(test_settings).record_mongodb_operation("read", "ok") + DLQMetrics(test_settings).record_dlq_message_received("topic", "type") + ExecutionMetrics(test_settings).record_script_execution(ExecutionStatus.QUEUED, "python") + HealthMetrics(test_settings).record_health_check_duration(0.001, "liveness", "basic") + KubernetesMetrics(test_settings).record_k8s_pod_created("success", "python") + NotificationMetrics(test_settings).record_notification_sent("welcome", channel="email") + RateLimitMetrics(test_settings).requests_total.add(1) + ReplayMetrics(test_settings).record_session_created("by_id", "kafka") + SecurityMetrics(test_settings).record_security_event("scan", severity="low") diff --git a/backend/tests/unit/core/metrics/test_metrics_context.py b/backend/tests/unit/core/metrics/test_metrics_context.py index c73001a9..c5cf6e50 100644 --- a/backend/tests/unit/core/metrics/test_metrics_context.py +++ b/backend/tests/unit/core/metrics/test_metrics_context.py @@ -1,26 +1,25 @@ import logging +import pytest + from app.core.metrics.context import ( - MetricsContext, get_connection_metrics, get_coordinator_metrics, ) _test_logger = logging.getLogger("test.core.metrics.context") +pytestmark = pytest.mark.unit + -def test_metrics_context_lazy_and_reset() -> None: - """Test metrics context lazy loading and reset with no-op metrics.""" - # Get metrics instances - will use NoOpMeterProvider automatically +def test_metrics_context_returns_initialized_metrics() -> None: + """Test metrics context returns initialized metrics from session fixture.""" + # Metrics are initialized by the session-scoped fixture in conftest.py c1 = get_connection_metrics() c2 = get_connection_metrics() assert c1 is c2 # same instance per context d1 = get_coordinator_metrics() - MetricsContext.reset_all(_test_logger) - # after reset, new instances are created lazily - c3 = get_connection_metrics() - assert c3 is not c1 d2 = get_coordinator_metrics() - assert d2 is not d1 + assert d1 is d2 diff --git a/backend/tests/unit/core/metrics/test_replay_and_security_metrics.py b/backend/tests/unit/core/metrics/test_replay_and_security_metrics.py index 6e03f057..5c3cf4f1 100644 --- a/backend/tests/unit/core/metrics/test_replay_and_security_metrics.py +++ b/backend/tests/unit/core/metrics/test_replay_and_security_metrics.py @@ -1,19 +1,20 @@ - import pytest from app.core.metrics.replay import ReplayMetrics from app.core.metrics.security import SecurityMetrics +from app.settings import Settings pytestmark = pytest.mark.unit -def test_replay_metrics_methods() -> None: +def test_replay_metrics_methods(test_settings: Settings) -> None: """Test ReplayMetrics methods with no-op metrics.""" - # Create ReplayMetrics instance - will use NoOpMeterProvider automatically - m = ReplayMetrics() + m = ReplayMetrics(test_settings) m.record_session_created("by_id", "kafka") - m.update_active_replays(2); m.increment_active_replays(); m.decrement_active_replays() + m.update_active_replays(2) + m.increment_active_replays() + m.decrement_active_replays() m.record_events_replayed("by_id", "etype", "success", 3) m.record_event_replayed("by_id", "etype", "failed") m.record_replay_duration(2.0, "by_id", total_events=4) @@ -21,20 +22,21 @@ def test_replay_metrics_methods() -> None: m.record_replay_error("timeout", "by_id") m.record_status_change("s1", "running", "completed") m.update_sessions_by_status("running", -1) - m.record_replay_by_target("kafka", True); m.record_replay_by_target("kafka", False) + m.record_replay_by_target("kafka", True) + m.record_replay_by_target("kafka", False) m.record_speed_multiplier(2.0, "by_id") m.record_delay_applied(0.05) m.record_batch_size(10, "by_id") m.record_events_filtered("type", 5) m.record_filter_effectiveness(5, 10, "type") m.record_replay_memory_usage(123.0, "s1") - m.update_replay_queue_size("s1", 10); m.update_replay_queue_size("s1", 4) + m.update_replay_queue_size("s1", 10) + m.update_replay_queue_size("s1", 4) -def test_security_metrics_methods() -> None: +def test_security_metrics_methods(test_settings: Settings) -> None: """Test SecurityMetrics methods with no-op metrics.""" - # Create SecurityMetrics instance - will use NoOpMeterProvider automatically - m = SecurityMetrics() + m = SecurityMetrics(test_settings) m.record_security_event("scan_started", severity="high", source="scanner") m.record_security_violation("csrf", user_id="u1", ip_address="127.0.0.1") m.record_authentication_attempt("password", False, user_id="u1", duration_seconds=0.2) diff --git a/backend/tests/unit/core/test_adaptive_sampling.py b/backend/tests/unit/core/test_adaptive_sampling.py index 1929de38..9250822c 100644 --- a/backend/tests/unit/core/test_adaptive_sampling.py +++ b/backend/tests/unit/core/test_adaptive_sampling.py @@ -1,9 +1,10 @@ import time -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from app.core.adaptive_sampling import AdaptiveSampler, create_adaptive_sampler +from app.settings import Settings def test_is_error_variants() -> None: @@ -67,11 +68,11 @@ def test_get_description_and_factory(monkeypatch: pytest.MonkeyPatch) -> None: assert "AdaptiveSampler(" in desc s._running = False - class S: - TRACING_SAMPLING_RATE = 0.2 + mock_settings = MagicMock(spec=Settings) + mock_settings.TRACING_SAMPLING_RATE = 0.2 monkeypatch.setenv("TRACING_SAMPLING_RATE", "0.2") # create_adaptive_sampler pulls settings via get_settings; just ensure it constructs - sampler = create_adaptive_sampler(S()) + sampler = create_adaptive_sampler(mock_settings) sampler._running = False diff --git a/backend/tests/unit/core/test_csrf.py b/backend/tests/unit/core/test_csrf.py index 9ef0b506..4c303cfd 100644 --- a/backend/tests/unit/core/test_csrf.py +++ b/backend/tests/unit/core/test_csrf.py @@ -1,10 +1,11 @@ import pytest +from app.core.security import SecurityService +from app.settings import Settings from starlette.requests import Request -from app.core.security import validate_csrf_token, security_service - -def make_request(method: str, path: str, headers: dict[str, str] | None = None, cookies: dict[str, str] | None = None) -> Request: +def make_request(method: str, path: str, headers: dict[str, str] | None = None, + cookies: dict[str, str] | None = None) -> Request: headers = headers or {} if cookies: cookie_header = "; ".join(f"{k}={v}" for k, v in cookies.items()) @@ -18,18 +19,21 @@ def make_request(method: str, path: str, headers: dict[str, str] | None = None, return Request(scope) -def test_csrf_skips_on_get() -> None: +def test_csrf_skips_on_get(test_settings: Settings) -> None: + security_service = SecurityService(test_settings) req = make_request("GET", "/api/v1/anything") - assert validate_csrf_token(req) == "skip" + assert security_service.validate_csrf_from_request(req) == "skip" -def test_csrf_missing_header_raises_when_authenticated() -> None: +def test_csrf_missing_header_raises_when_authenticated(test_settings: Settings) -> None: + security_service = SecurityService(test_settings) req = make_request("POST", "/api/v1/items", cookies={"access_token": "tok", "csrf_token": "abc"}) with pytest.raises(Exception): - validate_csrf_token(req) + security_service.validate_csrf_from_request(req) -def test_csrf_valid_tokens() -> None: +def test_csrf_valid_tokens(test_settings: Settings) -> None: + security_service = SecurityService(test_settings) token = security_service.generate_csrf_token() req = make_request( "POST", @@ -37,4 +41,4 @@ def test_csrf_valid_tokens() -> None: headers={"X-CSRF-Token": token}, cookies={"access_token": "tok", "csrf_token": token}, ) - assert validate_csrf_token(req) == token + assert security_service.validate_csrf_from_request(req) == token diff --git a/backend/tests/unit/core/test_logging_and_correlation.py b/backend/tests/unit/core/test_logging_and_correlation.py index bad1385f..f535ab9f 100644 --- a/backend/tests/unit/core/test_logging_and_correlation.py +++ b/backend/tests/unit/core/test_logging_and_correlation.py @@ -38,14 +38,16 @@ def capture_log(formatter: logging.Formatter, msg: str, extra: dict[str, Any] | string_io.close() if output: - return json.loads(output) + result: dict[str, Any] = json.loads(output) + return result # Fallback: create and format record manually lr = logging.LogRecord("t", logging.INFO, __file__, 1, msg, (), None, None) # Apply the filter manually correlation_filter.filter(lr) s = formatter.format(lr) - return json.loads(s) + fallback_result: dict[str, Any] = json.loads(s) + return fallback_result def test_json_formatter_sanitizes_tokens(monkeypatch: pytest.MonkeyPatch) -> None: @@ -83,6 +85,6 @@ async def ping(request: Request) -> JSONResponse: assert "X-Correlation-ID" in r.headers -def test_setup_logger_returns_logger(): +def test_setup_logger_returns_logger() -> None: lg = setup_logger(log_level="INFO") assert hasattr(lg, "info") diff --git a/backend/tests/unit/core/test_security.py b/backend/tests/unit/core/test_security.py index a3c475c3..cb0ed703 100644 --- a/backend/tests/unit/core/test_security.py +++ b/backend/tests/unit/core/test_security.py @@ -8,16 +8,16 @@ from app.core.security import SecurityService from app.domain.enums.user import UserRole - +from app.settings import Settings class TestPasswordHashing: """Test password hashing functionality.""" @pytest.fixture - def security_svc(self) -> SecurityService: + def security_svc(self, test_settings: Settings) -> SecurityService: """Create SecurityService instance.""" - return SecurityService() + return SecurityService(test_settings) def test_password_hash_creates_different_hash(self, security_svc: SecurityService) -> None: """Test that password hashing creates unique hashes.""" @@ -72,9 +72,9 @@ class TestSecurityService: """Test SecurityService functionality.""" @pytest.fixture - def security_service(self) -> SecurityService: - """Create SecurityService instance using real settings from env.""" - return SecurityService() + def security_service(self, test_settings: Settings) -> SecurityService: + """Create SecurityService instance using test settings.""" + return SecurityService(test_settings) def test_create_access_token_basic( self, @@ -222,11 +222,11 @@ def test_decode_token_missing_username( ) -> None: """Test decoding token without username.""" # Create token without 'sub' field - data = {"user_id": str(uuid4())} + data: dict[str, str | datetime] = {"user_id": str(uuid4())} expire = datetime.now(timezone.utc) + timedelta(minutes=15) to_encode = data.copy() - to_encode.update({"exp": expire}) + to_encode["exp"] = expire token = jwt.encode( to_encode, @@ -283,9 +283,9 @@ def test_token_has_only_expected_claims(self, security_service: SecurityService) assert decoded["role"] == UserRole.USER.value assert "extra_field" in decoded # Claims are carried as provided - def test_password_context_configuration(self) -> None: + def test_password_context_configuration(self, test_settings: Settings) -> None: """Test password context is properly configured.""" - svc = SecurityService() + svc = SecurityService(test_settings) password = "test_password" hashed = svc.get_password_hash(password) assert svc.verify_password(password, hashed) diff --git a/backend/tests/unit/events/core/test_producer.py b/backend/tests/unit/events/core/test_producer.py new file mode 100644 index 00000000..ba825dee --- /dev/null +++ b/backend/tests/unit/events/core/test_producer.py @@ -0,0 +1,22 @@ +import json +import logging + +import pytest +from app.events.core import ProducerMetrics, UnifiedProducer + +pytestmark = pytest.mark.unit + +_test_logger = logging.getLogger("test.events.core.producer") + + +def test_producer_handle_stats_path() -> None: + """Directly run stats parsing to cover branch logic; avoid relying on timing.""" + m = ProducerMetrics() + p = object.__new__(UnifiedProducer) # bypass __init__ safely for method call + # Inject required attributes for _handle_stats (including logger for exception handler) + p._metrics = m + p._stats_callback = None + p.logger = _test_logger + payload = json.dumps({"msg_cnt": 1, "topics": {"t": {"partitions": {"0": {"msgq_cnt": 2, "rtt": {"avg": 5}}}}}}) + UnifiedProducer._handle_stats(p, payload) + assert m.queue_size == 1 and m.avg_latency_ms > 0 diff --git a/backend/tests/unit/events/test_event_dispatcher.py b/backend/tests/unit/events/test_event_dispatcher.py index 28f7c92d..5933df45 100644 --- a/backend/tests/unit/events/test_event_dispatcher.py +++ b/backend/tests/unit/events/test_event_dispatcher.py @@ -8,7 +8,7 @@ _test_logger = logging.getLogger("test.events.event_dispatcher") -def make_event(): +def make_event() -> BaseEvent: return make_execution_requested_event(execution_id="e1") @@ -51,7 +51,7 @@ async def handler(_: BaseEvent) -> None: # Dispatch event with no handlers (different type) # Reuse base event but fake type by replacing value e = make_event() - e.event_type = EventType.EXECUTION_FAILED # type: ignore[attr-defined] + e.event_type = EventType.EXECUTION_FAILED await disp.dispatch(e) metrics = disp.get_metrics() diff --git a/backend/tests/unit/events/test_schema_registry_manager.py b/backend/tests/unit/events/test_schema_registry_manager.py index 77562a2e..5b8ddd1e 100644 --- a/backend/tests/unit/events/test_schema_registry_manager.py +++ b/backend/tests/unit/events/test_schema_registry_manager.py @@ -1,13 +1,15 @@ import logging import pytest + from app.events.schema.schema_registry import SchemaRegistryManager from app.infrastructure.kafka.events.execution import ExecutionRequestedEvent +from app.settings import Settings _test_logger = logging.getLogger("test.events.schema_registry_manager") -def test_deserialize_json_execution_requested(test_settings) -> None: # type: ignore[valid-type] +def test_deserialize_json_execution_requested(test_settings: Settings) -> None: m = SchemaRegistryManager(test_settings, logger=_test_logger) data = { "event_type": "execution_requested", @@ -32,7 +34,7 @@ def test_deserialize_json_execution_requested(test_settings) -> None: # type: i assert ev.language == "python" -def test_deserialize_json_missing_type_raises(test_settings) -> None: # type: ignore[valid-type] +def test_deserialize_json_missing_type_raises(test_settings: Settings) -> None: m = SchemaRegistryManager(test_settings, logger=_test_logger) with pytest.raises(ValueError): m.deserialize_json({}) diff --git a/backend/tests/unit/schemas_pydantic/test_events_schemas.py b/backend/tests/unit/schemas_pydantic/test_events_schemas.py index 30ef50c2..d055a488 100644 --- a/backend/tests/unit/schemas_pydantic/test_events_schemas.py +++ b/backend/tests/unit/schemas_pydantic/test_events_schemas.py @@ -4,7 +4,7 @@ from app.domain.enums.common import SortOrder -def test_event_filter_request_sort_validator_accepts_allowed_fields(): +def test_event_filter_request_sort_validator_accepts_allowed_fields() -> None: req = EventFilterRequest(sort_by="timestamp", sort_order=SortOrder.DESC) assert req.sort_by == "timestamp" @@ -13,6 +13,6 @@ def test_event_filter_request_sort_validator_accepts_allowed_fields(): assert req2.sort_by == field -def test_event_filter_request_sort_validator_rejects_invalid(): +def test_event_filter_request_sort_validator_rejects_invalid() -> None: with pytest.raises(ValueError): EventFilterRequest(sort_by="not-a-field") diff --git a/backend/tests/unit/schemas_pydantic/test_execution_schemas.py b/backend/tests/unit/schemas_pydantic/test_execution_schemas.py index 38e59401..70c48bab 100644 --- a/backend/tests/unit/schemas_pydantic/test_execution_schemas.py +++ b/backend/tests/unit/schemas_pydantic/test_execution_schemas.py @@ -5,18 +5,18 @@ from app.schemas_pydantic.execution import ExecutionRequest -def test_execution_request_valid_supported_runtime(): +def test_execution_request_valid_supported_runtime() -> None: req = ExecutionRequest(script="print('ok')", lang="python", lang_version="3.11") assert req.lang == "python" and req.lang_version == "3.11" -def test_execution_request_unsupported_language_raises(): +def test_execution_request_unsupported_language_raises() -> None: with pytest.raises(ValueError) as e: ExecutionRequest(script="print(1)", lang="rust", lang_version="1.0") assert "Language 'rust' not supported" in str(e.value) -def test_execution_request_unsupported_version_raises(): +def test_execution_request_unsupported_version_raises() -> None: with pytest.raises(ValueError) as e: ExecutionRequest(script="print(1)", lang="python", lang_version="9.9") assert "Version '9.9' not supported for python" in str(e.value) diff --git a/backend/tests/unit/schemas_pydantic/test_notification_schemas.py b/backend/tests/unit/schemas_pydantic/test_notification_schemas.py index 14b304bc..dd274180 100644 --- a/backend/tests/unit/schemas_pydantic/test_notification_schemas.py +++ b/backend/tests/unit/schemas_pydantic/test_notification_schemas.py @@ -6,7 +6,7 @@ from app.schemas_pydantic.notification import Notification, NotificationBatch -def test_notification_scheduled_for_must_be_future(): +def test_notification_scheduled_for_must_be_future() -> None: n = Notification( user_id="u1", channel=NotificationChannel.IN_APP, @@ -28,7 +28,7 @@ def test_notification_scheduled_for_must_be_future(): ) -def test_notification_batch_validation_limits(): +def test_notification_batch_validation_limits() -> None: n1 = Notification(user_id="u1", channel=NotificationChannel.IN_APP, subject="a", body="b") ok = NotificationBatch(notifications=[n1]) assert ok.processed_count == 0 diff --git a/backend/tests/unit/services/coordinator/test_queue_manager.py b/backend/tests/unit/services/coordinator/test_queue_manager.py index e3151a16..a43f81ca 100644 --- a/backend/tests/unit/services/coordinator/test_queue_manager.py +++ b/backend/tests/unit/services/coordinator/test_queue_manager.py @@ -2,18 +2,21 @@ import pytest +from app.infrastructure.kafka.events.execution import ExecutionRequestedEvent from app.services.coordinator.queue_manager import QueueManager, QueuePriority from tests.helpers import make_execution_requested_event _test_logger = logging.getLogger("test.services.coordinator.queue_manager") +pytestmark = pytest.mark.unit -def ev(execution_id: str, priority: int = QueuePriority.NORMAL.value): + +def ev(execution_id: str, priority: int = QueuePriority.NORMAL.value) -> ExecutionRequestedEvent: return make_execution_requested_event(execution_id=execution_id, priority=priority) @pytest.mark.asyncio -async def test_requeue_execution_increments_priority(): +async def test_requeue_execution_increments_priority() -> None: qm = QueueManager(max_queue_size=10, logger=_test_logger) await qm.start() # Use NORMAL priority which can be incremented to LOW @@ -26,7 +29,7 @@ async def test_requeue_execution_increments_priority(): @pytest.mark.asyncio -async def test_queue_stats_empty_and_after_add(): +async def test_queue_stats_empty_and_after_add() -> None: qm = QueueManager(max_queue_size=5, logger=_test_logger) await qm.start() stats0 = await qm.get_queue_stats() diff --git a/backend/tests/unit/services/idempotency/test_middleware.py b/backend/tests/unit/services/idempotency/test_middleware.py index c4b19acf..475e75ac 100644 --- a/backend/tests/unit/services/idempotency/test_middleware.py +++ b/backend/tests/unit/services/idempotency/test_middleware.py @@ -1,6 +1,5 @@ -import asyncio import logging -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock import pytest from app.infrastructure.kafka.events.base import BaseEvent @@ -11,8 +10,6 @@ IdempotentConsumerWrapper, ) from app.domain.idempotency import IdempotencyStatus -from app.domain.enums.events import EventType -from app.domain.enums.kafka import KafkaTopic _test_logger = logging.getLogger("test.services.idempotency.middleware") @@ -22,24 +19,24 @@ class TestIdempotentEventHandler: @pytest.fixture - def mock_idempotency_manager(self): + def mock_idempotency_manager(self) -> AsyncMock: return AsyncMock(spec=IdempotencyManager) @pytest.fixture - def mock_handler(self): + def mock_handler(self) -> AsyncMock: handler = AsyncMock() handler.__name__ = "test_handler" return handler @pytest.fixture - def event(self): + def event(self) -> MagicMock: event = MagicMock(spec=BaseEvent) event.event_type = "test.event" event.event_id = "event-123" return event @pytest.fixture - def idempotent_event_handler(self, mock_handler, mock_idempotency_manager): + def idempotent_event_handler(self, mock_handler: AsyncMock, mock_idempotency_manager: AsyncMock) -> IdempotentEventHandler: return IdempotentEventHandler( handler=mock_handler, idempotency_manager=mock_idempotency_manager, @@ -50,7 +47,7 @@ def idempotent_event_handler(self, mock_handler, mock_idempotency_manager): ) @pytest.mark.asyncio - async def test_call_with_fields(self, mock_handler, mock_idempotency_manager, event): + async def test_call_with_fields(self, mock_handler: AsyncMock, mock_idempotency_manager: AsyncMock, event: MagicMock) -> None: # Setup with specific fields fields = {"field1", "field2"} @@ -83,7 +80,7 @@ async def test_call_with_fields(self, mock_handler, mock_idempotency_manager, ev ) @pytest.mark.asyncio - async def test_call_handler_exception(self, idempotent_event_handler, mock_idempotency_manager, mock_handler, event): + async def test_call_handler_exception(self, idempotent_event_handler: IdempotentEventHandler, mock_idempotency_manager: AsyncMock, mock_handler: AsyncMock, event: MagicMock) -> None: # Setup: Handler raises exception idempotency_result = IdempotencyResult( is_duplicate=False, diff --git a/backend/tests/unit/services/pod_monitor/test_event_mapper.py b/backend/tests/unit/services/pod_monitor/test_event_mapper.py index 48a36d4b..ccce2787 100644 --- a/backend/tests/unit/services/pod_monitor/test_event_mapper.py +++ b/backend/tests/unit/services/pod_monitor/test_event_mapper.py @@ -1,20 +1,27 @@ import json import logging + import pytest from app.domain.enums.storage import ExecutionErrorType +from app.infrastructure.kafka.events.execution import ( + ExecutionCompletedEvent, + ExecutionFailedEvent, + ExecutionTimeoutEvent, +) from app.infrastructure.kafka.events.metadata import AvroEventMetadata +from app.infrastructure.kafka.events.pod import PodRunningEvent from app.services.pod_monitor.event_mapper import PodContext, PodEventMapper from tests.helpers.k8s_fakes import ( - Meta, - Terminated, - Waiting, - State, ContainerStatus, + FakeApi, + Meta, + Pod, Spec, + State, Status, - Pod, - FakeApi, + Terminated, + Waiting, ) @@ -33,8 +40,12 @@ def test_pending_running_and_succeeded_mapping() -> None: # Pending -> scheduled (set execution-id label and PodScheduled condition) pend = Pod("p", "Pending") pend.metadata.labels = {"execution-id": "e1"} + class Cond: - def __init__(self, t, s): self.type=t; self.status=s + def __init__(self, t: str, s: str) -> None: + self.type = t + self.status = s + pend.status.conditions = [Cond("PodScheduled", "True")] pend.spec.node_name = "n" evts = pem.map_pod_event(pend, "ADDED") @@ -50,6 +61,7 @@ def __init__(self, t, s): self.type=t; self.status=s print(f"Events returned: {[e.event_type.value for e in evts]}") assert any(e.event_type.value == "pod_running" for e in evts) pr = [e for e in evts if e.event_type.value == "pod_running"][0] + assert isinstance(pr, PodRunningEvent) statuses = json.loads(pr.container_statuses) assert any("waiting" in s["state"] for s in statuses) and any("terminated" in s["state"] for s in statuses) @@ -59,6 +71,7 @@ def __init__(self, t, s): self.type=t; self.status=s suc.metadata.labels = {"execution-id": "e1"} evts = pem.map_pod_event(suc, "MODIFIED") comp = [e for e in evts if e.event_type.value == "execution_completed"][0] + assert isinstance(comp, ExecutionCompletedEvent) assert comp.exit_code == 0 and comp.stdout == "ok" @@ -70,6 +83,7 @@ def test_failed_timeout_and_deleted() -> None: pod_to = Pod("p", "Failed", cs=[ContainerStatus(State(terminated=Terminated(137)))], reason="DeadlineExceeded", adl=5) pod_to.metadata.labels = {"execution-id": "e1"} ev = pem.map_pod_event(pod_to, "MODIFIED")[0] + assert isinstance(ev, ExecutionTimeoutEvent) assert ev.event_type.value == "execution_timeout" and ev.timeout_seconds == 5 # Failed: terminated exit_code nonzero, message used as stderr, error type defaults to SCRIPT_ERROR @@ -78,6 +92,7 @@ def test_failed_timeout_and_deleted() -> None: pod_fail = Pod("p2", "Failed", cs=[ContainerStatus(State(terminated=Terminated(2, message="boom")))]) pod_fail.metadata.labels = {"execution-id": "e2"} evf = pem_no_logs.map_pod_event(pod_fail, "MODIFIED")[0] + assert isinstance(evf, ExecutionFailedEvent) assert evf.event_type.value == "execution_failed" and evf.error_type in {ExecutionErrorType.SCRIPT_ERROR} # Deleted -> terminated when container terminated present (exit code 0 returns completed for DELETED) @@ -119,7 +134,9 @@ def test_extract_id_and_metadata_priority_and_duplicates() -> None: def test_scheduled_requires_condition() -> None: class Cond: - def __init__(self, t, s): self.type=t; self.status=s + def __init__(self, t: str, s: str) -> None: + self.type = t + self.status = s pem = PodEventMapper(k8s_api=FakeApi(""), logger=_test_logger) pod = Pod("p", "Pending") @@ -134,12 +151,13 @@ def __init__(self, t, s): self.type=t; self.status=s assert pem._map_scheduled(_ctx(pod)) is not None -def test_parse_and_log_paths_and_analyze_failure_variants(caplog) -> None: +def test_parse_and_log_paths_and_analyze_failure_variants(caplog: pytest.LogCaptureFixture) -> None: # _parse_executor_output line-by-line line_json = '{"stdout":"x","stderr":"","exit_code":3,"resource_usage":{}}' pem = PodEventMapper(k8s_api=FakeApi("junk\n" + line_json), logger=_test_logger) pod = Pod("p", "Succeeded", cs=[ContainerStatus(State(terminated=Terminated(0)))]) logs = pem._extract_logs(pod) + assert logs is not None assert logs.exit_code == 3 and logs.stdout == "x" # _extract_logs: no api -> returns None @@ -148,11 +166,16 @@ def test_parse_and_log_paths_and_analyze_failure_variants(caplog) -> None: # _extract_logs exceptions -> 404/400/generic branches, all return None class _API404(FakeApi): - def read_namespaced_pod_log(self, *a, **k): raise Exception("404 Not Found") + def read_namespaced_pod_log(self, name: str, namespace: str, tail_lines: int = 10000) -> str: # noqa: ARG002 + raise Exception("404 Not Found") + class _API400(FakeApi): - def read_namespaced_pod_log(self, *a, **k): raise Exception("400 Bad Request") + def read_namespaced_pod_log(self, name: str, namespace: str, tail_lines: int = 10000) -> str: # noqa: ARG002 + raise Exception("400 Bad Request") + class _APIGen(FakeApi): - def read_namespaced_pod_log(self, *a, **k): raise Exception("boom") + def read_namespaced_pod_log(self, name: str, namespace: str, tail_lines: int = 10000) -> str: # noqa: ARG002 + raise Exception("boom") pem404 = PodEventMapper(k8s_api=_API404(""), logger=_test_logger) assert pem404._extract_logs(pod) is None diff --git a/backend/tests/unit/services/pod_monitor/test_monitor.py b/backend/tests/unit/services/pod_monitor/test_monitor.py index 1e6d5081..1520eb12 100644 --- a/backend/tests/unit/services/pod_monitor/test_monitor.py +++ b/backend/tests/unit/services/pod_monitor/test_monitor.py @@ -1,166 +1,190 @@ import asyncio import logging import types +from typing import Any from unittest.mock import MagicMock import pytest +from app.core import k8s_clients as k8s_clients_module from app.core.k8s_clients import K8sClients +from app.db.repositories.event_repository import EventRepository +from app.domain.events import Event +from app.domain.execution.models import ResourceUsageDomain +from app.events.core import UnifiedProducer +from app.infrastructure.kafka.events.base import BaseEvent +from app.infrastructure.kafka.events.execution import ExecutionCompletedEvent, ExecutionStartedEvent +from app.infrastructure.kafka.events.metadata import AvroEventMetadata +from app.services.kafka_event_service import KafkaEventService from app.services.pod_monitor.config import PodMonitorConfig -from app.services.pod_monitor.monitor import PodMonitor, create_pod_monitor - -from tests.helpers.k8s_fakes import FakeApi, make_pod, make_watch +from app.services.pod_monitor.event_mapper import PodEventMapper +from app.services.pod_monitor.monitor import ( + MonitorState, + PodEvent, + PodMonitor, + ReconciliationResult, + WatchEventType, + create_pod_monitor, +) +from app.settings import Settings +from kubernetes.client.rest import ApiException +from tests.helpers.k8s_fakes import ( + FakeApi, + FakeV1Api, + FakeWatch, + FakeWatchStream, + make_k8s_clients, + make_pod, + make_watch, +) pytestmark = pytest.mark.unit -# Test logger for all tests _test_logger = logging.getLogger("test.pod_monitor") -# ===== Shared stubs for k8s mocking ===== +# ===== Test doubles for KafkaEventService dependencies ===== -class _Cfg: - host = "https://k8s" - ssl_ca_cert = None +class FakeEventRepository(EventRepository): + """In-memory event repository for testing.""" + def __init__(self) -> None: + super().__init__(_test_logger) + self.stored_events: list[Event] = [] -class _K8sConfig: - def load_incluster_config(self): - pass + async def store_event(self, event: Event) -> str: + self.stored_events.append(event) + return event.event_id - def load_kube_config(self, config_file=None): - pass # noqa: ARG002 +class FakeUnifiedProducer(UnifiedProducer): + """Fake producer that captures events without Kafka.""" -class _Conf: - @staticmethod - def get_default_copy(): - return _Cfg() + def __init__(self) -> None: + # Don't call super().__init__ - we don't need real Kafka + self.produced_events: list[tuple[BaseEvent, str | None]] = [] + self.logger = _test_logger + async def produce( + self, event_to_produce: BaseEvent, key: str | None = None, headers: dict[str, str] | None = None + ) -> None: + self.produced_events.append((event_to_produce, key)) -class _ApiClient: - def __init__(self, cfg): - pass # noqa: ARG002 + async def aclose(self) -> None: + pass -class _Core: - def __init__(self, api): - pass # noqa: ARG002 +def create_test_kafka_event_service() -> tuple[KafkaEventService, FakeUnifiedProducer]: + """Create real KafkaEventService with fake dependencies for testing.""" + fake_producer = FakeUnifiedProducer() + fake_repo = FakeEventRepository() + settings = Settings() # Uses defaults/env vars - def get_api_resources(self): - return None + service = KafkaEventService( + event_repository=fake_repo, + kafka_producer=fake_producer, + settings=settings, + logger=_test_logger, + ) + return service, fake_producer -class _Watch: - def __init__(self): - pass +# ===== Helpers to create test instances with pure DI ===== - def stop(self): - pass +class SpyMapper: + """Spy event mapper that tracks clear_cache calls.""" -class _SpyMapper: def __init__(self) -> None: self.cleared = False def clear_cache(self) -> None: self.cleared = True + def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: # noqa: ARG002 + return [] -class _StubV1: - def get_api_resources(self): - return None - - -class _StubWatch: - def stop(self): - return None - - -class _FakeKafkaEventService: - """Fake KafkaEventService for testing.""" - def __init__(self): - self.published_events = [] - - async def publish_base_event(self, event, key=None): - self.published_events.append((event, key)) - return event.event_id if hasattr(event, "event_id") else "fake-id" +def make_k8s_clients_di( + events: list[dict[str, Any]] | None = None, + resource_version: str = "rv1", + pods: list[Any] | None = None, +) -> K8sClients: + """Create K8sClients for DI with fakes.""" + v1, watch = make_k8s_clients(events=events, resource_version=resource_version, pods=pods) + return K8sClients( + api_client=MagicMock(), + v1=v1, + apps_v1=MagicMock(), + networking_v1=MagicMock(), + watch=watch, + ) -def _patch_k8s(monkeypatch, k8s_config=None, conf=None, api_client=None, core=None, watch=None): - """Helper to patch k8s modules with defaults or custom stubs.""" - monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_config", k8s_config or _K8sConfig()) - monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_client.Configuration", conf or _Conf) - monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_client.ApiClient", api_client or _ApiClient) - monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_client.CoreV1Api", core or _Core) - monkeypatch.setattr("app.services.pod_monitor.monitor.watch", types.SimpleNamespace(Watch=watch or _Watch)) +def make_pod_monitor( + config: PodMonitorConfig | None = None, + kafka_service: KafkaEventService | None = None, + k8s_clients: K8sClients | None = None, + event_mapper: PodEventMapper | None = None, +) -> PodMonitor: + """Create PodMonitor with sensible test defaults.""" + cfg = config or PodMonitorConfig() + clients = k8s_clients or make_k8s_clients_di() + mapper = event_mapper or PodEventMapper(logger=_test_logger, k8s_api=FakeApi("{}")) + service = kafka_service or create_test_kafka_event_service()[0] + return PodMonitor( + config=cfg, + kafka_event_service=service, + logger=_test_logger, + k8s_clients=clients, + event_mapper=mapper, + ) # ===== Tests ===== @pytest.mark.asyncio -async def test_start_and_stop_lifecycle(monkeypatch) -> None: +async def test_start_and_stop_lifecycle() -> None: cfg = PodMonitorConfig() cfg.enable_state_reconciliation = False - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - pm._initialize_kubernetes_client = lambda: None - spy = _SpyMapper() - pm._event_mapper = spy - pm._v1 = _StubV1() - pm._watch = _StubWatch() + spy = SpyMapper() + pm = make_pod_monitor(config=cfg, event_mapper=spy) # type: ignore[arg-type] - async def _quick_watch(): + # Replace _watch_pods to avoid real watch loop + async def _quick_watch() -> None: return None - pm._watch_pods = _quick_watch + pm._watch_pods = _quick_watch # type: ignore[method-assign] await pm.__aenter__() - assert pm.state.name == "RUNNING" + assert pm.state == MonitorState.RUNNING await pm.aclose() - assert pm.state.name == "STOPPED" and spy.cleared is True - - -def test_initialize_kubernetes_client_paths(monkeypatch) -> None: - cfg = PodMonitorConfig() - _patch_k8s(monkeypatch) - - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - pm._initialize_kubernetes_client() - assert pm._v1 is not None and pm._watch is not None + assert pm.state.value == MonitorState.STOPPED.value + assert spy.cleared is True @pytest.mark.asyncio -async def test_watch_pod_events_flow_and_publish(monkeypatch) -> None: +async def test_watch_pod_events_flow_and_publish() -> None: cfg = PodMonitorConfig() cfg.enable_state_reconciliation = False - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - - from app.services.pod_monitor.event_mapper import PodEventMapper as PEM - - pm._event_mapper = PEM(k8s_api=FakeApi("{}"), logger=_test_logger) - - class V1: - def list_namespaced_pod(self, **kwargs): # noqa: ARG002 - return None - - pm._v1 = V1() pod = make_pod(name="p", phase="Succeeded", labels={"execution-id": "e1"}, term_exit=0, resource_version="rv1") - pm._watch = make_watch([{"type": "MODIFIED", "object": pod}], resource_version="rv2") + k8s_clients = make_k8s_clients_di(events=[{"type": "MODIFIED", "object": pod}], resource_version="rv2") + + pm = make_pod_monitor(config=cfg, k8s_clients=k8s_clients) + pm._state = MonitorState.RUNNING - pm._state = pm.state.__class__.RUNNING await pm._watch_pod_events() assert pm._last_resource_version == "rv2" @pytest.mark.asyncio -async def test_process_raw_event_invalid_and_handle_watch_error(monkeypatch) -> None: +async def test_process_raw_event_invalid_and_handle_watch_error() -> None: cfg = PodMonitorConfig() - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) + pm = make_pod_monitor(config=cfg) await pm._process_raw_event({}) @@ -178,7 +202,7 @@ async def test_get_status() -> None: cfg.label_selector = "app=test" cfg.enable_state_reconciliation = True - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) + pm = make_pod_monitor(config=cfg) pm._tracked_pods = {"pod1", "pod2"} pm._reconnect_attempts = 3 pm._last_resource_version = "v123" @@ -194,36 +218,32 @@ async def test_get_status() -> None: @pytest.mark.asyncio -async def test_reconciliation_loop_and_state(monkeypatch) -> None: +async def test_reconciliation_loop_and_state() -> None: cfg = PodMonitorConfig() cfg.enable_state_reconciliation = True - cfg.reconcile_interval_seconds = 0.01 + cfg.reconcile_interval_seconds = 0 # sleep(0) yields control immediately - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - pm._state = pm.state.__class__.RUNNING + pm = make_pod_monitor(config=cfg) + pm._state = MonitorState.RUNNING - reconcile_called = [] + reconcile_called: list[bool] = [] - async def mock_reconcile(): + async def mock_reconcile() -> ReconciliationResult: reconcile_called.append(True) - from app.services.pod_monitor.monitor import ReconciliationResult - return ReconciliationResult(missing_pods={"p1"}, extra_pods={"p2"}, duration_seconds=0.1, success=True) - pm._reconcile_state = mock_reconcile - evt = asyncio.Event() - async def wrapped_reconcile(): + async def wrapped_reconcile() -> ReconciliationResult: res = await mock_reconcile() evt.set() return res - pm._reconcile_state = wrapped_reconcile + pm._reconcile_state = wrapped_reconcile # type: ignore[method-assign] task = asyncio.create_task(pm._reconciliation_loop()) await asyncio.wait_for(evt.wait(), timeout=1.0) - pm._state = pm.state.__class__.STOPPED + pm._state = MonitorState.STOPPED task.cancel() with pytest.raises(asyncio.CancelledError): await task @@ -232,30 +252,24 @@ async def wrapped_reconcile(): @pytest.mark.asyncio -async def test_reconcile_state_success(monkeypatch) -> None: +async def test_reconcile_state_success() -> None: cfg = PodMonitorConfig() cfg.namespace = "test" cfg.label_selector = "app=test" - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - - def sync_list(namespace, label_selector): # noqa: ARG002 - return types.SimpleNamespace( - items=[ - make_pod(name="pod1", phase="Running", resource_version="v1"), - make_pod(name="pod2", phase="Running", resource_version="v1"), - ] - ) + pod1 = make_pod(name="pod1", phase="Running", resource_version="v1") + pod2 = make_pod(name="pod2", phase="Running", resource_version="v1") + k8s_clients = make_k8s_clients_di(pods=[pod1, pod2]) - pm._v1 = types.SimpleNamespace(list_namespaced_pod=sync_list) + pm = make_pod_monitor(config=cfg, k8s_clients=k8s_clients) pm._tracked_pods = {"pod2", "pod3"} - processed = [] + processed: list[str] = [] - async def mock_process(event): + async def mock_process(event: PodEvent) -> None: processed.append(event.pod.metadata.name) - pm._process_pod_event = mock_process + pm._process_pod_event = mock_process # type: ignore[method-assign] result = await pm._reconcile_state() @@ -266,44 +280,38 @@ async def mock_process(event): assert "pod3" not in pm._tracked_pods -@pytest.mark.asyncio -async def test_reconcile_state_no_v1_api() -> None: - cfg = PodMonitorConfig() - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - pm._v1 = None - - result = await pm._reconcile_state() - assert result.success is False - assert result.error == "K8s API not initialized" - - @pytest.mark.asyncio async def test_reconcile_state_exception() -> None: cfg = PodMonitorConfig() - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - class FailV1: - def list_namespaced_pod(self, *a, **k): + class FailV1(FakeV1Api): + def list_namespaced_pod(self, namespace: str, label_selector: str) -> Any: raise RuntimeError("API error") - pm._v1 = FailV1() + fail_v1 = FailV1() + k8s_clients = K8sClients( + api_client=MagicMock(), + v1=fail_v1, + apps_v1=MagicMock(), + networking_v1=MagicMock(), + watch=make_watch([]), + ) + + pm = make_pod_monitor(config=cfg, k8s_clients=k8s_clients) result = await pm._reconcile_state() assert result.success is False + assert result.error is not None assert "API error" in result.error @pytest.mark.asyncio async def test_process_pod_event_full_flow() -> None: - from app.services.pod_monitor.monitor import PodEvent, WatchEventType - cfg = PodMonitorConfig() cfg.ignored_pod_phases = ["Unknown"] - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - class MockMapper: - def map_pod_event(self, pod, event_type): + def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: # noqa: ARG002 class Event: event_type = types.SimpleNamespace(value="test_event") metadata = types.SimpleNamespace(correlation_id=None) @@ -311,14 +319,17 @@ class Event: return [Event()] - pm._event_mapper = MockMapper() + def clear_cache(self) -> None: + pass - published = [] + pm = make_pod_monitor(config=cfg, event_mapper=MockMapper()) # type: ignore[arg-type] - async def mock_publish(event, pod): + published: list[Any] = [] + + async def mock_publish(event: Any, pod: Any) -> None: # noqa: ARG001 published.append(event) - pm._publish_event = mock_publish + pm._publish_event = mock_publish # type: ignore[method-assign] event = PodEvent( event_type=WatchEventType.ADDED, @@ -354,70 +365,82 @@ async def mock_publish(event, pod): @pytest.mark.asyncio async def test_process_pod_event_exception_handling() -> None: - from app.services.pod_monitor.monitor import PodEvent, WatchEventType - cfg = PodMonitorConfig() - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) class FailMapper: - def map_pod_event(self, pod, event_type): + def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: raise RuntimeError("Mapping failed") - pm._event_mapper = FailMapper() + def clear_cache(self) -> None: + pass + + pm = make_pod_monitor(config=cfg, event_mapper=FailMapper()) # type: ignore[arg-type] event = PodEvent( - event_type=WatchEventType.ADDED, pod=make_pod(name="fail-pod", phase="Pending"), resource_version=None + event_type=WatchEventType.ADDED, + pod=make_pod(name="fail-pod", phase="Pending"), + resource_version=None, ) + # Should not raise - errors are caught and logged await pm._process_pod_event(event) @pytest.mark.asyncio async def test_publish_event_full_flow() -> None: - from app.domain.enums.events import EventType - cfg = PodMonitorConfig() - fake_service = _FakeKafkaEventService() - pm = PodMonitor(cfg, kafka_event_service=fake_service, logger=_test_logger) + service, fake_producer = create_test_kafka_event_service() + pm = make_pod_monitor(config=cfg, kafka_service=service) - class Event: - event_type = EventType.EXECUTION_COMPLETED - metadata = types.SimpleNamespace(correlation_id=None) - aggregate_id = "exec1" - execution_id = "exec1" - event_id = "evt-123" + event = ExecutionCompletedEvent( + execution_id="exec1", + aggregate_id="exec1", + exit_code=0, + resource_usage=ResourceUsageDomain(), + metadata=AvroEventMetadata(service_name="test", service_version="1.0"), + ) pod = make_pod(name="test-pod", phase="Succeeded", labels={"execution-id": "exec1"}) - await pm._publish_event(Event(), pod) + await pm._publish_event(event, pod) - assert len(fake_service.published_events) == 1 - assert fake_service.published_events[0][1] == "exec1" + assert len(fake_producer.produced_events) == 1 + assert fake_producer.produced_events[0][1] == "exec1" @pytest.mark.asyncio async def test_publish_event_exception_handling() -> None: - from app.domain.enums.events import EventType - cfg = PodMonitorConfig() - class FailingKafkaEventService: - async def publish_base_event(self, event, key=None): + class FailingProducer(FakeUnifiedProducer): + async def produce( + self, event_to_produce: BaseEvent, key: str | None = None, headers: dict[str, str] | None = None + ) -> None: raise RuntimeError("Publish failed") - pm = PodMonitor(cfg, kafka_event_service=FailingKafkaEventService(), logger=_test_logger) + # Create service with failing producer + failing_producer = FailingProducer() + fake_repo = FakeEventRepository() + failing_service = KafkaEventService( + event_repository=fake_repo, + kafka_producer=failing_producer, + settings=Settings(), + logger=_test_logger, + ) + + pm = make_pod_monitor(config=cfg, kafka_service=failing_service) - class Event: - event_type = EventType.EXECUTION_STARTED - metadata = types.SimpleNamespace(correlation_id=None) - aggregate_id = None - execution_id = None + event = ExecutionStartedEvent( + execution_id="exec1", + pod_name="test-pod", + metadata=AvroEventMetadata(service_name="test", service_version="1.0"), + ) - class Pod: - metadata = None - status = None + # Use pod with no metadata to exercise edge case + pod = make_pod(name="no-meta-pod", phase="Pending") + pod.metadata = None # type: ignore[assignment] # Should not raise - errors are caught and logged - await pm._publish_event(Event(), Pod()) + await pm._publish_event(event, pod) @pytest.mark.asyncio @@ -425,57 +448,55 @@ async def test_handle_watch_error_max_attempts() -> None: cfg = PodMonitorConfig() cfg.max_reconnect_attempts = 2 - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - pm._state = pm.state.__class__.RUNNING + pm = make_pod_monitor(config=cfg) + pm._state = MonitorState.RUNNING pm._reconnect_attempts = 2 await pm._handle_watch_error() - assert pm._state == pm.state.__class__.STOPPING + assert pm._state == MonitorState.STOPPING @pytest.mark.asyncio -async def test_watch_pods_main_loop(monkeypatch) -> None: +async def test_watch_pods_main_loop() -> None: cfg = PodMonitorConfig() - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - pm._state = pm.state.__class__.RUNNING + pm = make_pod_monitor(config=cfg) + pm._state = MonitorState.RUNNING - watch_count = [] + watch_count: list[int] = [] - async def mock_watch(): + async def mock_watch() -> None: watch_count.append(1) if len(watch_count) > 2: - pm._state = pm.state.__class__.STOPPED + pm._state = MonitorState.STOPPED - async def mock_handle_error(): + async def mock_handle_error() -> None: pass - pm._watch_pod_events = mock_watch - pm._handle_watch_error = mock_handle_error + pm._watch_pod_events = mock_watch # type: ignore[method-assign] + pm._handle_watch_error = mock_handle_error # type: ignore[method-assign] await pm._watch_pods() assert len(watch_count) > 2 @pytest.mark.asyncio -async def test_watch_pods_api_exception(monkeypatch) -> None: - from kubernetes.client.rest import ApiException - +async def test_watch_pods_api_exception() -> None: cfg = PodMonitorConfig() - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - pm._state = pm.state.__class__.RUNNING + pm = make_pod_monitor(config=cfg) + pm._state = MonitorState.RUNNING - async def mock_watch(): + async def mock_watch() -> None: raise ApiException(status=410) - error_handled = [] + error_handled: list[bool] = [] - async def mock_handle(): + async def mock_handle() -> None: error_handled.append(True) - pm._state = pm.state.__class__.STOPPED + pm._state = MonitorState.STOPPED - pm._watch_pod_events = mock_watch - pm._handle_watch_error = mock_handle + pm._watch_pod_events = mock_watch # type: ignore[method-assign] + pm._handle_watch_error = mock_handle # type: ignore[method-assign] await pm._watch_pods() @@ -484,77 +505,100 @@ async def mock_handle(): @pytest.mark.asyncio -async def test_watch_pods_generic_exception(monkeypatch) -> None: +async def test_watch_pods_generic_exception() -> None: cfg = PodMonitorConfig() - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - pm._state = pm.state.__class__.RUNNING + pm = make_pod_monitor(config=cfg) + pm._state = MonitorState.RUNNING - async def mock_watch(): + async def mock_watch() -> None: raise RuntimeError("Unexpected error") - error_handled = [] + error_handled: list[bool] = [] - async def mock_handle(): + async def mock_handle() -> None: error_handled.append(True) - pm._state = pm.state.__class__.STOPPED + pm._state = MonitorState.STOPPED - pm._watch_pod_events = mock_watch - pm._handle_watch_error = mock_handle + pm._watch_pod_events = mock_watch # type: ignore[method-assign] + pm._handle_watch_error = mock_handle # type: ignore[method-assign] await pm._watch_pods() assert len(error_handled) > 0 @pytest.mark.asyncio -async def test_create_pod_monitor_context_manager(monkeypatch) -> None: - _patch_k8s(monkeypatch) +async def test_create_pod_monitor_context_manager(monkeypatch: pytest.MonkeyPatch) -> None: + """Test create_pod_monitor factory with auto-created dependencies.""" + # Mock create_k8s_clients to avoid real K8s connection + mock_v1 = FakeV1Api() + mock_watch = make_watch([]) + mock_clients = K8sClients( + api_client=MagicMock(), + v1=mock_v1, + apps_v1=MagicMock(), + networking_v1=MagicMock(), + watch=mock_watch, + ) + + def mock_create_clients( + logger: logging.Logger, # noqa: ARG001 + kubeconfig_path: str | None = None, # noqa: ARG001 + in_cluster: bool | None = None, # noqa: ARG001 + ) -> K8sClients: + return mock_clients + + monkeypatch.setattr(k8s_clients_module, "create_k8s_clients", mock_create_clients) + monkeypatch.setattr("app.services.pod_monitor.monitor.create_k8s_clients", mock_create_clients) cfg = PodMonitorConfig() cfg.enable_state_reconciliation = False - fake_service = _FakeKafkaEventService() + service, _ = create_test_kafka_event_service() - async with create_pod_monitor(cfg, fake_service, _test_logger) as monitor: - assert monitor.state == monitor.state.__class__.RUNNING + # Use the actual create_pod_monitor which will use our mocked create_k8s_clients + async with create_pod_monitor(cfg, service, _test_logger) as monitor: + assert monitor.state == MonitorState.RUNNING - assert monitor.state == monitor.state.__class__.STOPPED + assert monitor.state.value == MonitorState.STOPPED.value @pytest.mark.asyncio -async def test_create_pod_monitor_with_injected_k8s_clients(monkeypatch) -> None: +async def test_create_pod_monitor_with_injected_k8s_clients() -> None: """Test create_pod_monitor with injected K8sClients (DI path).""" - _patch_k8s(monkeypatch) - cfg = PodMonitorConfig() cfg.enable_state_reconciliation = False - fake_service = _FakeKafkaEventService() + service, _ = create_test_kafka_event_service() - mock_v1 = MagicMock() - mock_v1.get_api_resources.return_value = None + mock_v1 = FakeV1Api() + mock_watch = make_watch([]) mock_k8s_clients = K8sClients( api_client=MagicMock(), v1=mock_v1, apps_v1=MagicMock(), networking_v1=MagicMock(), + watch=mock_watch, ) - async with create_pod_monitor(cfg, fake_service, _test_logger, k8s_clients=mock_k8s_clients) as monitor: - assert monitor.state == monitor.state.__class__.RUNNING + async with create_pod_monitor( + cfg, service, _test_logger, k8s_clients=mock_k8s_clients + ) as monitor: + assert monitor.state == MonitorState.RUNNING assert monitor._clients is mock_k8s_clients assert monitor._v1 is mock_v1 - assert monitor.state == monitor.state.__class__.STOPPED + assert monitor.state.value == MonitorState.STOPPED.value @pytest.mark.asyncio async def test_start_already_running() -> None: """Test idempotent start via __aenter__.""" cfg = PodMonitorConfig() - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) + pm = make_pod_monitor(config=cfg) + # Simulate already started state pm._lifecycle_started = True - pm._state = pm.state.__class__.RUNNING + pm._state = MonitorState.RUNNING # Should be idempotent - just return self await pm.__aenter__() @@ -564,8 +608,8 @@ async def test_start_already_running() -> None: async def test_stop_already_stopped() -> None: """Test idempotent stop via aclose().""" cfg = PodMonitorConfig() - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - pm._state = pm.state.__class__.STOPPED + pm = make_pod_monitor(config=cfg) + pm._state = MonitorState.STOPPED # Not started, so aclose should be a no-op await pm.aclose() @@ -575,27 +619,26 @@ async def test_stop_already_stopped() -> None: async def test_stop_with_tasks() -> None: """Test cleanup of tasks on aclose().""" cfg = PodMonitorConfig() - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - pm._state = pm.state.__class__.RUNNING - pm._lifecycle_started = True # Simulate started state + pm = make_pod_monitor(config=cfg) + pm._state = MonitorState.RUNNING + pm._lifecycle_started = True - async def dummy_task(): + async def dummy_task() -> None: await asyncio.Event().wait() pm._watch_task = asyncio.create_task(dummy_task()) pm._reconcile_task = asyncio.create_task(dummy_task()) - pm._watch = _StubWatch() pm._tracked_pods = {"pod1"} await pm.aclose() - assert pm._state == pm.state.__class__.STOPPED + assert pm._state == MonitorState.STOPPED assert len(pm._tracked_pods) == 0 def test_update_resource_version() -> None: cfg = PodMonitorConfig() - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) + pm = make_pod_monitor(config=cfg) class Stream: _stop_event = types.SimpleNamespace(resource_version="v123") @@ -612,14 +655,14 @@ class BadStream: @pytest.mark.asyncio async def test_process_raw_event_with_metadata() -> None: cfg = PodMonitorConfig() - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) + pm = make_pod_monitor(config=cfg) - processed = [] + processed: list[PodEvent] = [] - async def mock_process(event): + async def mock_process(event: PodEvent) -> None: processed.append(event) - pm._process_pod_event = mock_process + pm._process_pod_event = mock_process # type: ignore[method-assign] raw_event = { "type": "ADDED", @@ -637,140 +680,56 @@ async def mock_process(event): assert processed[1].resource_version is None -def test_initialize_kubernetes_client_in_cluster(monkeypatch) -> None: - cfg = PodMonitorConfig() - cfg.in_cluster = True - - load_incluster_called = [] - - class TrackingK8sConfig: - def load_incluster_config(self): - load_incluster_called.append(True) - - def load_kube_config(self, config_file=None): - pass # noqa: ARG002 - - _patch_k8s(monkeypatch, k8s_config=TrackingK8sConfig()) - - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - pm._initialize_kubernetes_client() - - assert len(load_incluster_called) == 1 - - -def test_initialize_kubernetes_client_with_kubeconfig_path(monkeypatch) -> None: - cfg = PodMonitorConfig() - cfg.in_cluster = False - cfg.kubeconfig_path = "/custom/kubeconfig" - - load_kube_called_with = [] - - class TrackingK8sConfig: - def load_incluster_config(self): - pass - - def load_kube_config(self, config_file=None): - load_kube_called_with.append(config_file) - - class ConfWithCert: - @staticmethod - def get_default_copy(): - return types.SimpleNamespace(host="https://k8s", ssl_ca_cert="cert") - - _patch_k8s(monkeypatch, k8s_config=TrackingK8sConfig(), conf=ConfWithCert) - - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - pm._initialize_kubernetes_client() - - assert load_kube_called_with == ["/custom/kubeconfig"] - - -def test_initialize_kubernetes_client_exception(monkeypatch) -> None: - cfg = PodMonitorConfig() - - class FailingK8sConfig: - def load_kube_config(self, config_file=None): - raise Exception("K8s config error") - - monkeypatch.setattr("app.services.pod_monitor.monitor.k8s_config", FailingK8sConfig()) - - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - - with pytest.raises(Exception) as exc_info: - pm._initialize_kubernetes_client() - - assert "K8s config error" in str(exc_info.value) - - @pytest.mark.asyncio -async def test_watch_pods_api_exception_other_status(monkeypatch) -> None: - from kubernetes.client.rest import ApiException - +async def test_watch_pods_api_exception_other_status() -> None: cfg = PodMonitorConfig() - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - pm._state = pm.state.__class__.RUNNING + pm = make_pod_monitor(config=cfg) + pm._state = MonitorState.RUNNING - async def mock_watch(): + async def mock_watch() -> None: raise ApiException(status=500) - error_handled = [] + error_handled: list[bool] = [] - async def mock_handle(): + async def mock_handle() -> None: error_handled.append(True) - pm._state = pm.state.__class__.STOPPED + pm._state = MonitorState.STOPPED - pm._watch_pod_events = mock_watch - pm._handle_watch_error = mock_handle + pm._watch_pod_events = mock_watch # type: ignore[method-assign] + pm._handle_watch_error = mock_handle # type: ignore[method-assign] await pm._watch_pods() assert len(error_handled) > 0 -@pytest.mark.asyncio -async def test_watch_pod_events_no_watch_or_v1() -> None: - cfg = PodMonitorConfig() - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - - pm._watch = None - pm._v1 = _StubV1() - - with pytest.raises(RuntimeError) as exc_info: - await pm._watch_pod_events() - - assert "Watch or API not initialized" in str(exc_info.value) - - pm._watch = _StubWatch() - pm._v1 = None - - with pytest.raises(RuntimeError) as exc_info: - await pm._watch_pod_events() - - assert "Watch or API not initialized" in str(exc_info.value) - - @pytest.mark.asyncio async def test_watch_pod_events_with_field_selector() -> None: cfg = PodMonitorConfig() cfg.field_selector = "status.phase=Running" cfg.enable_state_reconciliation = False - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) + watch_kwargs: list[dict[str, Any]] = [] - watch_kwargs = [] - - class V1: - def list_namespaced_pod(self, **kwargs): - watch_kwargs.append(kwargs) + class TrackingV1(FakeV1Api): + def list_namespaced_pod(self, namespace: str, label_selector: str) -> Any: + watch_kwargs.append({"namespace": namespace, "label_selector": label_selector}) return None - class Watch: - def stream(self, func, **kwargs): + class TrackingWatch(FakeWatch): + def stream(self, func: Any, **kwargs: Any) -> FakeWatchStream: watch_kwargs.append(kwargs) - return [] + return FakeWatchStream([], "rv1") + + k8s_clients = K8sClients( + api_client=MagicMock(), + v1=TrackingV1(), + apps_v1=MagicMock(), + networking_v1=MagicMock(), + watch=TrackingWatch([], "rv1"), + ) - pm._v1 = V1() - pm._watch = Watch() - pm._state = pm.state.__class__.RUNNING + pm = make_pod_monitor(config=cfg, k8s_clients=k8s_clients) + pm._state = MonitorState.RUNNING await pm._watch_pod_events() @@ -781,22 +740,22 @@ def stream(self, func, **kwargs): async def test_reconciliation_loop_exception() -> None: cfg = PodMonitorConfig() cfg.enable_state_reconciliation = True - cfg.reconcile_interval_seconds = 0.01 + cfg.reconcile_interval_seconds = 0 # sleep(0) yields control immediately - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - pm._state = pm.state.__class__.RUNNING + pm = make_pod_monitor(config=cfg) + pm._state = MonitorState.RUNNING hit = asyncio.Event() - async def raising(): + async def raising() -> ReconciliationResult: hit.set() raise RuntimeError("Reconcile error") - pm._reconcile_state = raising + pm._reconcile_state = raising # type: ignore[method-assign] task = asyncio.create_task(pm._reconciliation_loop()) await asyncio.wait_for(hit.wait(), timeout=1.0) - pm._state = pm.state.__class__.STOPPED + pm._state = MonitorState.STOPPED task.cancel() with pytest.raises(asyncio.CancelledError): await task @@ -807,19 +766,16 @@ async def test_start_with_reconciliation() -> None: cfg = PodMonitorConfig() cfg.enable_state_reconciliation = True - pm = PodMonitor(cfg, kafka_event_service=_FakeKafkaEventService(), logger=_test_logger) - pm._initialize_kubernetes_client = lambda: None - pm._v1 = _StubV1() - pm._watch = _StubWatch() + pm = make_pod_monitor(config=cfg) - async def mock_watch(): + async def mock_watch() -> None: return None - async def mock_reconcile(): + async def mock_reconcile() -> None: return None - pm._watch_pods = mock_watch - pm._reconciliation_loop = mock_reconcile + pm._watch_pods = mock_watch # type: ignore[method-assign] + pm._reconciliation_loop = mock_reconcile # type: ignore[method-assign] await pm.__aenter__() assert pm._watch_task is not None diff --git a/backend/tests/unit/services/result_processor/test_processor.py b/backend/tests/unit/services/result_processor/test_processor.py index 26ef9fdd..79410f7e 100644 --- a/backend/tests/unit/services/result_processor/test_processor.py +++ b/backend/tests/unit/services/result_processor/test_processor.py @@ -12,7 +12,7 @@ class TestResultProcessorConfig: - def test_default_values(self): + def test_default_values(self) -> None: config = ResultProcessorConfig() assert config.consumer_group == GroupId.RESULT_PROCESSOR assert KafkaTopic.EXECUTION_COMPLETED in config.topics @@ -22,13 +22,13 @@ def test_default_values(self): assert config.batch_size == 10 assert config.processing_timeout == 300 - def test_custom_values(self): + def test_custom_values(self) -> None: config = ResultProcessorConfig(batch_size=20, processing_timeout=600) assert config.batch_size == 20 assert config.processing_timeout == 600 -def test_create_dispatcher_registers_handlers(): +def test_create_dispatcher_registers_handlers() -> None: rp = ResultProcessor( execution_repo=MagicMock(), producer=MagicMock(), diff --git a/backend/tests/unit/services/saga/test_execution_saga_steps.py b/backend/tests/unit/services/saga/test_execution_saga_steps.py index ee57f431..982bbc6c 100644 --- a/backend/tests/unit/services/saga/test_execution_saga_steps.py +++ b/backend/tests/unit/services/saga/test_execution_saga_steps.py @@ -1,23 +1,27 @@ import pytest -from app.domain.saga import DomainResourceAllocation +from app.db.repositories.resource_allocation_repository import ResourceAllocationRepository +from app.domain.saga import DomainResourceAllocation, DomainResourceAllocationCreate +from app.events.core import UnifiedProducer +from app.infrastructure.kafka.events import BaseEvent +from app.infrastructure.kafka.events.execution import ExecutionRequestedEvent from app.services.saga.execution_saga import ( - ValidateExecutionStep, AllocateResourcesStep, - QueueExecutionStep, CreatePodStep, + DeletePodCompensation, + ExecutionSaga, MonitorExecutionStep, + QueueExecutionStep, ReleaseResourcesCompensation, - DeletePodCompensation, + ValidateExecutionStep, ) from app.services.saga.saga_step import SagaContext from tests.helpers import make_execution_requested_event - pytestmark = pytest.mark.unit -def _req(timeout: int = 30, script: str = "print('x')"): +def _req(timeout: int = 30, script: str = "print('x')") -> ExecutionRequestedEvent: return make_execution_requested_event(execution_id="e1", script=script, timeout_seconds=timeout) @@ -39,16 +43,18 @@ async def test_validate_execution_step_success_and_failures() -> None: assert ok3 is False and ctx3.error is not None -class _FakeAllocRepo: +class _FakeAllocRepo(ResourceAllocationRepository): + """Fake ResourceAllocationRepository for testing.""" + def __init__(self, active: int = 0, alloc_id: str = "alloc-1") -> None: self.active = active self.alloc_id = alloc_id self.released: list[str] = [] - async def count_active(self, language: str) -> int: # noqa: ARG002 + async def count_active(self, language: str) -> int: return self.active - async def create_allocation(self, create_data) -> DomainResourceAllocation: # noqa: ARG002 + async def create_allocation(self, create_data: DomainResourceAllocationCreate) -> DomainResourceAllocation: return DomainResourceAllocation( allocation_id=self.alloc_id, execution_id=create_data.execution_id, @@ -59,8 +65,9 @@ async def create_allocation(self, create_data) -> DomainResourceAllocation: # n memory_limit=create_data.memory_limit, ) - async def release_allocation(self, allocation_id: str) -> None: + async def release_allocation(self, allocation_id: str) -> bool: self.released.append(allocation_id) + return True @pytest.mark.asyncio @@ -94,19 +101,23 @@ async def test_queue_and_monitor_steps() -> None: assert ctx.get("monitoring_active") is True # Force exceptions to exercise except paths - class _Ctx(SagaContext): - def set(self, key, value): # type: ignore[override] + class _BadCtx(SagaContext): + def set(self, key: str, value: object) -> None: raise RuntimeError("boom") - bad = _Ctx("s", "e") + + bad = _BadCtx("s", "e") assert await QueueExecutionStep().execute(bad, _req()) is False assert await MonitorExecutionStep().execute(bad, _req()) is False -class _FakeProducer: +class _FakeProducer(UnifiedProducer): + """Fake UnifiedProducer for testing.""" + def __init__(self) -> None: - self.events: list[object] = [] + self.events: list[BaseEvent] = [] - async def produce(self, event_to_produce, key: str | None = None): # noqa: ARG002 + async def produce(self, event_to_produce: BaseEvent, key: str | None = None, + headers: dict[str, str] | None = None) -> None: self.events.append(event_to_produce) @@ -180,16 +191,14 @@ async def test_delete_pod_compensation_variants() -> None: def test_execution_saga_bind_and_get_steps_sets_flags_and_types() -> None: # Dummy subclasses to satisfy isinstance checks without real deps - from app.events.core import UnifiedProducer - from app.db.repositories.resource_allocation_repository import ResourceAllocationRepository - class DummyProd(UnifiedProducer): - def __init__(self): pass # type: ignore[no-untyped-def] + def __init__(self) -> None: + pass # Skip parent __init__ class DummyAlloc(ResourceAllocationRepository): - def __init__(self): pass # type: ignore[no-untyped-def] + def __init__(self) -> None: + pass # Skip parent __init__ - from app.services.saga.execution_saga import ExecutionSaga, CreatePodStep s = ExecutionSaga() s.bind_dependencies(producer=DummyProd(), alloc_repo=DummyAlloc(), publish_commands=True) steps = s.get_steps() diff --git a/backend/tests/unit/services/saga/test_saga_comprehensive.py b/backend/tests/unit/services/saga/test_saga_comprehensive.py index e746164b..027ec634 100644 --- a/backend/tests/unit/services/saga/test_saga_comprehensive.py +++ b/backend/tests/unit/services/saga/test_saga_comprehensive.py @@ -10,9 +10,11 @@ from app.domain.enums.events import EventType from app.domain.enums.saga import SagaState from app.domain.saga.models import Saga -from tests.helpers import make_execution_requested_event +from app.infrastructure.kafka.events.base import BaseEvent +from app.infrastructure.kafka.events.execution import ExecutionRequestedEvent from app.services.saga.execution_saga import ExecutionSaga from app.services.saga.saga_step import CompensationStep, SagaContext, SagaStep +from tests.helpers import make_execution_requested_event pytestmark = pytest.mark.unit @@ -23,19 +25,19 @@ async def compensate(self, context: SagaContext) -> bool: # noqa: ARG002 return True -class _Step(SagaStep): - def __init__(self, name: str, ok: bool = True): +class _Step(SagaStep[BaseEvent]): + def __init__(self, name: str, ok: bool = True) -> None: super().__init__(name) self._ok = ok - async def execute(self, context: SagaContext, event) -> bool: # noqa: ARG002 + async def execute(self, context: SagaContext, event: BaseEvent) -> bool: # noqa: ARG002 return self._ok - def get_compensation(self): + def get_compensation(self) -> CompensationStep: return _NoopComp(f"{self.name}-comp") -def _req_event(): +def _req_event() -> ExecutionRequestedEvent: return make_execution_requested_event(execution_id="e1", script="print('x')") diff --git a/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py b/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py index 75fb2e25..0621caf8 100644 --- a/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py +++ b/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py @@ -1,31 +1,57 @@ import logging +from datetime import datetime, timezone +from typing import ClassVar +from unittest.mock import MagicMock import pytest +from pydantic import Field + +from app.db.repositories.resource_allocation_repository import ResourceAllocationRepository +from app.db.repositories.saga_repository import SagaRepository from app.domain.enums.events import EventType +from app.domain.enums.kafka import KafkaTopic from app.domain.enums.saga import SagaState from app.domain.saga.models import Saga, SagaConfig +from app.events.core import UnifiedProducer +from app.events.event_store import EventStore +from app.infrastructure.kafka.events import BaseEvent +from app.infrastructure.kafka.events.metadata import AvroEventMetadata +from app.services.idempotency.idempotency_manager import IdempotencyManager from app.services.saga.base_saga import BaseSaga from app.services.saga.saga_orchestrator import SagaOrchestrator -from app.services.saga.saga_step import SagaStep +from app.services.saga.saga_step import CompensationStep, SagaContext, SagaStep +from app.events.schema.schema_registry import SchemaRegistryManager +from app.settings import Settings pytestmark = pytest.mark.unit _test_logger = logging.getLogger("test.services.saga.orchestrator") -class _Evt: - def __init__(self, et: EventType, execution_id: str): - self.event_type = et - self.execution_id = execution_id - self.event_id = "evid" +class _FakeEvent(BaseEvent): + """Fake event for testing that extends BaseEvent. + + Note: event_type has no default to avoid polluting the global event type mapping + (which is built from BaseEvent subclasses with default event_type values). + """ + + event_type: EventType # No default - set explicitly in _make_event() + execution_id: str = "" + topic: ClassVar[KafkaTopic] = KafkaTopic.EXECUTION_EVENTS + metadata: AvroEventMetadata = Field(default_factory=lambda: AvroEventMetadata( + service_name="test", service_version="1.0", correlation_id="test-corr-id" + )) + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + +class _FakeRepo(SagaRepository): + """Fake SagaRepository for testing.""" -class _Repo: def __init__(self) -> None: self.saved: list[Saga] = [] self.existing: dict[tuple[str, str], Saga] = {} - async def get_saga_by_execution_and_name(self, execution_id: str, saga_name: str): # noqa: ARG002 + async def get_saga_by_execution_and_name(self, execution_id: str, saga_name: str) -> Saga | None: return self.existing.get((execution_id, saga_name)) async def upsert_saga(self, saga: Saga) -> bool: @@ -33,71 +59,111 @@ async def upsert_saga(self, saga: Saga) -> bool: return True -class _Prod: - async def produce(self, event_to_produce, key=None): # noqa: ARG002 +class _FakeProd(UnifiedProducer): + """Fake UnifiedProducer for testing.""" + + def __init__(self) -> None: + pass # Skip parent __init__ + + async def produce(self, event_to_produce: BaseEvent, key: str | None = None, headers: dict[str, str] | None = None) -> None: return None -class _Idem: - async def close(self): +class _FakeIdem(IdempotencyManager): + """Fake IdempotencyManager for testing.""" + + def __init__(self) -> None: + pass # Skip parent __init__ + + async def close(self) -> None: return None -class _Store: ... -class _Alloc: ... -class _SchemaRegistry: ... -class _Settings: ... +class _FakeStore(EventStore): + """Fake EventStore for testing.""" + def __init__(self) -> None: + pass # Skip parent __init__ + + +class _FakeAlloc(ResourceAllocationRepository): + """Fake ResourceAllocationRepository for testing.""" -class _StepOK(SagaStep[_Evt]): + def __init__(self) -> None: + pass # No special attributes needed + + +class _StepOK(SagaStep[_FakeEvent]): def __init__(self) -> None: super().__init__("ok") - async def execute(self, context, event) -> bool: # noqa: ARG002 + + async def execute(self, context: SagaContext, event: _FakeEvent) -> bool: return True + def get_compensation(self) -> CompensationStep | None: + return None + class _Saga(BaseSaga): @classmethod def get_name(cls) -> str: return "s" + @classmethod - def get_trigger_events(cls): + def get_trigger_events(cls) -> list[EventType]: return [EventType.EXECUTION_REQUESTED] - def get_steps(self): + + def get_steps(self) -> list[SagaStep[_FakeEvent]]: return [_StepOK()] def _orch() -> SagaOrchestrator: return SagaOrchestrator( config=SagaConfig(name="t", enable_compensation=True, store_events=True, publish_commands=False), - saga_repository=_Repo(), - producer=_Prod(), - schema_registry_manager=_SchemaRegistry(), # type: ignore[arg-type] - settings=_Settings(), # type: ignore[arg-type] - event_store=_Store(), - idempotency_manager=_Idem(), - resource_allocation_repository=_Alloc(), + saga_repository=_FakeRepo(), + producer=_FakeProd(), + schema_registry_manager=MagicMock(spec=SchemaRegistryManager), + settings=MagicMock(spec=Settings), + event_store=_FakeStore(), + idempotency_manager=_FakeIdem(), + resource_allocation_repository=_FakeAlloc(), logger=_test_logger, ) +def _make_event(et: EventType, execution_id: str) -> _FakeEvent: + return _FakeEvent(event_type=et, execution_id=execution_id) + + @pytest.mark.asyncio async def test_min_success_flow() -> None: orch = _orch() - orch.register_saga(_Saga) # type: ignore[arg-type] - orch._running = True - await orch._handle_event(_Evt(EventType.EXECUTION_REQUESTED, "e")) - assert orch._running is True # basic sanity; deep behavior covered by integration + orch.register_saga(_Saga) + # Set orchestrator running state via lifecycle property + orch._lifecycle_started = True + await orch._handle_event(_make_event(EventType.EXECUTION_REQUESTED, "e")) + # basic sanity; deep behavior covered by integration + assert orch.is_running is True @pytest.mark.asyncio async def test_should_trigger_and_existing_short_circuit() -> None: - orch = _orch() - orch.register_saga(_Saga) # type: ignore[arg-type] - assert orch._should_trigger_saga(_Saga, _Evt(EventType.EXECUTION_REQUESTED, "e")) is True + fake_repo = _FakeRepo() + orch = SagaOrchestrator( + config=SagaConfig(name="t", enable_compensation=True, store_events=True, publish_commands=False), + saga_repository=fake_repo, + producer=_FakeProd(), + schema_registry_manager=MagicMock(spec=SchemaRegistryManager), + settings=MagicMock(spec=Settings), + event_store=_FakeStore(), + idempotency_manager=_FakeIdem(), + resource_allocation_repository=_FakeAlloc(), + logger=_test_logger, + ) + orch.register_saga(_Saga) + assert orch._should_trigger_saga(_Saga, _make_event(EventType.EXECUTION_REQUESTED, "e")) is True # Existing short-circuit returns existing ID - repo = orch._repo # type: ignore[attr-defined] s = Saga(saga_id="sX", saga_name="s", execution_id="e", state=SagaState.RUNNING) - repo.existing[("e", "s")] = s - sid = await orch._start_saga("s", _Evt(EventType.EXECUTION_REQUESTED, "e")) + fake_repo.existing[("e", "s")] = s + sid = await orch._start_saga("s", _make_event(EventType.EXECUTION_REQUESTED, "e")) assert sid == "sX" diff --git a/backend/tests/unit/services/saga/test_saga_step_and_base.py b/backend/tests/unit/services/saga/test_saga_step_and_base.py index a8ab93bd..693832d5 100644 --- a/backend/tests/unit/services/saga/test_saga_step_and_base.py +++ b/backend/tests/unit/services/saga/test_saga_step_and_base.py @@ -1,8 +1,12 @@ -import pytest +import asyncio +from unittest.mock import MagicMock -from app.services.saga.saga_step import SagaContext, CompensationStep +import pytest +from app.domain.enums.events import EventType +from app.infrastructure.kafka.events import BaseEvent +from app.infrastructure.kafka.events.metadata import AvroEventMetadata from app.services.saga.base_saga import BaseSaga - +from app.services.saga.saga_step import CompensationStep, SagaContext, SagaStep pytestmark = pytest.mark.unit @@ -13,9 +17,11 @@ def test_saga_context_public_dict_filters_and_encodes() -> None: ctx.set("b", {"x": 2}) ctx.set("c", [1, 2, 3]) ctx.set("_private", {"won't": "leak"}) + # Complex non-JSON object -> should be dropped class X: pass + ctx.set("complex", X()) # Nested complex objects get encoded by jsonable_encoder # The nested dict with a complex object gets partially encoded @@ -37,10 +43,6 @@ async def compensate(self, context: SagaContext) -> bool: # noqa: ARG002 @pytest.mark.asyncio async def test_context_adders() -> None: - from app.infrastructure.kafka.events.metadata import AvroEventMetadata - from app.infrastructure.kafka.events.base import BaseEvent - from app.domain.enums.events import EventType - class E(BaseEvent): event_type: EventType = EventType.SYSTEM_ERROR topic = None # type: ignore[assignment] @@ -60,23 +62,33 @@ def test_base_saga_abstract_calls_cover_pass_lines() -> None: assert BaseSaga.get_trigger_events() is None # Instance-less call to abstract instance method to hit 'pass' assert BaseSaga.get_steps(None) is None # type: ignore[arg-type] + # And the default bind hook returns None when called + class Dummy(BaseSaga): @classmethod - def get_name(cls): return "d" + def get_name(cls) -> str: + return "d" + @classmethod - def get_trigger_events(cls): return [] - def get_steps(self): return [] - assert Dummy().bind_dependencies() is None + def get_trigger_events(cls) -> list[EventType]: + return [] + + def get_steps(self) -> list[SagaStep[BaseEvent]]: + return [] + + Dummy().bind_dependencies() def test_saga_step_str_and_can_execute() -> None: - from app.services.saga.saga_step import SagaStep - class S(SagaStep): - async def execute(self, context, event): return True - def get_compensation(self): return None + class S(SagaStep[BaseEvent]): + async def execute(self, context: SagaContext, event: BaseEvent) -> bool: + return True + + def get_compensation(self) -> CompensationStep | None: + return None + s = S("nm") assert str(s) == "SagaStep(nm)" # can_execute default True - import asyncio - assert asyncio.run(s.can_execute(SagaContext("s","e"), object())) is True + assert asyncio.run(s.can_execute(SagaContext("s", "e"), MagicMock(spec=BaseEvent))) is True diff --git a/backend/tests/unit/services/sse/test_kafka_redis_bridge.py b/backend/tests/unit/services/sse/test_kafka_redis_bridge.py index e4b0cded..48f1b936 100644 --- a/backend/tests/unit/services/sse/test_kafka_redis_bridge.py +++ b/backend/tests/unit/services/sse/test_kafka_redis_bridge.py @@ -1,74 +1,76 @@ -import asyncio import logging -import pytest - -pytestmark = pytest.mark.unit +from typing import ClassVar +from unittest.mock import MagicMock +import pytest +from app.core.metrics.events import EventMetrics from app.domain.enums.events import EventType +from app.domain.enums.kafka import KafkaTopic +from app.events.core import EventDispatcher +from app.events.schema.schema_registry import SchemaRegistryManager +from app.infrastructure.kafka.events import BaseEvent +from app.infrastructure.kafka.events.metadata import AvroEventMetadata from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge +from app.services.sse.redis_bus import SSERedisBus +from app.settings import Settings -_test_logger = logging.getLogger("test.services.sse.kafka_redis_bridge") - - -class _FakeSchema: ... - - -class _FakeSettings: - KAFKA_BOOTSTRAP_SERVERS = "kafka:9092" - SSE_CONSUMER_POOL_SIZE = 1 +pytestmark = pytest.mark.unit +_test_logger = logging.getLogger("test.services.sse.kafka_redis_bridge") -class _FakeEventMetrics: ... +class _FakeBus(SSERedisBus): + """Fake SSERedisBus for testing.""" -class _FakeBus: def __init__(self) -> None: - self.published: list[tuple[str, object]] = [] + self.published: list[tuple[str, BaseEvent]] = [] - async def publish_event(self, execution_id: str, event: object) -> None: + async def publish_event(self, execution_id: str, event: BaseEvent) -> None: self.published.append((execution_id, event)) -class _StubDispatcher: - def __init__(self) -> None: - self.handlers: dict[EventType, object] = {} - - def register_handler(self, et: EventType, fn: object) -> None: - self.handlers[et] = fn +def _make_metadata() -> AvroEventMetadata: + return AvroEventMetadata(service_name="test", service_version="1.0") -class _DummyEvent: - def __init__(self, execution_id: str | None, et: EventType) -> None: - self.event_type = et - self.execution_id = execution_id +class _DummyEvent(BaseEvent): + """Dummy event for testing.""" + execution_id: str | None = None + topic: ClassVar[KafkaTopic] = KafkaTopic.EXECUTION_EVENTS - def model_dump(self) -> dict: + def model_dump(self, **kwargs: object) -> dict[str, str | None]: return {"execution_id": self.execution_id} @pytest.mark.asyncio async def test_register_and_route_events_without_kafka() -> None: # Build the bridge but don't call start(); directly test routing handlers + fake_bus = _FakeBus() + mock_settings = MagicMock(spec=Settings) + mock_settings.KAFKA_BOOTSTRAP_SERVERS = "kafka:9092" + mock_settings.SSE_CONSUMER_POOL_SIZE = 1 + bridge = SSEKafkaRedisBridge( - schema_registry=_FakeSchema(), - settings=_FakeSettings(), - event_metrics=_FakeEventMetrics(), - sse_bus=_FakeBus(), + schema_registry=MagicMock(spec=SchemaRegistryManager), + settings=mock_settings, + event_metrics=MagicMock(spec=EventMetrics), + sse_bus=fake_bus, logger=_test_logger, ) - disp = _StubDispatcher() + disp = EventDispatcher(_test_logger) bridge._register_routing_handlers(disp) - assert EventType.EXECUTION_STARTED in disp.handlers + handlers = disp.get_handlers(EventType.EXECUTION_STARTED) + assert len(handlers) > 0 # Event without execution_id is ignored - h = disp.handlers[EventType.EXECUTION_STARTED] - await h(_DummyEvent(None, EventType.EXECUTION_STARTED)) - assert bridge.sse_bus.published == [] + h = handlers[0] + await h(_DummyEvent(event_type=EventType.EXECUTION_STARTED, metadata=_make_metadata(), execution_id=None)) + assert fake_bus.published == [] # Proper event is published - await h(_DummyEvent("exec-123", EventType.EXECUTION_STARTED)) - assert bridge.sse_bus.published and bridge.sse_bus.published[-1][0] == "exec-123" + await h(_DummyEvent(event_type=EventType.EXECUTION_STARTED, metadata=_make_metadata(), execution_id="exec-123")) + assert fake_bus.published and fake_bus.published[-1][0] == "exec-123" s = bridge.get_stats() assert s["num_consumers"] == 0 and s["is_running"] is False diff --git a/backend/tests/unit/services/sse/test_shutdown_manager.py b/backend/tests/unit/services/sse/test_shutdown_manager.py index 6db2190e..7025f15b 100644 --- a/backend/tests/unit/services/sse/test_shutdown_manager.py +++ b/backend/tests/unit/services/sse/test_shutdown_manager.py @@ -3,21 +3,26 @@ import pytest +from app.core.lifecycle import LifecycleEnabled from app.services.sse.sse_shutdown_manager import SSEShutdownManager _test_logger = logging.getLogger("test.services.sse.shutdown_manager") -class DummyRouter: +class _FakeRouter(LifecycleEnabled): + """Fake router that tracks whether aclose was called.""" + def __init__(self) -> None: + super().__init__() self.stopped = False + self._lifecycle_started = True # Simulate already-started router - async def aclose(self) -> None: + async def _on_stop(self) -> None: self.stopped = True @pytest.mark.asyncio -async def test_shutdown_graceful_notify_and_drain(): +async def test_shutdown_graceful_notify_and_drain() -> None: mgr = SSEShutdownManager(drain_timeout=1.0, notification_timeout=0.01, force_close_timeout=0.1, logger=_test_logger) # Register two connections and arrange that they unregister when notified @@ -25,7 +30,7 @@ async def test_shutdown_graceful_notify_and_drain(): ev2 = await mgr.register_connection("e1", "c2") assert ev1 is not None and ev2 is not None - async def on_shutdown(event, cid): # noqa: ANN001 + async def on_shutdown(event: asyncio.Event, cid: str) -> None: await asyncio.wait_for(event.wait(), timeout=0.5) await mgr.unregister_connection("e1", cid) @@ -41,9 +46,9 @@ async def on_shutdown(event, cid): # noqa: ANN001 @pytest.mark.asyncio -async def test_shutdown_force_close_calls_router_stop_and_rejects_new(): +async def test_shutdown_force_close_calls_router_stop_and_rejects_new() -> None: mgr = SSEShutdownManager(drain_timeout=0.01, notification_timeout=0.01, force_close_timeout=0.01, logger=_test_logger) - router = DummyRouter() + router = _FakeRouter() mgr.set_router(router) # Register a connection but never unregister -> force close path @@ -63,7 +68,7 @@ async def test_shutdown_force_close_calls_router_stop_and_rejects_new(): @pytest.mark.asyncio -async def test_get_shutdown_status_transitions(): +async def test_get_shutdown_status_transitions() -> None: m = SSEShutdownManager(drain_timeout=0.01, notification_timeout=0.0, force_close_timeout=0.0, logger=_test_logger) st0 = m.get_shutdown_status() assert st0.phase == "ready" diff --git a/backend/tests/unit/services/sse/test_sse_service.py b/backend/tests/unit/services/sse/test_sse_service.py index 63299b4e..b6526da4 100644 --- a/backend/tests/unit/services/sse/test_sse_service.py +++ b/backend/tests/unit/services/sse/test_sse_service.py @@ -1,32 +1,37 @@ import asyncio +import json import logging from datetime import datetime, timezone -from typing import Any, Type +from typing import Any +from unittest.mock import MagicMock import pytest +from app.db.repositories.sse_repository import SSERepository +from app.domain.enums.events import EventType +from app.domain.enums.execution import ExecutionStatus +from app.domain.execution import DomainExecution, ResourceUsageDomain +from app.domain.sse import ShutdownStatus, SSEExecutionStatusDomain, SSEHealthDomain +from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge +from app.services.sse.redis_bus import SSERedisBus, SSERedisSubscription +from app.services.sse.sse_service import SSEService +from app.services.sse.sse_shutdown_manager import SSEShutdownManager +from app.settings import Settings from pydantic import BaseModel pytestmark = pytest.mark.unit _test_logger = logging.getLogger("test.services.sse.sse_service") -from app.domain.enums.events import EventType -from app.domain.execution import DomainExecution, ResourceUsageDomain -from app.domain.sse import ShutdownStatus, SSEHealthDomain -from app.schemas_pydantic.sse import RedisNotificationMessage, RedisSSEMessage -from app.services.sse.sse_service import SSEService - -T = Any # TypeVar for fake - -class _FakeSubscription: +class _FakeSubscription(SSERedisSubscription): def __init__(self) -> None: + # Skip parent __init__ - no real Redis pubsub self._q: asyncio.Queue[dict[str, Any] | None] = asyncio.Queue() self.closed = False - async def get(self, model: Type[BaseModel], timeout: float = 0.5) -> T | None: + async def get[T: BaseModel](self, model: type[T]) -> T | None: try: - raw = await asyncio.wait_for(self._q.get(), timeout=timeout) + raw = await asyncio.wait_for(self._q.get(), timeout=0.5) if raw is None: return None return model.model_validate(raw) @@ -35,54 +40,55 @@ async def get(self, model: Type[BaseModel], timeout: float = 0.5) -> T | None: except Exception: return None - async def push(self, msg: dict[str, Any]) -> None: + async def push(self, msg: dict[str, Any] | None) -> None: self._q.put_nowait(msg) async def close(self) -> None: self.closed = True -class _FakeBus: +class _FakeBus(SSERedisBus): def __init__(self) -> None: + # Skip parent __init__ self.exec_sub = _FakeSubscription() self.notif_sub = _FakeSubscription() - async def open_subscription(self, execution_id: str) -> _FakeSubscription: # noqa: ARG002 + async def open_subscription(self, execution_id: str) -> SSERedisSubscription: # noqa: ARG002 return self.exec_sub - async def open_notification_subscription(self, user_id: str) -> _FakeSubscription: # noqa: ARG002 + async def open_notification_subscription(self, user_id: str) -> SSERedisSubscription: # noqa: ARG002 return self.notif_sub -class _FakeRepo: - class _Status: - def __init__(self, execution_id: str) -> None: - self.execution_id = execution_id - self.status = "running" - self.timestamp = datetime.now(timezone.utc).isoformat() - +class _FakeRepo(SSERepository): def __init__(self) -> None: + # Skip parent __init__ self.exec_for_result: DomainExecution | None = None - async def get_execution_status(self, execution_id: str) -> "_FakeRepo._Status": - return _FakeRepo._Status(execution_id) + async def get_execution_status(self, execution_id: str) -> SSEExecutionStatusDomain | None: + return SSEExecutionStatusDomain( + execution_id=execution_id, + status=ExecutionStatus.RUNNING, + timestamp=datetime.now(timezone.utc).isoformat(), + ) async def get_execution(self, execution_id: str) -> DomainExecution | None: # noqa: ARG002 return self.exec_for_result -class _FakeShutdown: +class _FakeShutdown(SSEShutdownManager): def __init__(self) -> None: + # Skip parent __init__ self._evt = asyncio.Event() self._initiated = False self.registered: list[tuple[str, str]] = [] self.unregistered: list[tuple[str, str]] = [] - async def register_connection(self, execution_id: str, connection_id: str): + async def register_connection(self, execution_id: str, connection_id: str) -> asyncio.Event: self.registered.append((execution_id, connection_id)) return self._evt - async def unregister_connection(self, execution_id: str, connection_id: str): + async def unregister_connection(self, execution_id: str, connection_id: str) -> None: self.unregistered.append((execution_id, connection_id)) def is_shutting_down(self) -> bool: @@ -102,19 +108,24 @@ def initiate(self) -> None: self._evt.set() -class _FakeSettings: - SSE_HEARTBEAT_INTERVAL = 0 # not used for execution; helpful for notification test - +class _FakeRouter(SSEKafkaRedisBridge): + def __init__(self) -> None: + # Skip parent __init__ + pass -class _FakeRouter: def get_stats(self) -> dict[str, int | bool]: return {"num_consumers": 3, "active_executions": 2, "is_running": True, "total_buffers": 0} -def _decode(evt: dict[str, Any]) -> dict[str, Any]: - import json +def _make_fake_settings() -> Settings: + mock = MagicMock(spec=Settings) + mock.SSE_HEARTBEAT_INTERVAL = 0 + return mock - return json.loads(evt["data"]) # type: ignore[index] + +def _decode(evt: dict[str, Any]) -> dict[str, Any]: + result: dict[str, Any] = json.loads(evt["data"]) + return result @pytest.mark.asyncio @@ -122,12 +133,17 @@ async def test_execution_stream_closes_on_failed_event() -> None: repo = _FakeRepo() bus = _FakeBus() sm = _FakeShutdown() - svc = SSEService(repository=repo, router=_FakeRouter(), sse_bus=bus, shutdown_manager=sm, settings=_FakeSettings(), logger=_test_logger) + svc = SSEService(repository=repo, router=_FakeRouter(), sse_bus=bus, shutdown_manager=sm, + settings=_make_fake_settings(), logger=_test_logger) agen = svc.create_execution_stream("exec-1", user_id="u1") first = await agen.__anext__() assert _decode(first)["event_type"] == "connected" + # Should emit subscribed after Redis subscription is ready + subscribed = await agen.__anext__() + assert _decode(subscribed)["event_type"] == "subscribed" + # Should emit initial status stat = await agen.__anext__() assert _decode(stat)["event_type"] == "status" @@ -148,7 +164,7 @@ async def test_execution_stream_result_stored_includes_result_payload() -> None: repo.exec_for_result = DomainExecution( execution_id="exec-2", script="", - status="completed", # type: ignore[arg-type] + status=ExecutionStatus.COMPLETED, stdout="out", stderr="", lang="python", @@ -159,10 +175,12 @@ async def test_execution_stream_result_stored_includes_result_payload() -> None: ) bus = _FakeBus() sm = _FakeShutdown() - svc = SSEService(repository=repo, router=_FakeRouter(), sse_bus=bus, shutdown_manager=sm, settings=_FakeSettings(), logger=_test_logger) + svc = SSEService(repository=repo, router=_FakeRouter(), sse_bus=bus, shutdown_manager=sm, + settings=_make_fake_settings(), logger=_test_logger) agen = svc.create_execution_stream("exec-2", user_id="u1") await agen.__anext__() # connected + await agen.__anext__() # subscribed await agen.__anext__() # status await bus.exec_sub.push({"event_type": EventType.RESULT_STORED, "execution_id": "exec-2", "data": {}}) @@ -180,14 +198,19 @@ async def test_notification_stream_connected_and_heartbeat_and_message() -> None repo = _FakeRepo() bus = _FakeBus() sm = _FakeShutdown() - settings = _FakeSettings() + settings = _make_fake_settings() settings.SSE_HEARTBEAT_INTERVAL = 0 # emit immediately - svc = SSEService(repository=repo, router=_FakeRouter(), sse_bus=bus, shutdown_manager=sm, settings=settings, logger=_test_logger) + svc = SSEService(repository=repo, router=_FakeRouter(), sse_bus=bus, shutdown_manager=sm, settings=settings, + logger=_test_logger) agen = svc.create_notification_stream("u1") connected = await agen.__anext__() assert _decode(connected)["event_type"] == "connected" + # Should emit subscribed after Redis subscription is ready + subscribed = await agen.__anext__() + assert _decode(subscribed)["event_type"] == "subscribed" + # With 0 interval, next yield should be heartbeat hb = await agen.__anext__() assert _decode(hb)["event_type"] == "heartbeat" @@ -209,7 +232,7 @@ async def test_notification_stream_connected_and_heartbeat_and_message() -> None # Stop the stream by initiating shutdown and advancing once more (loop checks flag) sm.initiate() # It may loop until it sees the flag; push a None to release get(timeout) - await bus.notif_sub.push(None) # type: ignore[arg-type] + await bus.notif_sub.push(None) # Give the generator a chance to observe the flag and finish with pytest.raises(StopAsyncIteration): await asyncio.wait_for(agen.__anext__(), timeout=0.2) @@ -217,7 +240,8 @@ async def test_notification_stream_connected_and_heartbeat_and_message() -> None @pytest.mark.asyncio async def test_health_status_shape() -> None: - svc = SSEService(repository=_FakeRepo(), router=_FakeRouter(), sse_bus=_FakeBus(), shutdown_manager=_FakeShutdown(), settings=_FakeSettings(), logger=_test_logger) + svc = SSEService(repository=_FakeRepo(), router=_FakeRouter(), sse_bus=_FakeBus(), shutdown_manager=_FakeShutdown(), + settings=_make_fake_settings(), logger=_test_logger) h = await svc.get_health_status() assert isinstance(h, SSEHealthDomain) assert h.active_consumers == 3 and h.active_executions == 2 diff --git a/backend/tests/unit/services/sse/test_sse_shutdown_manager.py b/backend/tests/unit/services/sse/test_sse_shutdown_manager.py index 4e7300b3..f97350c2 100644 --- a/backend/tests/unit/services/sse/test_sse_shutdown_manager.py +++ b/backend/tests/unit/services/sse/test_sse_shutdown_manager.py @@ -2,19 +2,24 @@ import logging import pytest +from app.core.lifecycle import LifecycleEnabled +from app.services.sse.sse_shutdown_manager import SSEShutdownManager +from tests.helpers.eventually import eventually pytestmark = pytest.mark.unit -from app.services.sse.sse_shutdown_manager import SSEShutdownManager - _test_logger = logging.getLogger("test.services.sse.sse_shutdown_manager") -class _FakeRouter: +class _FakeRouter(LifecycleEnabled): + """Fake router that tracks whether aclose was called.""" + def __init__(self) -> None: + super().__init__() self.stopped = False + self._lifecycle_started = True # Simulate already-started router - async def stop(self) -> None: + async def _on_stop(self) -> None: self.stopped = True @@ -32,9 +37,7 @@ async def test_register_unregister_and_shutdown_flow() -> None: task = asyncio.create_task(mgr.initiate_shutdown()) # Wait until manager enters NOTIFYING phase (event-driven) - from tests.helpers.eventually import eventually - - async def _is_notifying(): + async def _is_notifying() -> bool: return mgr.get_shutdown_status().phase == "notifying" await eventually(_is_notifying, timeout=1.0, interval=0.02) @@ -51,16 +54,16 @@ async def _is_notifying(): @pytest.mark.asyncio async def test_reject_new_connection_during_shutdown() -> None: - mgr = SSEShutdownManager(drain_timeout=0.1, notification_timeout=0.01, force_close_timeout=0.01, logger=_test_logger) + mgr = SSEShutdownManager(drain_timeout=0.1, notification_timeout=0.01, force_close_timeout=0.01, + logger=_test_logger) # Pre-register one active connection to reflect realistic state e = await mgr.register_connection("e", "c0") assert e is not None # Start shutdown and wait until initiated t = asyncio.create_task(mgr.initiate_shutdown()) - from tests.helpers.eventually import eventually - async def _initiated(): + async def _initiated() -> None: assert mgr.is_shutting_down() is True await eventually(_initiated, timeout=1.0, interval=0.02) diff --git a/backend/uv.lock b/backend/uv.lock index 8bf078fe..ed11e615 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -1275,10 +1275,10 @@ dev = [ { name = "mypy-extensions", specifier = "==1.1.0" }, { name = "pipdeptree", specifier = "==2.23.4" }, { name = "pluggy", specifier = "==1.5.0" }, - { name = "pytest", specifier = "==8.3.3" }, + { name = "pytest", specifier = "==8.4.2" }, { name = "pytest-asyncio", specifier = "==1.3.0" }, { name = "pytest-cov", specifier = "==5.0.0" }, - { name = "pytest-env", specifier = ">=1.1.5" }, + { name = "pytest-env", specifier = "==1.2.0" }, { name = "pytest-xdist", specifier = "==3.6.1" }, { name = "ruff", specifier = "==0.14.10" }, { name = "types-cachetools", specifier = "==6.2.0.20250827" }, @@ -2436,17 +2436,18 @@ wheels = [ [[package]] name = "pytest" -version = "8.3.3" +version = "8.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "iniconfig" }, { name = "packaging" }, { name = "pluggy" }, + { name = "pygments" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8b/6c/62bbd536103af674e227c41a8f3dcd022d591f6eed5facb5a0f31ee33bbc/pytest-8.3.3.tar.gz", hash = "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181", size = 1442487, upload-time = "2024-09-10T10:52:15.003Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341, upload-time = "2024-09-10T10:52:12.54Z" }, + { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, ] [[package]] @@ -2477,14 +2478,14 @@ wheels = [ [[package]] name = "pytest-env" -version = "1.1.5" +version = "1.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1f/31/27f28431a16b83cab7a636dce59cf397517807d247caa38ee67d65e71ef8/pytest_env-1.1.5.tar.gz", hash = "sha256:91209840aa0e43385073ac464a554ad2947cc2fd663a9debf88d03b01e0cc1cf", size = 8911, upload-time = "2024-09-17T22:39:18.566Z" } +sdist = { url = "https://files.pythonhosted.org/packages/13/12/9c87d0ca45d5992473208bcef2828169fa7d39b8d7fc6e3401f5c08b8bf7/pytest_env-1.2.0.tar.gz", hash = "sha256:475e2ebe8626cee01f491f304a74b12137742397d6c784ea4bc258f069232b80", size = 8973, upload-time = "2025-10-09T19:15:47.42Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/de/b8/87cfb16045c9d4092cfcf526135d73b88101aac83bc1adcf82dfb5fd3833/pytest_env-1.1.5-py3-none-any.whl", hash = "sha256:ce90cf8772878515c24b31cd97c7fa1f4481cd68d588419fd45f10ecaee6bc30", size = 6141, upload-time = "2024-09-17T22:39:16.942Z" }, + { url = "https://files.pythonhosted.org/packages/27/98/822b924a4a3eb58aacba84444c7439fce32680592f394de26af9c76e2569/pytest_env-1.2.0-py3-none-any.whl", hash = "sha256:d7e5b7198f9b83c795377c09feefa45d56083834e60d04767efd64819fc9da00", size = 6251, upload-time = "2025-10-09T19:15:46.077Z" }, ] [[package]] diff --git a/backend/workers/dlq_processor.py b/backend/workers/dlq_processor.py index 711d1ff2..b03d3dc1 100644 --- a/backend/workers/dlq_processor.py +++ b/backend/workers/dlq_processor.py @@ -9,7 +9,7 @@ from app.db.docs import ALL_DOCUMENTS from app.dlq import DLQMessage, RetryPolicy, RetryStrategy from app.dlq.manager import DLQManager -from app.settings import Settings, get_settings +from app.settings import Settings from beanie import init_beanie @@ -100,10 +100,8 @@ async def alert_on_discard(message: DLQMessage, reason: str) -> None: manager.add_callback("on_discard", alert_on_discard) -async def main(settings: Settings | None = None) -> None: +async def main(settings: Settings) -> None: """Run the DLQ processor.""" - if settings is None: - settings = get_settings() container = create_dlq_processor_container(settings) logger = await container.get(logging.Logger) @@ -134,4 +132,4 @@ def signal_handler() -> None: if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(main(Settings())) diff --git a/backend/workers/run_coordinator.py b/backend/workers/run_coordinator.py index ef617444..12004bf1 100644 --- a/backend/workers/run_coordinator.py +++ b/backend/workers/run_coordinator.py @@ -10,14 +10,12 @@ from app.domain.enums.kafka import GroupId from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.services.coordinator.coordinator import ExecutionCoordinator -from app.settings import Settings, get_settings +from app.settings import Settings from beanie import init_beanie -async def run_coordinator(settings: Settings | None = None) -> None: +async def run_coordinator(settings: Settings) -> None: """Run the execution coordinator service.""" - if settings is None: - settings = get_settings() container = create_coordinator_container(settings) logger = await container.get(logging.Logger) @@ -54,7 +52,7 @@ async def run_coordinator(settings: Settings | None = None) -> None: def main() -> None: """Main entry point for coordinator worker""" - settings = get_settings() + settings = Settings() logger = setup_logger(settings.LOG_LEVEL) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") @@ -64,6 +62,7 @@ def main() -> None: if settings.ENABLE_TRACING: init_tracing( service_name=GroupId.EXECUTION_COORDINATOR, + settings=settings, logger=logger, service_version=settings.TRACING_SERVICE_VERSION, enable_console_exporter=False, diff --git a/backend/workers/run_event_replay.py b/backend/workers/run_event_replay.py index 949cf8af..95c38dad 100644 --- a/backend/workers/run_event_replay.py +++ b/backend/workers/run_event_replay.py @@ -9,7 +9,7 @@ from app.db.docs import ALL_DOCUMENTS from app.events.core import UnifiedProducer from app.services.event_replay.replay_service import EventReplayService -from app.settings import Settings, get_settings +from app.settings import Settings from beanie import init_beanie @@ -24,10 +24,8 @@ async def cleanup_task(replay_service: EventReplayService, logger: logging.Logge logger.error(f"Error during cleanup: {e}") -async def run_replay_service(settings: Settings | None = None) -> None: +async def run_replay_service(settings: Settings) -> None: """Run the event replay service with cleanup task.""" - if settings is None: - settings = get_settings() container = create_event_replay_container(settings) logger = await container.get(logging.Logger) @@ -61,7 +59,7 @@ async def _cancel_task() -> None: def main() -> None: """Main entry point for event replay service""" - settings = get_settings() + settings = Settings() logger = setup_logger(settings.LOG_LEVEL) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") @@ -71,6 +69,7 @@ def main() -> None: if settings.ENABLE_TRACING: init_tracing( service_name="event-replay", + settings=settings, logger=logger, service_version=settings.TRACING_SERVICE_VERSION, enable_console_exporter=False, diff --git a/backend/workers/run_k8s_worker.py b/backend/workers/run_k8s_worker.py index 49b945fa..d3b857ad 100644 --- a/backend/workers/run_k8s_worker.py +++ b/backend/workers/run_k8s_worker.py @@ -10,14 +10,12 @@ from app.domain.enums.kafka import GroupId from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.services.k8s_worker.worker import KubernetesWorker -from app.settings import Settings, get_settings +from app.settings import Settings from beanie import init_beanie -async def run_kubernetes_worker(settings: Settings | None = None) -> None: +async def run_kubernetes_worker(settings: Settings) -> None: """Run the Kubernetes worker service.""" - if settings is None: - settings = get_settings() container = create_k8s_worker_container(settings) logger = await container.get(logging.Logger) @@ -54,7 +52,7 @@ async def run_kubernetes_worker(settings: Settings | None = None) -> None: def main() -> None: """Main entry point for Kubernetes worker""" - settings = get_settings() + settings = Settings() logger = setup_logger(settings.LOG_LEVEL) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") @@ -64,6 +62,7 @@ def main() -> None: if settings.ENABLE_TRACING: init_tracing( service_name=GroupId.K8S_WORKER, + settings=settings, logger=logger, service_version=settings.TRACING_SERVICE_VERSION, enable_console_exporter=False, diff --git a/backend/workers/run_pod_monitor.py b/backend/workers/run_pod_monitor.py index 9c1fe09e..4b4dd325 100644 --- a/backend/workers/run_pod_monitor.py +++ b/backend/workers/run_pod_monitor.py @@ -10,16 +10,14 @@ from app.domain.enums.kafka import GroupId from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.services.pod_monitor.monitor import MonitorState, PodMonitor -from app.settings import Settings, get_settings +from app.settings import Settings from beanie import init_beanie RECONCILIATION_LOG_INTERVAL: int = 60 -async def run_pod_monitor(settings: Settings | None = None) -> None: +async def run_pod_monitor(settings: Settings) -> None: """Run the pod monitor service.""" - if settings is None: - settings = get_settings() container = create_pod_monitor_container(settings) logger = await container.get(logging.Logger) @@ -56,7 +54,7 @@ async def run_pod_monitor(settings: Settings | None = None) -> None: def main() -> None: """Main entry point for pod monitor worker""" - settings = get_settings() + settings = Settings() logger = setup_logger(settings.LOG_LEVEL) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") @@ -66,6 +64,7 @@ def main() -> None: if settings.ENABLE_TRACING: init_tracing( service_name=GroupId.POD_MONITOR, + settings=settings, logger=logger, service_version=settings.TRACING_SERVICE_VERSION, enable_console_exporter=False, diff --git a/backend/workers/run_result_processor.py b/backend/workers/run_result_processor.py index 0151ad9f..11cb7a72 100644 --- a/backend/workers/run_result_processor.py +++ b/backend/workers/run_result_processor.py @@ -13,14 +13,12 @@ from app.events.schema.schema_registry import SchemaRegistryManager from app.services.idempotency import IdempotencyManager from app.services.result_processor.processor import ProcessingState, ResultProcessor -from app.settings import Settings, get_settings +from app.settings import Settings from beanie import init_beanie from pymongo.asynchronous.mongo_client import AsyncMongoClient -async def run_result_processor(settings: Settings | None = None) -> None: - if settings is None: - settings = get_settings() +async def run_result_processor(settings: Settings) -> None: db_client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 @@ -70,7 +68,7 @@ async def run_result_processor(settings: Settings | None = None) -> None: def main() -> None: """Main entry point for result processor worker""" - settings = get_settings() + settings = Settings() logger = setup_logger(settings.LOG_LEVEL) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") @@ -80,6 +78,7 @@ def main() -> None: if settings.ENABLE_TRACING: init_tracing( service_name=GroupId.RESULT_PROCESSOR, + settings=settings, logger=logger, service_version=settings.TRACING_SERVICE_VERSION, enable_console_exporter=False, diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py index 04ad8a8d..7fd0c359 100644 --- a/backend/workers/run_saga_orchestrator.py +++ b/backend/workers/run_saga_orchestrator.py @@ -10,14 +10,12 @@ from app.domain.enums.kafka import GroupId from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.services.saga import SagaOrchestrator -from app.settings import Settings, get_settings +from app.settings import Settings from beanie import init_beanie -async def run_saga_orchestrator(settings: Settings | None = None) -> None: +async def run_saga_orchestrator(settings: Settings) -> None: """Run the saga orchestrator.""" - if settings is None: - settings = get_settings() container = create_saga_orchestrator_container(settings) logger = await container.get(logging.Logger) @@ -54,7 +52,7 @@ async def run_saga_orchestrator(settings: Settings | None = None) -> None: def main() -> None: """Main entry point for saga orchestrator worker""" - settings = get_settings() + settings = Settings() logger = setup_logger(settings.LOG_LEVEL) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") @@ -64,6 +62,7 @@ def main() -> None: if settings.ENABLE_TRACING: init_tracing( service_name=GroupId.SAGA_ORCHESTRATOR, + settings=settings, logger=logger, service_version=settings.TRACING_SERVICE_VERSION, enable_console_exporter=False, diff --git a/frontend/src/components/NotificationCenter.svelte b/frontend/src/components/NotificationCenter.svelte index 21ed1930..d27c6227 100644 --- a/frontend/src/components/NotificationCenter.svelte +++ b/frontend/src/components/NotificationCenter.svelte @@ -91,8 +91,8 @@ try { const data = JSON.parse(event.data); - // Ignore heartbeat and connection messages - if (data.event === 'heartbeat' || data.event === 'connected') { + // Ignore heartbeat, connection, and subscription confirmation messages + if (data.event_type === 'heartbeat' || data.event_type === 'connected' || data.event_type === 'subscribed') { return; } diff --git a/frontend/src/lib/editor/execution.svelte.ts b/frontend/src/lib/editor/execution.svelte.ts index 3736604a..bf7b1804 100644 --- a/frontend/src/lib/editor/execution.svelte.ts +++ b/frontend/src/lib/editor/execution.svelte.ts @@ -65,7 +65,7 @@ export function createExecutionState() { const eventData = JSON.parse(event.data); const eventType = eventData?.event_type || eventData?.type; - if (eventType === 'heartbeat' || eventType === 'connected') return; + if (eventType === 'heartbeat' || eventType === 'connected' || eventType === 'subscribed') return; if (eventData.status) { phase = toExecutionPhase(eventData.status, phase);