Skip to content
29 changes: 29 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
},
{
"name": "Run worker",
"type": "debugpy",
"request": "launch",
"module": "trapdata.cli.base",
"args": ["worker"]
},
{
"name": "Run api",
"type": "debugpy",
"request": "launch",
"module": "trapdata.cli.base",
"args": ["api"]
}
]
}
53 changes: 34 additions & 19 deletions trapdata/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import enum
import time
from contextlib import asynccontextmanager

import fastapi
import pydantic
Expand Down Expand Up @@ -36,7 +37,18 @@
from .schemas import PipelineResultsResponse as PipelineResponse_
from .schemas import ProcessingServiceInfoResponse, SourceImage, SourceImageResponse

app = fastapi.FastAPI()

@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
# cache the service info to be built only once at startup
app.state.service_info = initialize_service_info()
logger.info("Initialized service info")
yield
# Shutdown event: Clean up resources (if necessary)
logger.info("Shutting down API")


app = fastapi.FastAPI(lifespan=lifespan)
app.add_middleware(GZipMiddleware)


Expand Down Expand Up @@ -157,13 +169,6 @@ def make_pipeline_config_response(
)


# @TODO This requires loading all models into memory! Can we avoid this?
Copy link
Author

@carlosgjs carlosgjs Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this to the fastapi initialization so we don't pay the cost when running ami worker

PIPELINE_CONFIGS = [
make_pipeline_config_response(classifier_class, slug=key)
for key, classifier_class in CLASSIFIER_CHOICES.items()
]


class PipelineRequest(PipelineRequest_):
pipeline: PipelineChoice = pydantic.Field(
description=PipelineRequest_.model_fields["pipeline"].description,
Expand Down Expand Up @@ -313,17 +318,7 @@ async def process(data: PipelineRequest) -> PipelineResponse:

@app.get("/info", tags=["services"])
async def info() -> ProcessingServiceInfoResponse:
info = ProcessingServiceInfoResponse(
name="Antenna Inference API",
description=(
"The primary endpoint for processing images for the Antenna platform. "
"This API provides access to multiple detection and classification "
"algorithms by multiple labs for processing images of moths."
),
pipelines=PIPELINE_CONFIGS,
# algorithms=list(algorithm_choices.values()),
)
return info
return app.state.service_info


# Check if the server is online
Expand Down Expand Up @@ -361,6 +356,26 @@ async def readyz():
# pass


def initialize_service_info() -> ProcessingServiceInfoResponse:
# @TODO This requires loading all models into memory! Can we avoid this?
pipeline_configs = [
make_pipeline_config_response(classifier_class, slug=key)
for key, classifier_class in CLASSIFIER_CHOICES.items()
]

_info = ProcessingServiceInfoResponse(
name="Antenna Inference API",
description=(
"The primary endpoint for processing images for the Antenna platform. "
"This API provides access to multiple detection and classification "
"algorithms by multiple labs for processing images of moths."
),
pipelines=pipeline_configs,
# algorithms=list(algorithm_choices.values()),
)
return _info


if __name__ == "__main__":
import uvicorn

Expand Down
241 changes: 240 additions & 1 deletion trapdata/api/datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
import typing
from io import BytesIO

import requests
import torch
import torch.utils.data
import torchvision
from PIL import Image

from trapdata.common.logs import logger

from .schemas import DetectionResponse, SourceImage
from .schemas import DetectionResponse, PipelineProcessingTask, SourceImage


class LocalizationImageDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -87,3 +91,238 @@ def __getitem__(self, idx):

# return (ids_batch, image_batch)
return (source_image.id, detection_idx), image_data


class RESTDataset(torch.utils.data.IterableDataset):
"""
An IterableDataset that fetches tasks from a REST API endpoint and loads images.

The dataset continuously polls the API for tasks, loads the associated images,
and yields them as PyTorch tensors along with metadata.
"""

def __init__(
self,
base_url: str,
job_id: int,
batch_size: int = 1,
image_transforms: typing.Optional[torchvision.transforms.Compose] = None,
auth_token: typing.Optional[str] = None,
):
"""
Initialize the REST dataset.

Args:
base_url: Base URL for the API (e.g., "http://localhost:8000")
job_id: The job ID to fetch tasks for
batch_size: Number of tasks to request per batch
image_transforms: Optional transforms to apply to loaded images
auth_token: API authentication token. If not provided, reads from
ANTENNA_API_TOKEN environment variable
"""
super().__init__()
self.base_url = base_url.rstrip("/")
self.job_id = job_id
self.batch_size = batch_size
self.image_transforms = image_transforms or torchvision.transforms.ToTensor()
self.auth_token = auth_token or os.environ.get("ANTENNA_API_TOKEN")

def _fetch_tasks(self) -> list[PipelineProcessingTask]:
"""
Fetch a batch of tasks from the REST API.

Returns:
List of task dictionaries from the API response
"""
url = f"{self.base_url}/api/v2/jobs/{self.job_id}/tasks"
params = {"batch": self.batch_size}

headers = {}
if self.auth_token:
headers["Authorization"] = f"Token {self.auth_token}"

try:
response = requests.get(
url,
params=params,
timeout=30,
headers=headers,
)
response.raise_for_status()
data = response.json()
tasks = [PipelineProcessingTask(**task) for task in data.get("tasks", [])]
return tasks
except requests.RequestException as e:
logger.error(f"Failed to fetch tasks from {url}: {e}")
return []

def _load_image(self, image_url: str) -> typing.Optional[torch.Tensor]:
"""
Load an image from a URL and convert it to a PyTorch tensor.

Args:
image_url: URL of the image to load

Returns:
Image as a PyTorch tensor, or None if loading failed
"""
try:
response = requests.get(image_url, timeout=30)
response.raise_for_status()
image = Image.open(BytesIO(response.content))

# Convert to RGB if necessary
if image.mode != "RGB":
image = image.convert("RGB")

# Apply transforms
image_tensor = self.image_transforms(image)
return image_tensor
except Exception as e:
logger.error(f"Failed to load image from {image_url}: {e}")
return None

def __iter__(self):
"""
Iterate over tasks from the REST API.

Yields:
Dictionary containing:
- image: PyTorch tensor of the loaded image
- reply_subject: Reply subject for the task
- batch_index: Index of the image in the batch
- job_id: Job ID
- image_id: Image ID
"""
try:
# Get worker info for debugging
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id if worker_info else 0
num_workers = worker_info.num_workers if worker_info else 1

logger.info(
f"Worker {worker_id}/{num_workers} starting iteration for job {self.job_id}"
)

while True:
tasks = self._fetch_tasks()
# _, t = log_time()
# _, t = t(f"Worker {worker_id}: Fetched {len(tasks)} tasks from API")

# If no tasks returned, dataset is finished
if not tasks:
logger.info(
f"Worker {worker_id}: No more tasks for job {self.job_id}, terminating"
)
break

for task in tasks:
errors = []
# Load the image
# _, t = log_time()
image_tensor = (
self._load_image(task.image_url) if task.image_url else None
)
# _, t = t(f"Loaded image from {image_url}")

if image_tensor is None:
errors.append("failed to load image")

if errors:
logger.warning(
f"Worker {worker_id}: Errors in task for image '{task.image_id}': {', '.join(errors)}"
)

# Yield the data row
row = {
"image": image_tensor,
"reply_subject": task.reply_subject,
"image_id": task.image_id,
"image_url": task.image_url,
}
if errors:
row["error"] = ("; ".join(errors) if errors else None,)
yield row

logger.info(f"Worker {worker_id}: Iterator finished")
except Exception as e:
logger.error(f"Worker {worker_id}: Exception in iterator: {e}")
raise


def rest_collate_fn(batch: list[dict]) -> dict:
"""
Custom collate function that separates failed and successful items.

Returns a dict with:
- images: List of valid tensors
- reply_subjects: List of reply subjects for valid images
- image_ids: List of image IDs for valid images
- image_urls: List of image URLs for valid images
- failed_items: List of dicts with metadata for failed items
"""
successful = []
failed = []

for item in batch:
if item["image"] is None or item.get("error"):
# Failed item
failed.append(
{
"reply_subject": item["reply_subject"],
"image_id": item["image_id"],
"image_url": item.get("image_url"),
"error": item.get("error", "Unknown error"),
}
)
else:
# Successful item
successful.append(item)

# Collate successful items
if successful:
result = {
"image": torch.stack([item["image"] for item in successful]),
"reply_subject": [item["reply_subject"] for item in successful],
"image_id": [item["image_id"] for item in successful],
"image_url": [item.get("image_url") for item in successful],
}
else:
# Empty batch - all failed
result = {
"reply_subject": [],
"image_id": [],
}

result["failed_items"] = failed

return result


def get_rest_dataloader(
job_id: int,
base_url: str = "http://localhost:8000",
batch_size: int = 4,
num_workers: int = 2,
auth_token: typing.Optional[str] = None,
) -> torch.utils.data.DataLoader:
"""
Args:
base_url: Base URL for the REST API (default: http://localhost:8000)
job_id: Job id to fetch tasks for (default: 11)
batch_size: Number of tasks/images per batch (default: 4)
num_workers: Number of DataLoader workers (default: 2)
"""
assert base_url is not None, "Base URL must be provided"
base_url = base_url.rstrip("/")

dataset = RESTDataset(
base_url=base_url, job_id=job_id, batch_size=batch_size, auth_token=auth_token
)

return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
collate_fn=rest_collate_fn,
)
Loading