diff --git a/docker-compose.yml b/docker-compose.yml index 69cea61a7..e21b75614 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -199,7 +199,22 @@ services: default: aliases: - processing_service + depends_on: + - celeryworker_ml + celeryworker_ml: + build: + context: ./processing_services/minimal + command: ./celery_worker/start_celery.sh + environment: + - CELERY_BROKER_URL=amqp://rabbituser:rabbitpass@rabbitmq:5672// + extra_hosts: + - minio:host-gateway + networks: + - antenna_network + volumes: # fixes drift issue + - /etc/localtime:/etc/localtime:ro + - /etc/timezone:/etc/timezone:ro networks: antenna_network: diff --git a/processing_services/README.md b/processing_services/README.md index 48bb254e7..ea0e960d3 100644 --- a/processing_services/README.md +++ b/processing_services/README.md @@ -20,7 +20,7 @@ If your goal is to run an ML backend locally, simply copy the `example` app and 1. Update `processing_services/example/requirements.txt` with required packages (i.e. PyTorch, etc) 2. Rebuild container to install updated dependencies. Start the minimal and example ML backends: `docker compose -f processing_services/docker-compose.yml up -d --build ml_backend_example` -3. To test that everything works, register a new processing service in Antenna with endpoint URL http://ml_backend_example:2000. All ML backends are connected to the main docker compose stack using the `antenna_network`. +3. To test that everything works, register a new processing service in Antenna with endpoint URL http://ml_backend_example:2000. All ML backends are connected to the main docker compose stack using the `ml_network`. ## Add Algorithms, Pipelines, and ML Backend/Processing Services diff --git a/processing_services/docker-compose.yml b/processing_services/docker-compose.yml index 91a21c100..7bdec1389 100644 --- a/processing_services/docker-compose.yml +++ b/processing_services/docker-compose.yml @@ -11,6 +11,20 @@ services: networks: - antenna_network + celeryworker_minimal: + build: + context: ./minimal + command: ./celery_worker/start_celery.sh + environment: + - CELERY_BROKER_URL=amqp://rabbituser:rabbitpass@rabbitmq:5672// + extra_hosts: + - minio:host-gateway + networks: + - antenna_network + volumes: # fixes drift issue + - /etc/localtime:/etc/localtime:ro + - /etc/timezone:/etc/timezone:ro + ml_backend_example: build: context: ./example @@ -23,6 +37,20 @@ services: networks: - antenna_network + celeryworker_example: + build: + context: ./example + command: ./celery_worker/start_celery.sh + environment: + - CELERY_BROKER_URL=amqp://rabbituser:rabbitpass@rabbitmq:5672// + extra_hosts: + - minio:host-gateway + networks: + - antenna_network + volumes: # fixes drift issue + - /etc/localtime:/etc/localtime:ro + - /etc/timezone:/etc/timezone:ro + networks: antenna_network: name: antenna_network diff --git a/processing_services/example/Dockerfile b/processing_services/example/Dockerfile index 3e0781f92..7128404d7 100644 --- a/processing_services/example/Dockerfile +++ b/processing_services/example/Dockerfile @@ -1,7 +1,11 @@ FROM python:3.11-slim -# Set up ml backend FastAPI WORKDIR /app + COPY . /app + RUN pip install -r ./requirements.txt + +RUN chmod +x ./celery_worker/start_celery.sh + CMD ["python", "/app/main.py"] diff --git a/processing_services/example/api/api.py b/processing_services/example/api/api.py index 79ce5d83c..7522de6f2 100644 --- a/processing_services/example/api/api.py +++ b/processing_services/example/api/api.py @@ -6,24 +6,8 @@ import fastapi -from .pipelines import ( - Pipeline, - ZeroShotHFClassifierPipeline, - ZeroShotObjectDetectorPipeline, - ZeroShotObjectDetectorWithConstantClassifierPipeline, - ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline, -) -from .schemas import ( - AlgorithmConfigResponse, - Detection, - DetectionRequest, - PipelineRequest, - PipelineRequestConfigParameters, - PipelineResultsResponse, - ProcessingServiceInfoResponse, - SourceImage, -) -from .utils import is_base64, is_url +from .processing import pipeline_choices, pipelines, process_pipeline_request +from .schemas import PipelineRequest, PipelineResultsResponse, ProcessingServiceInfoResponse # Configure root logger logging.basicConfig( @@ -35,18 +19,6 @@ app = fastapi.FastAPI() - -pipelines: list[type[Pipeline]] = [ - ZeroShotHFClassifierPipeline, - ZeroShotObjectDetectorPipeline, - ZeroShotObjectDetectorWithConstantClassifierPipeline, - ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline, -] -pipeline_choices: dict[str, type[Pipeline]] = {pipeline.config.slug: pipeline for pipeline in pipelines} -algorithm_choices: dict[str, AlgorithmConfigResponse] = { - algorithm.key: algorithm for pipeline in pipelines for algorithm in pipeline.config.algorithms -} - # ----------- # API endpoints # ----------- @@ -63,7 +35,6 @@ async def info() -> ProcessingServiceInfoResponse: name="Custom ML Backend", description=("A template for running custom models locally."), pipelines=[pipeline.config for pipeline in pipelines], - # algorithms=list(algorithm_choices.values()), ) return info @@ -92,134 +63,13 @@ async def readyz(): @app.post("/process", tags=["services"]) async def process(data: PipelineRequest) -> PipelineResultsResponse: - pipeline_slug = data.pipeline - request_config = data.config - - source_images = [SourceImage(**img.model_dump()) for img in data.source_images] - # Open source images once before processing - for img in source_images: - img.open(raise_exception=True) - - detections = create_detections( - source_images=source_images, - detection_requests=data.detections, - ) - - try: - Pipeline = pipeline_choices[pipeline_slug] - except KeyError: - raise fastapi.HTTPException(status_code=422, detail=f"Invalid pipeline choice: {pipeline_slug}") - - pipeline_request_config = PipelineRequestConfigParameters(**dict(request_config)) if request_config else {} try: - pipeline = Pipeline( - source_images=source_images, - request_config=pipeline_request_config, - existing_detections=detections, - ) - pipeline.compile() + resp: PipelineResultsResponse = process_pipeline_request(data) except Exception as e: - logger.error(f"Error compiling pipeline: {e}") - raise fastapi.HTTPException(status_code=422, detail=f"{e}") - - try: - response = pipeline.run() - except Exception as e: - logger.error(f"Error running pipeline: {e}") - raise fastapi.HTTPException(status_code=422, detail=f"{e}") - - return response - - -# ----------- -# Helper functions -# ----------- - + logger.error(f"Error processing pipeline request: {e}") + raise fastapi.HTTPException(status_code=422, detail=str(e)) -def create_detections( - source_images: list[SourceImage], - detection_requests: list[DetectionRequest] | None, -): - if not detection_requests: - return [] - - # Group detection requests by source image id - source_image_map = {img.id: img for img in source_images} - grouped_detection_requests = {} - for request in detection_requests: - if request.source_image.id not in grouped_detection_requests: - grouped_detection_requests[request.source_image.id] = [] - grouped_detection_requests[request.source_image.id].append(request) - - # Process each source image and its detection requests - detections = [] - for source_image_id, requests in grouped_detection_requests.items(): - if source_image_id not in source_image_map: - raise ValueError( - f"A detection request for source image {source_image_id} was received, " - "but no source image with that ID was provided." - ) - - logger.info(f"Processing existing detections for source image {source_image_id}.") - - for request in requests: - source_image = source_image_map[source_image_id] - cropped_image_id = ( - f"{source_image.id}-crop-{request.bbox.x1}-{request.bbox.y1}-{request.bbox.x2}-{request.bbox.y2}" - ) - if not request.crop_image_url: - logger.info("Detection request does not have a crop_image_url, crop the original source image.") - assert source_image._pil is not None, "Source image must be opened before cropping." - cropped_image_pil = source_image._pil.crop( - (request.bbox.x1, request.bbox.y1, request.bbox.x2, request.bbox.y2) - ) - else: - try: - logger.info(f"Opening existing cropped image from {request.crop_image_url}.") - if is_url(request.crop_image_url): - cropped_image = SourceImage( - id=cropped_image_id, - url=request.crop_image_url, - ) - elif is_base64(request.crop_image_url): - logger.info("Decoding base64 cropped image.") - cropped_image = SourceImage( - id=cropped_image_id, - b64=request.crop_image_url, - ) - else: - # Must be a filepath - cropped_image = SourceImage( - id=cropped_image_id, - filepath=request.crop_image_url, - ) - cropped_image.open(raise_exception=True) - cropped_image_pil = cropped_image._pil - except Exception as e: - logger.warning(f"Error opening cropped image: {e}") - logger.info(f"Falling back to cropping the original source image {source_image_id}.") - assert source_image._pil is not None, "Source image must be opened before cropping." - cropped_image_pil = source_image._pil.crop( - (request.bbox.x1, request.bbox.y1, request.bbox.x2, request.bbox.y2) - ) - - # Create a Detection object - det = Detection( - source_image=SourceImage( - id=source_image.id, - url=source_image.url, - ), - bbox=request.bbox, - id=cropped_image_id, - url=request.crop_image_url or source_image.url, - algorithm=request.algorithm, - ) - # Set the _pil attribute to the cropped image - det._pil = cropped_image_pil - detections.append(det) - logger.info(f"Created detection {det.id} for source image {source_image_id}.") - - return detections + return resp if __name__ == "__main__": diff --git a/processing_services/example/api/processing.py b/processing_services/example/api/processing.py new file mode 100644 index 000000000..1bc2e2651 --- /dev/null +++ b/processing_services/example/api/processing.py @@ -0,0 +1,170 @@ +import logging + +from .pipelines import ( + Pipeline, + ZeroShotHFClassifierPipeline, + ZeroShotObjectDetectorPipeline, + ZeroShotObjectDetectorWithConstantClassifierPipeline, + ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline, +) +from .schemas import ( + Detection, + DetectionRequest, + PipelineRequest, + PipelineRequestConfigParameters, + PipelineResultsResponse, + SourceImage, +) +from .utils import is_base64, is_url + +# Get the root logger +logger = logging.getLogger(__name__) + +pipelines: list[type[Pipeline]] = [ + ZeroShotHFClassifierPipeline, + ZeroShotObjectDetectorPipeline, + ZeroShotObjectDetectorWithConstantClassifierPipeline, + ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline, +] +pipeline_choices: dict[str, type[Pipeline]] = {pipeline.config.slug: pipeline for pipeline in pipelines} + + +def process_pipeline_request(data: PipelineRequest) -> PipelineResultsResponse: + """ + Process a pipeline request. + + Args: + data (PipelineRequest): The request data containing pipeline configuration and source images. + + Returns: + PipelineResultsResponse: The response containing the results of the pipeline processing. + """ + logger.info(f"Processing pipeline request for pipeline: {data.pipeline}") + pipeline_slug = data.pipeline + request_config = data.config + + source_images = [SourceImage(**img.model_dump()) for img in data.source_images] + # Open source images once before processing + for img in source_images: + img.open(raise_exception=True) + + detections = create_detections( + source_images=source_images, + detection_requests=data.detections, + ) + + try: + Pipeline = pipeline_choices[pipeline_slug] + except KeyError: + raise ValueError(f"Invalid pipeline choice: {pipeline_slug}") + + pipeline_request_config = PipelineRequestConfigParameters(**dict(request_config)) if request_config else {} + try: + pipeline = Pipeline( + source_images=source_images, + request_config=pipeline_request_config, + existing_detections=detections, + ) + pipeline.compile() + except Exception as e: + logger.error(f"Error compiling pipeline: {e}") + raise Exception(f"Error compiling pipeline: {e}") + + try: + response = pipeline.run() + except Exception as e: + logger.error(f"Error running pipeline: {e}") + raise Exception(f"Error running pipeline: {e}") + + return response + + +# ----------- +# Helper functions +# ----------- + + +def create_detections( + source_images: list[SourceImage], + detection_requests: list[DetectionRequest] | None, +): + if not detection_requests: + return [] + + # Group detection requests by source image id + source_image_map = {img.id: img for img in source_images} + grouped_detection_requests = {} + for request in detection_requests: + if request.source_image.id not in grouped_detection_requests: + grouped_detection_requests[request.source_image.id] = [] + grouped_detection_requests[request.source_image.id].append(request) + + # Process each source image and its detection requests + detections = [] + for source_image_id, requests in grouped_detection_requests.items(): + if source_image_id not in source_image_map: + raise ValueError( + f"A detection request for source image {source_image_id} was received, " + "but no source image with that ID was provided." + ) + + logger.info(f"Processing existing detections for source image {source_image_id}.") + + for request in requests: + source_image = source_image_map[source_image_id] + cropped_image_id = ( + f"{source_image.id}-crop-{request.bbox.x1}-{request.bbox.y1}-{request.bbox.x2}-{request.bbox.y2}" + ) + if not request.crop_image_url: + logger.info("Detection request does not have a crop_image_url, crop the original source image.") + assert source_image._pil is not None, "Source image must be opened before cropping." + cropped_image_pil = source_image._pil.crop( + (request.bbox.x1, request.bbox.y1, request.bbox.x2, request.bbox.y2) + ) + else: + try: + logger.info(f"Opening existing cropped image from {request.crop_image_url}.") + if is_url(request.crop_image_url): + cropped_image = SourceImage( + id=cropped_image_id, + url=request.crop_image_url, + ) + elif is_base64(request.crop_image_url): + logger.info("Decoding base64 cropped image.") + cropped_image = SourceImage( + id=cropped_image_id, + b64=request.crop_image_url, + ) + else: + # Must be a filepath + cropped_image = SourceImage( + id=cropped_image_id, + filepath=request.crop_image_url, + ) + cropped_image.open(raise_exception=True) + cropped_image_pil = cropped_image._pil + except Exception as e: + logger.warning(f"Error opening cropped image: {e}") + logger.info(f"Falling back to cropping the original source image {source_image_id}.") + assert source_image._pil is not None, "Source image must be opened before cropping." + cropped_image_pil = source_image._pil.crop( + (request.bbox.x1, request.bbox.y1, request.bbox.x2, request.bbox.y2) + ) + + # Create a Detection object + det = Detection( + source_image=SourceImage( + id=source_image.id, + url=source_image.url, + ), + bbox=request.bbox, + id=cropped_image_id, + url=request.crop_image_url or source_image.url, + algorithm=request.algorithm, + ) + # Set the _pil attribute to the cropped image + det._pil = cropped_image_pil + detections.append(det) + logger.info(f"Created detection {det.id} for source image {source_image_id}.") + + return detections diff --git a/processing_services/example/celery_worker/__init__.py b/processing_services/example/celery_worker/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/processing_services/example/celery_worker/get_queues.py b/processing_services/example/celery_worker/get_queues.py new file mode 100644 index 000000000..50ed498f3 --- /dev/null +++ b/processing_services/example/celery_worker/get_queues.py @@ -0,0 +1,9 @@ +from typing import get_args + +from api.schemas import PipelineChoice + +if __name__ == "__main__": + pipeline_names = get_args(PipelineChoice) + queue_names = [f"ml-pipeline-{name}" for name in pipeline_names] + queues = ",".join(queue_names) + print(queues) diff --git a/processing_services/example/celery_worker/start_celery.sh b/processing_services/example/celery_worker/start_celery.sh new file mode 100644 index 000000000..513904ea0 --- /dev/null +++ b/processing_services/example/celery_worker/start_celery.sh @@ -0,0 +1,7 @@ +#!/bin/bash +set -e + +QUEUES=$(python -m celery_worker.get_queues) + +echo "Starting Celery with queues: $QUEUES" +celery -A celery_worker.worker worker --queues="$QUEUES" --loglevel=info --pool=solo # --concurrency=1 diff --git a/processing_services/example/celery_worker/worker.py b/processing_services/example/celery_worker/worker.py new file mode 100644 index 000000000..905aa5fcc --- /dev/null +++ b/processing_services/example/celery_worker/worker.py @@ -0,0 +1,31 @@ +from typing import get_args + +from api.processing import process_pipeline_request as process +from api.schemas import PipelineChoice, PipelineRequest, PipelineResultsResponse +from celery import Celery +from kombu import Queue + +celery_app = Celery( + "example_worker", + broker="amqp://user:password@rabbitmq:5672//", + backend="redis://redis:6379/0", +) + +PIPELINES: list[PipelineChoice] = list(get_args(PipelineChoice)) +QUEUE_NAMES = [f"ml-pipeline-{name}" for name in PIPELINES] + +celery_app.conf.task_queues = [Queue(name=queue_name) for queue_name in QUEUE_NAMES] + +celery_app.conf.update(task_default_exchange="pipeline", task_default_exchange_type="direct") + + +@celery_app.task(name="process_pipeline_request", soft_time_limit=60 * 4, time_limit=60 * 5) +def process_pipeline_request(pipeline_request: dict, project_id: int) -> dict: + print(f"Running pipeline on: {pipeline_request}") + request_data = PipelineRequest(**pipeline_request) + resp: PipelineResultsResponse = process(request_data) + return resp.dict() + + +# Don't really need this? unless we auto-discover tasks if apps use `@celery_app.task` and define __init__.py +celery_app.autodiscover_tasks() diff --git a/processing_services/example/requirements.txt b/processing_services/example/requirements.txt index eccbee47a..d7a318cc3 100644 --- a/processing_services/example/requirements.txt +++ b/processing_services/example/requirements.txt @@ -7,3 +7,5 @@ transformers==4.50.3 torch==2.6.0 torchvision==0.21.0 scipy==1.16.0 +celery==5.4.0 +redis==5.2.1 diff --git a/processing_services/minimal/Dockerfile b/processing_services/minimal/Dockerfile index 0686b4471..7128404d7 100644 --- a/processing_services/minimal/Dockerfile +++ b/processing_services/minimal/Dockerfile @@ -6,4 +6,6 @@ COPY . /app RUN pip install -r ./requirements.txt +RUN chmod +x ./celery_worker/start_celery.sh + CMD ["python", "/app/main.py"] diff --git a/processing_services/minimal/api/api.py b/processing_services/minimal/api/api.py index 400156894..617e55dfc 100644 --- a/processing_services/minimal/api/api.py +++ b/processing_services/minimal/api/api.py @@ -3,22 +3,11 @@ """ import logging -import time import fastapi -from .pipelines import ConstantPipeline, Pipeline, RandomDetectionRandomSpeciesPipeline -from .schemas import ( - AlgorithmConfigResponse, - Detection, - DetectionRequest, - PipelineRequest, - PipelineResultsResponse, - ProcessingServiceInfoResponse, - SourceImage, - SourceImageResponse, -) -from .utils import is_base64, is_url +from .processing import pipeline_choices, pipelines, process_pipeline_request +from .schemas import PipelineRequest, PipelineResultsResponse, ProcessingServiceInfoResponse # Configure root logger logging.basicConfig( @@ -30,12 +19,9 @@ app = fastapi.FastAPI() - -pipelines: list[type[Pipeline]] = [ConstantPipeline, RandomDetectionRandomSpeciesPipeline] -pipeline_choices: dict[str, type[Pipeline]] = {pipeline.config.slug: pipeline for pipeline in pipelines} -algorithm_choices: dict[str, AlgorithmConfigResponse] = { - algorithm.key: algorithm for pipeline in pipelines for algorithm in pipeline.config.algorithms -} +# ----------- +# API endpoints +# ----------- @app.get("/") @@ -47,12 +33,8 @@ async def root(): async def info() -> ProcessingServiceInfoResponse: info = ProcessingServiceInfoResponse( name="ML Backend Template", - description=( - "A template for an inference API that allows the user to run different sequences of machine learning " - "models and processing methods on images for the Antenna platform." - ), + description=("A lightweight template for running custom models locally."), pipelines=[pipeline.config for pipeline in pipelines], - # algorithms=list(algorithm_choices.values()), ) return info @@ -81,138 +63,13 @@ async def readyz(): @app.post("/process", tags=["services"]) async def process(data: PipelineRequest) -> PipelineResultsResponse: - pipeline_slug = data.pipeline - - source_images = [SourceImage(**img.model_dump()) for img in data.source_images] - # Open source images once before processing - for img in source_images: - img.open(raise_exception=True) - source_image_results = [SourceImageResponse(**image.model_dump()) for image in data.source_images] - - detections = create_detections( - source_images=source_images, - detection_requests=data.detections, - ) - - start_time = time.time() - try: - Pipeline = pipeline_choices[pipeline_slug] - except KeyError: - raise fastapi.HTTPException(status_code=422, detail=f"Invalid pipeline choice: {pipeline_slug}") - - try: - pipeline = Pipeline( - source_images=source_images, - existing_detections=detections, - ) - results = pipeline.run() + resp: PipelineResultsResponse = process_pipeline_request(data) except Exception as e: - logger.error(f"Error running pipeline: {e}") - raise fastapi.HTTPException(status_code=422, detail=f"{e}") - - end_time = time.time() - seconds_elapsed = float(end_time - start_time) - - response = PipelineResultsResponse( - pipeline=pipeline_slug, - # algorithms={algorithm.key: algorithm for algorithm in pipeline.config.algorithms}, - source_images=source_image_results, - detections=results, - total_time=seconds_elapsed, - ) - return response - - -# ----------- -# Helper functions -# ----------- - + logger.error(f"Error processing pipeline request: {e}") + raise fastapi.HTTPException(status_code=422, detail=str(e)) -def create_detections( - source_images: list[SourceImage], - detection_requests: list[DetectionRequest] | None, -): - if not detection_requests: - return [] - - # Group detection requests by source image id - source_image_map = {img.id: img for img in source_images} - grouped_detection_requests = {} - for request in detection_requests: - if request.source_image.id not in grouped_detection_requests: - grouped_detection_requests[request.source_image.id] = [] - grouped_detection_requests[request.source_image.id].append(request) - - # Process each source image and its detection requests - detections = [] - for source_image_id, requests in grouped_detection_requests.items(): - if source_image_id not in source_image_map: - raise ValueError( - f"A detection request for source image {source_image_id} was received, " - "but no source image with that ID was provided." - ) - - logger.info(f"Processing existing detections for source image {source_image_id}.") - - for request in requests: - source_image = source_image_map[source_image_id] - cropped_image_id = ( - f"{source_image.id}-crop-{request.bbox.x1}-{request.bbox.y1}-{request.bbox.x2}-{request.bbox.y2}" - ) - if not request.crop_image_url: - logger.info("Detection request does not have a crop_image_url, crop the original source image.") - assert source_image._pil is not None, "Source image must be opened before cropping." - cropped_image_pil = source_image._pil.crop( - (request.bbox.x1, request.bbox.y1, request.bbox.x2, request.bbox.y2) - ) - else: - try: - logger.info(f"Opening existing cropped image from {request.crop_image_url}.") - if is_url(request.crop_image_url): - cropped_image = SourceImage( - id=cropped_image_id, - url=request.crop_image_url, - ) - elif is_base64(request.crop_image_url): - logger.info("Decoding base64 cropped image.") - cropped_image = SourceImage( - id=cropped_image_id, - b64=request.crop_image_url, - ) - else: - # Must be a filepath - cropped_image = SourceImage( - id=cropped_image_id, - filepath=request.crop_image_url, - ) - cropped_image.open(raise_exception=True) - cropped_image_pil = cropped_image._pil - except Exception as e: - logger.warning(f"Error opening cropped image: {e}") - logger.info(f"Falling back to cropping the original source image {source_image_id}.") - assert source_image._pil is not None, "Source image must be opened before cropping." - cropped_image_pil = source_image._pil.crop( - (request.bbox.x1, request.bbox.y1, request.bbox.x2, request.bbox.y2) - ) - - # Create a Detection object - det = Detection( - source_image=SourceImage( - id=source_image.id, - url=source_image.url, - ), - bbox=request.bbox, - id=cropped_image_id, - url=request.crop_image_url or source_image.url, - algorithm=request.algorithm, - ) - # Set the _pil attribute to the cropped image - det._pil = cropped_image_pil - detections.append(det) - logger.info(f"Created detection {det.id} for source image {source_image_id}.") - - return detections + return resp if __name__ == "__main__": diff --git a/processing_services/minimal/api/processing.py b/processing_services/minimal/api/processing.py new file mode 100644 index 000000000..2156c4a61 --- /dev/null +++ b/processing_services/minimal/api/processing.py @@ -0,0 +1,164 @@ +import logging +import time + +from .pipelines import ConstantPipeline, Pipeline, RandomDetectionRandomSpeciesPipeline +from .schemas import ( + Detection, + DetectionRequest, + PipelineRequest, + PipelineResultsResponse, + SourceImage, + SourceImageResponse, +) +from .utils import is_base64, is_url + +# Get the root logger +logger = logging.getLogger(__name__) + +pipelines: list[type[Pipeline]] = [ConstantPipeline, RandomDetectionRandomSpeciesPipeline] +pipeline_choices: dict[str, type[Pipeline]] = {pipeline.config.slug: pipeline for pipeline in pipelines} + + +def process_pipeline_request(data: PipelineRequest) -> PipelineResultsResponse: + """ + Process a pipeline request. + + Args: + data (PipelineRequest): The request data containing pipeline configuration and source images. + + Returns: + PipelineResultsResponse: The response containing the results of the pipeline processing. + """ + logger.info(f"Processing pipeline request for pipeline: {data.pipeline}") + pipeline_slug = data.pipeline + + source_images = [SourceImage(**img.model_dump()) for img in data.source_images] + # Open source images once before processing + for img in source_images: + img.open(raise_exception=True) + source_image_results = [SourceImageResponse(**image.model_dump()) for image in data.source_images] + + detections = create_detections( + source_images=source_images, + detection_requests=data.detections, + ) + + start_time = time.time() + + try: + Pipeline = pipeline_choices[pipeline_slug] + except KeyError: + raise ValueError(f"Invalid pipeline choice: {pipeline_slug}") + + try: + pipeline = Pipeline( + source_images=source_images, + existing_detections=detections, + ) + results = pipeline.run() + except Exception as e: + logger.error(f"Error running pipeline: {e}") + raise Exception(f"Error running pipeline: {e}") + + end_time = time.time() + seconds_elapsed = float(end_time - start_time) + + response = PipelineResultsResponse( + pipeline=pipeline_slug, + algorithms={algorithm.key: algorithm for algorithm in pipeline.config.algorithms}, + source_images=source_image_results, + detections=results, + total_time=seconds_elapsed, + ) + return response + + +# ----------- +# Helper functions +# ----------- + + +def create_detections( + source_images: list[SourceImage], + detection_requests: list[DetectionRequest] | None, +): + if not detection_requests: + return [] + + # Group detection requests by source image id + source_image_map = {img.id: img for img in source_images} + grouped_detection_requests = {} + for request in detection_requests: + if request.source_image.id not in grouped_detection_requests: + grouped_detection_requests[request.source_image.id] = [] + grouped_detection_requests[request.source_image.id].append(request) + + # Process each source image and its detection requests + detections = [] + for source_image_id, requests in grouped_detection_requests.items(): + if source_image_id not in source_image_map: + raise ValueError( + f"A detection request for source image {source_image_id} was received, " + "but no source image with that ID was provided." + ) + + logger.info(f"Processing existing detections for source image {source_image_id}.") + + for request in requests: + source_image = source_image_map[source_image_id] + cropped_image_id = ( + f"{source_image.id}-crop-{request.bbox.x1}-{request.bbox.y1}-{request.bbox.x2}-{request.bbox.y2}" + ) + if not request.crop_image_url: + logger.info("Detection request does not have a crop_image_url, crop the original source image.") + assert source_image._pil is not None, "Source image must be opened before cropping." + cropped_image_pil = source_image._pil.crop( + (request.bbox.x1, request.bbox.y1, request.bbox.x2, request.bbox.y2) + ) + else: + try: + logger.info(f"Opening existing cropped image from {request.crop_image_url}.") + if is_url(request.crop_image_url): + cropped_image = SourceImage( + id=cropped_image_id, + url=request.crop_image_url, + ) + elif is_base64(request.crop_image_url): + logger.info("Decoding base64 cropped image.") + cropped_image = SourceImage( + id=cropped_image_id, + b64=request.crop_image_url, + ) + else: + # Must be a filepath + cropped_image = SourceImage( + id=cropped_image_id, + filepath=request.crop_image_url, + ) + cropped_image.open(raise_exception=True) + cropped_image_pil = cropped_image._pil + except Exception as e: + logger.warning(f"Error opening cropped image: {e}") + logger.info(f"Falling back to cropping the original source image {source_image_id}.") + assert source_image._pil is not None, "Source image must be opened before cropping." + cropped_image_pil = source_image._pil.crop( + (request.bbox.x1, request.bbox.y1, request.bbox.x2, request.bbox.y2) + ) + + # Create a Detection object + det = Detection( + source_image=SourceImage( + id=source_image.id, + url=source_image.url, + ), + bbox=request.bbox, + id=cropped_image_id, + url=request.crop_image_url or source_image.url, + algorithm=request.algorithm, + ) + # Set the _pil attribute to the cropped image + det._pil = cropped_image_pil + detections.append(det) + logger.info(f"Created detection {det.id} for source image {source_image_id}.") + + return detections diff --git a/processing_services/minimal/api/schemas.py b/processing_services/minimal/api/schemas.py index b0febba1b..cadb68aa5 100644 --- a/processing_services/minimal/api/schemas.py +++ b/processing_services/minimal/api/schemas.py @@ -203,7 +203,7 @@ class Config: extra = "ignore" -PipelineChoice = typing.Literal["random", "constant", "random-detection-random-species"] +PipelineChoice = typing.Literal["constant", "random-detection-random-species"] class PipelineRequest(pydantic.BaseModel): @@ -216,7 +216,7 @@ class PipelineRequest(pydantic.BaseModel): class Config: json_schema_extra = { "example": { - "pipeline": "random", + "pipeline": "constant", "source_images": [ { "id": "123", @@ -241,6 +241,7 @@ class PipelineResultsResponse(pydantic.BaseModel): ) source_images: list[SourceImageResponse] detections: list[DetectionResponse] + errors: str | None = None class PipelineStageParam(pydantic.BaseModel): diff --git a/processing_services/minimal/celery_worker/__init__.py b/processing_services/minimal/celery_worker/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/processing_services/minimal/celery_worker/get_queues.py b/processing_services/minimal/celery_worker/get_queues.py new file mode 100644 index 000000000..50ed498f3 --- /dev/null +++ b/processing_services/minimal/celery_worker/get_queues.py @@ -0,0 +1,9 @@ +from typing import get_args + +from api.schemas import PipelineChoice + +if __name__ == "__main__": + pipeline_names = get_args(PipelineChoice) + queue_names = [f"ml-pipeline-{name}" for name in pipeline_names] + queues = ",".join(queue_names) + print(queues) diff --git a/processing_services/minimal/celery_worker/start_celery.sh b/processing_services/minimal/celery_worker/start_celery.sh new file mode 100644 index 000000000..ebe662bba --- /dev/null +++ b/processing_services/minimal/celery_worker/start_celery.sh @@ -0,0 +1,7 @@ +#!/bin/bash +set -e + +QUEUES=$(python -m celery_worker.get_queues) + +echo "Starting Celery with queues: $QUEUES" +celery -A celery_worker.worker worker --queues="$QUEUES" --loglevel=info diff --git a/processing_services/minimal/celery_worker/worker.py b/processing_services/minimal/celery_worker/worker.py new file mode 100644 index 000000000..951918353 --- /dev/null +++ b/processing_services/minimal/celery_worker/worker.py @@ -0,0 +1,31 @@ +from typing import get_args + +from api.processing import process_pipeline_request as process +from api.schemas import PipelineChoice, PipelineRequest, PipelineResultsResponse +from celery import Celery +from kombu import Queue + +celery_app = Celery( + "minimal_worker", + broker="amqp://user:password@rabbitmq:5672//", + backend="redis://redis:6379/0", +) + +PIPELINES: list[PipelineChoice] = list(get_args(PipelineChoice)) +QUEUE_NAMES = [f"ml-pipeline-{name}" for name in PIPELINES] + +celery_app.conf.task_queues = [Queue(name=queue_name) for queue_name in QUEUE_NAMES] + +celery_app.conf.update(task_default_exchange="pipeline", task_default_exchange_type="direct") + + +@celery_app.task(name="process_pipeline_request", soft_time_limit=60 * 4, time_limit=60 * 5) +def process_pipeline_request(pipeline_request: dict, project_id: int) -> dict: + print(f"Running pipeline on: {pipeline_request}") + request_data = PipelineRequest(**pipeline_request) + resp: PipelineResultsResponse = process(request_data) + return resp.dict() + + +# Don't really need this? unless we auto-discover tasks if apps use `@celery_app.task` and define __init__.py +celery_app.autodiscover_tasks() diff --git a/processing_services/minimal/requirements.txt b/processing_services/minimal/requirements.txt index 6494fa201..4d4a967b1 100644 --- a/processing_services/minimal/requirements.txt +++ b/processing_services/minimal/requirements.txt @@ -3,3 +3,5 @@ uvicorn==0.35.0 pydantic==2.11.7 Pillow==11.3.0 requests==2.32.4 +celery==5.4.0 +redis==5.2.1