|
| 1 | +"""Shared test fixtures and utilities for docling-jobkit tests.""" |
| 2 | + |
| 3 | +import logging |
| 4 | +from typing import Any |
| 5 | + |
| 6 | +import pytest_asyncio |
| 7 | +from aiohttp import web |
| 8 | + |
| 9 | +from docling_jobkit.datamodel.callback import ProgressKind |
| 10 | + |
| 11 | + |
| 12 | +def pytest_configure(config): |
| 13 | + """Configure logging for tests.""" |
| 14 | + logging.getLogger("docling").setLevel(logging.INFO) |
| 15 | + |
| 16 | + |
| 17 | +class CallbackServer: |
| 18 | + """Mock HTTP server to capture callback invocations.""" |
| 19 | + |
| 20 | + def __init__(self): |
| 21 | + self.callbacks: list[dict[str, Any]] = [] |
| 22 | + self.app = web.Application() |
| 23 | + self.app.router.add_post("/callback", self.handle_callback) |
| 24 | + self.runner = None |
| 25 | + self.site = None |
| 26 | + |
| 27 | + async def handle_callback(self, request: web.Request) -> web.Response: |
| 28 | + """Handle incoming callback requests.""" |
| 29 | + data = await request.json() |
| 30 | + self.callbacks.append(data) |
| 31 | + logging.info(f"Received callback: {data.get('progress', {}).get('kind')}") |
| 32 | + return web.Response(status=200) |
| 33 | + |
| 34 | + async def start(self, port: int = 8765): |
| 35 | + """Start the callback server.""" |
| 36 | + self.runner = web.AppRunner(self.app) |
| 37 | + await self.runner.setup() |
| 38 | + self.site = web.TCPSite(self.runner, "localhost", port) |
| 39 | + await self.site.start() |
| 40 | + |
| 41 | + async def stop(self): |
| 42 | + """Stop the callback server.""" |
| 43 | + if self.site: |
| 44 | + await self.site.stop() |
| 45 | + if self.runner: |
| 46 | + await self.runner.cleanup() |
| 47 | + |
| 48 | + def get_callbacks_by_kind(self, kind: ProgressKind) -> list[dict[str, Any]]: |
| 49 | + """Get all callbacks of a specific kind.""" |
| 50 | + return [ |
| 51 | + cb for cb in self.callbacks if cb.get("progress", {}).get("kind") == kind |
| 52 | + ] |
| 53 | + |
| 54 | + |
| 55 | +@pytest_asyncio.fixture |
| 56 | +async def callback_server(): |
| 57 | + """Fixture to provide a mock callback server.""" |
| 58 | + server = CallbackServer() |
| 59 | + await server.start() |
| 60 | + yield server |
| 61 | + await server.stop() |
| 62 | + |
| 63 | + |
| 64 | +@pytest_asyncio.fixture |
| 65 | +async def callback_server_rq(): |
| 66 | + """Fixture to provide a mock callback server for RQ tests (different port).""" |
| 67 | + server = CallbackServer() |
| 68 | + await server.start(port=8766) |
| 69 | + yield server |
| 70 | + await server.stop() |
0 commit comments