diff --git a/ami/main/api/views.py b/ami/main/api/views.py index 9a2770ac8..1dd493099 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -13,6 +13,7 @@ from django_filters.rest_framework import DjangoFilterBackend from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter, extend_schema +from pydantic import ValidationError from rest_framework import exceptions as api_exceptions from rest_framework import filters, serializers, status, viewsets from rest_framework.decorators import action @@ -31,6 +32,8 @@ from ami.base.views import ProjectMixin from ami.main.api.schemas import project_id_doc_param from ami.main.api.serializers import TagSerializer +from ami.ml.models.processing_service import ProcessingService +from ami.ml.schemas import AsyncPipelineRegistrationRequest from ami.utils.requests import get_default_classification_threshold from ami.utils.storages import ConnectionTestResult @@ -206,6 +209,65 @@ def charts(self, request, pk=None): project = self.get_object() return Response({"summary_data": project.summary_data()}) + @action(detail=True, methods=["post"], url_path="pipelines") + def pipelines(self, request, pk=None): + """ + Receive pipeline registrations for a project. This endpoint is called by the + V2 ML processing services to register available pipelines for a project. + + Expected payload: PipelineRegistrationResponse (pydantic schema) containing a + list of PipelineConfigResponse objects under the `pipelines` key. + + Behavior: + - If the project has no associated ProcessingService, create a dummy one and + associate it with the project. + - Call ProcessingService.create_pipelines() with the provided pipeline configs + and limit the operation to this project. + + Returns the PipelineRegistrationResponse returned by create_pipelines(). + """ + # Parse the incoming payload using the pydantic schema so we convert dicts to + # the expected PipelineConfigResponse models + + try: + parsed: AsyncPipelineRegistrationRequest = AsyncPipelineRegistrationRequest.parse_obj(request.data) + except ValidationError as err: + logger.debug(f"Invalid pipeline registration payload: {err}") + return Response({"detail": str(err)}, status=status.HTTP_400_BAD_REQUEST) + + project: Project = self.get_object() + + # TODO: Discuss the right approach for associating pipelines with projects in V2. + processing_service = ProcessingService.objects.filter( + projects=project, name=parsed.processing_service_name + ).first() + + if not processing_service: + # Create a dummy processing service and associate it with the project + processing_service = ProcessingService.objects.create( + name=parsed.processing_service_name, + endpoint_url=None, # TODO: depends on https://github.com/RolnickLab/antenna/pull/1090 + ) + processing_service.projects.add(project) + processing_service.save() + logger.info(f"Created dummy processing service {processing_service} for project {project.pk}") + + pipeline_configs = None + if parsed and parsed.pipeline_response: + pipeline_configs = parsed.pipeline_response.pipelines + + # Call create_pipelines limited to this project + response = processing_service.create_pipelines( + pipeline_configs=pipeline_configs, + projects=Project.objects.filter(pk=project.pk), + ) + + # Save any changes to the processing service + processing_service.save() + + # response is a pydantic model; return its dict representation + return Response(response.dict()) + @extend_schema( parameters=[ OpenApiParameter( diff --git a/ami/ml/schemas.py b/ami/ml/schemas.py index 478b4c8fd..9b2c0dfde 100644 --- a/ami/ml/schemas.py +++ b/ami/ml/schemas.py @@ -307,3 +307,12 @@ class PipelineRegistrationResponse(pydantic.BaseModel): pipelines: list[PipelineConfigResponse] = [] pipelines_created: list[str] = [] algorithms_created: list[str] = [] + + +class AsyncPipelineRegistrationRequest(pydantic.BaseModel): + """ + Request to register pipelines from an async processing service + """ + + processing_service_name: str + pipeline_response: PipelineRegistrationResponse diff --git a/requirements/base.txt b/requirements/base.txt index dd9de69d5..ed40ea5f7 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -52,6 +52,7 @@ django-anymail[sendgrid]==10.0 # https://github.com/anymail/django-anymail Werkzeug[watchdog]==2.3.6 # https://github.com/pallets/werkzeug ipdb==0.13.13 # https://github.com/gotcha/ipdb psycopg[binary]==3.1.9 # https://github.com/psycopg/psycopg +# psycopg==3.1.9 # https://github.com/psycopg/psycopg # the non-binary version is needed for some platforms watchfiles==0.19.0 # https://github.com/samuelcolvin/watchfiles # Testing