Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,22 @@ services:
default:
aliases:
- processing_service
depends_on:
- celeryworker_ml

celeryworker_ml:
build:
context: ./processing_services/minimal
command: ./celery_worker/start_celery.sh
environment:
- CELERY_BROKER_URL=amqp://rabbituser:rabbitpass@rabbitmq:5672//
extra_hosts:
- minio:host-gateway
networks:
- antenna_network
volumes: # fixes drift issue
- /etc/localtime:/etc/localtime:ro
- /etc/timezone:/etc/timezone:ro

networks:
antenna_network:
Expand Down
2 changes: 1 addition & 1 deletion processing_services/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ If your goal is to run an ML backend locally, simply copy the `example` app and

1. Update `processing_services/example/requirements.txt` with required packages (i.e. PyTorch, etc)
2. Rebuild container to install updated dependencies. Start the minimal and example ML backends: `docker compose -f processing_services/docker-compose.yml up -d --build ml_backend_example`
3. To test that everything works, register a new processing service in Antenna with endpoint URL http://ml_backend_example:2000. All ML backends are connected to the main docker compose stack using the `antenna_network`.
3. To test that everything works, register a new processing service in Antenna with endpoint URL http://ml_backend_example:2000. All ML backends are connected to the main docker compose stack using the `ml_network`.


## Add Algorithms, Pipelines, and ML Backend/Processing Services
Expand Down
28 changes: 28 additions & 0 deletions processing_services/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@ services:
networks:
- antenna_network

celeryworker_minimal:
build:
context: ./minimal
command: ./celery_worker/start_celery.sh
environment:
- CELERY_BROKER_URL=amqp://rabbituser:rabbitpass@rabbitmq:5672//
extra_hosts:
- minio:host-gateway
networks:
- antenna_network
volumes: # fixes drift issue
- /etc/localtime:/etc/localtime:ro
- /etc/timezone:/etc/timezone:ro

ml_backend_example:
build:
context: ./example
Expand All @@ -23,6 +37,20 @@ services:
networks:
- antenna_network

celeryworker_example:
build:
context: ./example
command: ./celery_worker/start_celery.sh
environment:
- CELERY_BROKER_URL=amqp://rabbituser:rabbitpass@rabbitmq:5672//
extra_hosts:
- minio:host-gateway
networks:
- antenna_network
volumes: # fixes drift issue
- /etc/localtime:/etc/localtime:ro
- /etc/timezone:/etc/timezone:ro

networks:
antenna_network:
name: antenna_network
6 changes: 5 additions & 1 deletion processing_services/example/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
FROM python:3.11-slim

# Set up ml backend FastAPI
WORKDIR /app

COPY . /app

RUN pip install -r ./requirements.txt

RUN chmod +x ./celery_worker/start_celery.sh

CMD ["python", "/app/main.py"]
162 changes: 6 additions & 156 deletions processing_services/example/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,8 @@

import fastapi

from .pipelines import (
Pipeline,
ZeroShotHFClassifierPipeline,
ZeroShotObjectDetectorPipeline,
ZeroShotObjectDetectorWithConstantClassifierPipeline,
ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline,
)
from .schemas import (
AlgorithmConfigResponse,
Detection,
DetectionRequest,
PipelineRequest,
PipelineRequestConfigParameters,
PipelineResultsResponse,
ProcessingServiceInfoResponse,
SourceImage,
)
from .utils import is_base64, is_url
from .processing import pipeline_choices, pipelines, process_pipeline_request
from .schemas import PipelineRequest, PipelineResultsResponse, ProcessingServiceInfoResponse

# Configure root logger
logging.basicConfig(
Expand All @@ -35,18 +19,6 @@

app = fastapi.FastAPI()


pipelines: list[type[Pipeline]] = [
ZeroShotHFClassifierPipeline,
ZeroShotObjectDetectorPipeline,
ZeroShotObjectDetectorWithConstantClassifierPipeline,
ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline,
]
pipeline_choices: dict[str, type[Pipeline]] = {pipeline.config.slug: pipeline for pipeline in pipelines}
algorithm_choices: dict[str, AlgorithmConfigResponse] = {
algorithm.key: algorithm for pipeline in pipelines for algorithm in pipeline.config.algorithms
}

# -----------
# API endpoints
# -----------
Expand All @@ -63,7 +35,6 @@ async def info() -> ProcessingServiceInfoResponse:
name="Custom ML Backend",
description=("A template for running custom models locally."),
pipelines=[pipeline.config for pipeline in pipelines],
# algorithms=list(algorithm_choices.values()),
)
return info

Expand Down Expand Up @@ -92,134 +63,13 @@ async def readyz():

@app.post("/process", tags=["services"])
async def process(data: PipelineRequest) -> PipelineResultsResponse:
pipeline_slug = data.pipeline
request_config = data.config

source_images = [SourceImage(**img.model_dump()) for img in data.source_images]
# Open source images once before processing
for img in source_images:
img.open(raise_exception=True)

detections = create_detections(
source_images=source_images,
detection_requests=data.detections,
)

try:
Pipeline = pipeline_choices[pipeline_slug]
except KeyError:
raise fastapi.HTTPException(status_code=422, detail=f"Invalid pipeline choice: {pipeline_slug}")

pipeline_request_config = PipelineRequestConfigParameters(**dict(request_config)) if request_config else {}
try:
pipeline = Pipeline(
source_images=source_images,
request_config=pipeline_request_config,
existing_detections=detections,
)
pipeline.compile()
resp: PipelineResultsResponse = process_pipeline_request(data)
except Exception as e:
logger.error(f"Error compiling pipeline: {e}")
raise fastapi.HTTPException(status_code=422, detail=f"{e}")

try:
response = pipeline.run()
except Exception as e:
logger.error(f"Error running pipeline: {e}")
raise fastapi.HTTPException(status_code=422, detail=f"{e}")

return response


# -----------
# Helper functions
# -----------

logger.error(f"Error processing pipeline request: {e}")
raise fastapi.HTTPException(status_code=422, detail=str(e))

def create_detections(
source_images: list[SourceImage],
detection_requests: list[DetectionRequest] | None,
):
if not detection_requests:
return []

# Group detection requests by source image id
source_image_map = {img.id: img for img in source_images}
grouped_detection_requests = {}
for request in detection_requests:
if request.source_image.id not in grouped_detection_requests:
grouped_detection_requests[request.source_image.id] = []
grouped_detection_requests[request.source_image.id].append(request)

# Process each source image and its detection requests
detections = []
for source_image_id, requests in grouped_detection_requests.items():
if source_image_id not in source_image_map:
raise ValueError(
f"A detection request for source image {source_image_id} was received, "
"but no source image with that ID was provided."
)

logger.info(f"Processing existing detections for source image {source_image_id}.")

for request in requests:
source_image = source_image_map[source_image_id]
cropped_image_id = (
f"{source_image.id}-crop-{request.bbox.x1}-{request.bbox.y1}-{request.bbox.x2}-{request.bbox.y2}"
)
if not request.crop_image_url:
logger.info("Detection request does not have a crop_image_url, crop the original source image.")
assert source_image._pil is not None, "Source image must be opened before cropping."
cropped_image_pil = source_image._pil.crop(
(request.bbox.x1, request.bbox.y1, request.bbox.x2, request.bbox.y2)
)
else:
try:
logger.info(f"Opening existing cropped image from {request.crop_image_url}.")
if is_url(request.crop_image_url):
cropped_image = SourceImage(
id=cropped_image_id,
url=request.crop_image_url,
)
elif is_base64(request.crop_image_url):
logger.info("Decoding base64 cropped image.")
cropped_image = SourceImage(
id=cropped_image_id,
b64=request.crop_image_url,
)
else:
# Must be a filepath
cropped_image = SourceImage(
id=cropped_image_id,
filepath=request.crop_image_url,
)
cropped_image.open(raise_exception=True)
cropped_image_pil = cropped_image._pil
except Exception as e:
logger.warning(f"Error opening cropped image: {e}")
logger.info(f"Falling back to cropping the original source image {source_image_id}.")
assert source_image._pil is not None, "Source image must be opened before cropping."
cropped_image_pil = source_image._pil.crop(
(request.bbox.x1, request.bbox.y1, request.bbox.x2, request.bbox.y2)
)

# Create a Detection object
det = Detection(
source_image=SourceImage(
id=source_image.id,
url=source_image.url,
),
bbox=request.bbox,
id=cropped_image_id,
url=request.crop_image_url or source_image.url,
algorithm=request.algorithm,
)
# Set the _pil attribute to the cropped image
det._pil = cropped_image_pil
detections.append(det)
logger.info(f"Created detection {det.id} for source image {source_image_id}.")

return detections
return resp


if __name__ == "__main__":
Expand Down
Loading