Skip to content
Draft
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b60eab0
merge
carlos-irreverentlabs Jan 16, 2026
644927f
Merge remote-tracking branch 'upstream/main'
carlosgjs Jan 22, 2026
218f7aa
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 3, 2026
90da389
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 10, 2026
842e9b3
PSv2: Use connection pooling and retries for NATS
carlosgjs Feb 11, 2026
227a8db
Refactor and fix nats tests
carlosgjs Feb 11, 2026
2acf620
Tighten formatting
carlosgjs Feb 11, 2026
0632ce0
format
carlosgjs Feb 11, 2026
c5f8106
CR feedback
carlosgjs Feb 11, 2026
8805dbe
Apply suggestions from code review
carlosgjs Feb 11, 2026
8618d3c
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 13, 2026
c384199
refactor: simplify NATS connection handling — keep retry decorator, d…
mihow Feb 12, 2026
98a17f1
Merge branch 'main' into carlosg/natsconn
carlosgjs Feb 13, 2026
cf42506
revert: restore NATS connection pool — avoid per-operation connection…
mihow Feb 13, 2026
dc798ea
refactor: add switchable NATS connection strategies
mihow Feb 13, 2026
4d66c07
refactor: simplify NATS connection module — pool-only, archive original
mihow Feb 13, 2026
41bbeb3
docs: clarify where connection pool provides reuse vs. single-use
mihow Feb 13, 2026
ead53d1
fix: use `from None` to suppress noisy exception chain in _get_pool
mihow Feb 13, 2026
9737301
docs: update AGENTS.md test commands to use docker-compose.ci.yml
mihow Feb 13, 2026
fa0f84b
fix: correct mock setup in NATS task tests to match plain instantiation
mihow Feb 14, 2026
c7b2014
fix: address PR review feedback for NATS connection module
mihow Feb 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions .agents/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,29 +107,34 @@ docker compose restart django celeryworker

### Backend (Django)

Run tests:
Run tests (use `docker-compose.ci.yml` to avoid conflicts with the local dev stack):
```bash
docker compose -f docker-compose.ci.yml run --rm django python manage.py test
```

Run a specific test module:
```bash
docker compose run --rm django python manage.py test
docker compose -f docker-compose.ci.yml run --rm django python manage.py test ami.ml.orchestration.tests.test_nats_connection
```

Run specific test pattern:
```bash
docker compose run --rm django python manage.py test -k pattern
docker compose -f docker-compose.ci.yml run --rm django python manage.py test -k pattern
```

Run tests with debugger on failure:
```bash
docker compose run --rm django python manage.py test -k pattern --failfast --pdb
docker compose -f docker-compose.ci.yml run --rm django python manage.py test -k pattern --failfast --pdb
```

Speed up test development (reuse database):
```bash
docker compose run --rm django python manage.py test --keepdb
docker compose -f docker-compose.ci.yml run --rm django python manage.py test --keepdb
```

Run pytest (alternative test runner):
```bash
docker compose run --rm django pytest --ds=config.settings.test --reuse-db
docker compose -f docker-compose.ci.yml run --rm django pytest --ds=config.settings.test --reuse-db
```

Django shell:
Expand Down Expand Up @@ -654,13 +659,13 @@ images = SourceImage.objects.annotate(det_count=Count('detections'))

```bash
# Run specific test class
docker compose run --rm django python manage.py test ami.main.tests.test_models.ProjectTestCase
docker compose -f docker-compose.ci.yml run --rm django python manage.py test ami.main.tests.test_models.ProjectTestCase

# Run specific test method
docker compose run --rm django python manage.py test ami.main.tests.test_models.ProjectTestCase.test_project_creation
docker compose -f docker-compose.ci.yml run --rm django python manage.py test ami.main.tests.test_models.ProjectTestCase.test_project_creation

# Run with pattern matching
docker compose run --rm django python manage.py test -k test_detection
docker compose -f docker-compose.ci.yml run --rm django python manage.py test -k test_detection
```

### Pre-commit Hooks
Expand Down
4 changes: 2 additions & 2 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None:
try:

async def ack_task():
async with TaskQueueManager() as manager:
return await manager.acknowledge_task(reply_subject)
manager = TaskQueueManager()
return await manager.acknowledge_task(reply_subject)

ack_success = async_to_sync(ack_task)()

Expand Down
10 changes: 5 additions & 5 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,11 @@ def tasks(self, request, pk=None):

async def get_tasks():
tasks = []
async with TaskQueueManager() as manager:
for _ in range(batch):
task = await manager.reserve_task(job.pk, timeout=0.1)
if task:
tasks.append(task.dict())
manager = TaskQueueManager()
for _ in range(batch):
task = await manager.reserve_task(job.pk, timeout=0.1)
if task:
tasks.append(task.dict())
return tasks

# Use async_to_sync to properly handle the async call
Expand Down
36 changes: 18 additions & 18 deletions ami/ml/orchestration/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def cleanup_async_job_resources(job: "Job") -> bool:

# Cleanup NATS resources
async def cleanup():
async with TaskQueueManager() as manager:
return await manager.cleanup_job_resources(job.pk)
manager = TaskQueueManager()
return await manager.cleanup_job_resources(job.pk)

try:
nats_success = async_to_sync(cleanup)()
Expand Down Expand Up @@ -96,22 +96,22 @@ async def queue_all_images():
successful_queues = 0
failed_queues = 0

async with TaskQueueManager() as manager:
for image_pk, task in tasks:
try:
logger.info(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}")
success = await manager.publish_task(
job_id=job.pk,
data=task,
)
except Exception as e:
logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}")
success = False

if success:
successful_queues += 1
else:
failed_queues += 1
manager = TaskQueueManager()
for image_pk, task in tasks:
try:
logger.info(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}")
success = await manager.publish_task(
job_id=job.pk,
data=task,
)
except Exception as e:
logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}")
success = False

if success:
successful_queues += 1
else:
failed_queues += 1

return successful_queues, failed_queues

Expand Down
227 changes: 227 additions & 0 deletions ami/ml/orchestration/nats_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
"""
NATS connection management for both Celery workers and Django processes.

Provides a ConnectionPool keyed by event loop. The pool reuses a single NATS
connection for all async operations *within* one async_to_sync() boundary.
It does NOT provide reuse across separate async_to_sync() calls — each call
creates a new event loop, so a new connection is established.

Where the pool helps:
The main beneficiary is queue_images_to_nats() in jobs.py, which wraps
1000+ publish_task() awaits in a single async_to_sync() call. All of those
awaits share one event loop and therefore one NATS connection. Without the
pool, each publish would open its own TCP connection (~1500 per job).
Similarly, JobViewSet.tasks() batches multiple reserve_task() calls in one
async_to_sync() boundary.

Where it doesn't help:
Single-operation boundaries like _ack_task_via_nats() (one ACK per call)
get no reuse — the pool is effectively single-use there. The overhead is
negligible (one dict lookup), and the retry_on_connection_error decorator
provides resilience regardless.

Why keyed by event loop:
asyncio.Lock and nats.Client are bound to the loop they were created on.
Sharing them across loops causes "attached to a different loop" errors.
Keying by loop ensures isolation. WeakKeyDictionary auto-cleans when loops
are garbage collected, so short-lived loops don't leak.

Archived alternative:
ContextManagerConnection preserves the original pre-pool implementation
(one connection per `async with` block) as a drop-in fallback.
"""

import asyncio
import logging
import threading
from typing import TYPE_CHECKING
from weakref import WeakKeyDictionary

import nats
from django.conf import settings
from nats.js import JetStreamContext

if TYPE_CHECKING:
from nats.aio.client import Client as NATSClient

logger = logging.getLogger(__name__)


class ConnectionPool:
"""
Manages a single persistent NATS connection per event loop.

This is safe because:
- asyncio.Lock and NATS Client are bound to the event loop they were created on
- Each event loop gets its own isolated connection and lock
- Works correctly with async_to_sync() which creates per-thread event loops
- Prevents "attached to a different loop" errors in Celery tasks and Django views

Instantiating TaskQueueManager() is cheap — multiple instances share the same
underlying connection via this pool.
"""

def __init__(self):
self._nc: "NATSClient | None" = None
self._js: JetStreamContext | None = None
self._lock: asyncio.Lock | None = None # Lazy-initialized when needed

def _ensure_lock(self) -> asyncio.Lock:
"""Lazily create lock bound to current event loop."""
if self._lock is None:
self._lock = asyncio.Lock()
return self._lock

async def get_connection(self) -> tuple["NATSClient", JetStreamContext]:
"""
Get or create the event loop's NATS connection. Checks connection health
and recreates if stale.

Returns:
Tuple of (NATS connection, JetStream context)
Raises:
RuntimeError: If connection cannot be established
"""
# Fast path (no lock needed): connection exists, is open, and is connected.
# This is the hot path — most calls hit this and return immediately.
if self._nc is not None and self._js is not None and not self._nc.is_closed and self._nc.is_connected:
return self._nc, self._js

# Connection is stale or doesn't exist — clear references before reconnecting
if self._nc is not None:
logger.warning("NATS connection is closed or disconnected, will reconnect")
self._nc = None
self._js = None

# Slow path: acquire lock to prevent concurrent reconnection attempts
lock = self._ensure_lock()
async with lock:
# Double-check after acquiring lock (another coroutine may have reconnected)
if self._nc is not None and self._js is not None and not self._nc.is_closed and self._nc.is_connected:
return self._nc, self._js

nats_url = settings.NATS_URL
try:
logger.info(f"Creating NATS connection to {nats_url}")
self._nc = await nats.connect(nats_url)
self._js = self._nc.jetstream()
logger.info(f"Successfully connected to NATS at {nats_url}")
return self._nc, self._js
except Exception as e:
logger.error(f"Failed to connect to NATS: {e}")
raise RuntimeError(f"Could not establish NATS connection: {e}") from e

async def close(self):
"""Close the NATS connection if it exists."""
if self._nc is not None and not self._nc.is_closed:
logger.info("Closing NATS connection")
await self._nc.close()
self._nc = None
self._js = None

async def reset(self):
"""
Close the current connection and clear all state so the next call to
get_connection() creates a fresh one.

Called by retry_on_connection_error when an operation hits a connection
error (e.g. network blip, NATS restart). The lock is also cleared so it
gets recreated bound to the current event loop.
"""
logger.warning("Resetting NATS connection pool due to connection error")
if self._nc is not None:
try:
if not self._nc.is_closed:
await self._nc.close()
logger.debug("Successfully closed existing NATS connection during reset")
except Exception as e:
# Swallow errors - connection may already be broken
logger.debug(f"Error closing connection during reset (expected): {e}")
self._nc = None
self._js = None
self._lock = None # Clear lock so new one is created for fresh connection


class ContextManagerConnection:
"""
Archived pre-pool implementation: one NATS connection per `async with` block.

This was the original approach before the connection pool was added. It creates
a fresh connection on get_connection() and expects the caller to close it when
done. There is no connection reuse and no retry logic at this layer.

Trade-offs vs ConnectionPool:
- Simpler: no shared state, no locking, no event-loop keying
- Expensive: ~1500 TCP connections per 1000-image job vs 1 with the pool
- No automatic reconnection — caller must handle connection failures

Kept as a drop-in fallback. To switch, change the class used in
_create_pool() below from ConnectionPool to ContextManagerConnection.
"""

async def get_connection(self) -> tuple["NATSClient", JetStreamContext]:
"""Create a fresh NATS connection."""
nats_url = settings.NATS_URL
try:
logger.debug(f"Creating per-operation NATS connection to {nats_url}")
nc = await nats.connect(nats_url)
js = nc.jetstream()
return nc, js
except Exception as e:
logger.error(f"Failed to connect to NATS: {e}")
raise RuntimeError(f"Could not establish NATS connection: {e}") from e

async def close(self):
"""No-op — connections are not tracked."""
pass

async def reset(self):
"""No-op — connections are not tracked."""
pass


# Event-loop-keyed pools: one ConnectionPool per event loop.
# WeakKeyDictionary automatically cleans up when event loops are garbage collected.
_pools: WeakKeyDictionary[asyncio.AbstractEventLoop, ConnectionPool] = WeakKeyDictionary()
_pools_lock = threading.Lock()


def _get_pool() -> ConnectionPool:
"""Get or create the ConnectionPool for the current event loop."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
raise RuntimeError(
"get_connection() must be called from an async context with a running event loop. "
"If calling from sync code, use async_to_sync() to wrap the async function."
) from None

with _pools_lock:
if loop not in _pools:
_pools[loop] = ConnectionPool()
logger.debug(f"Created NATS connection pool for event loop {id(loop)}")
return _pools[loop]


async def get_connection() -> tuple["NATSClient", JetStreamContext]:
"""
Get or create a NATS connection for the current event loop.

Returns:
Tuple of (NATS connection, JetStream context)
Raises:
RuntimeError: If called outside of an async context (no running event loop)
"""
pool = _get_pool()
return await pool.get_connection()


async def reset_connection() -> None:
"""
Reset the NATS connection for the current event loop.

Closes the current connection and clears all state so the next call to
get_connection() creates a fresh one.
"""
pool = _get_pool()
await pool.reset()
Loading
Loading