Skip to content

Commit d350600

Browse files
committed
Merge #1117: PSv2 processing service name parameter
Resolve conflicts in ami/jobs/views.py and ami/jobs/tests.py: - views.py: keep both dispatch_mode guard and service name logging - tests.py: keep both test classes, fix service name test to set async dispatch_mode
2 parents dec0e19 + 5b53380 commit d350600

File tree

4 files changed

+88
-31
lines changed

4 files changed

+88
-31
lines changed

ami/jobs/schemas.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,10 @@
2020
required=False,
2121
type=int,
2222
)
23+
24+
processing_service_name_param = OpenApiParameter(
25+
name="processing_service_name",
26+
description="Name of the calling processing service",
27+
required=False,
28+
type=str,
29+
)

ami/jobs/tests.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,54 @@ def test_result_endpoint_validation(self):
520520
self.assertEqual(resp.status_code, 400)
521521
self.assertIn("result", resp.json()[0].lower())
522522

523+
def test_processing_service_name_parameter(self):
524+
"""Test that processing_service_name parameter is accepted on job endpoints."""
525+
self.client.force_authenticate(user=self.user)
526+
test_service_name = "Test Service"
527+
528+
# Test list endpoint
529+
list_url = reverse_with_params(
530+
"api:job-list", params={"project_id": self.project.pk, "processing_service_name": test_service_name}
531+
)
532+
resp = self.client.get(list_url)
533+
self.assertEqual(resp.status_code, 200)
534+
535+
# Test tasks endpoint (requires job with pipeline)
536+
pipeline = self._create_pipeline()
537+
job = self._create_ml_job("Job for service name test", pipeline)
538+
job.dispatch_mode = JobDispatchMode.ASYNC_API
539+
job.save(update_fields=["dispatch_mode"])
540+
541+
tasks_url = reverse_with_params(
542+
"api:job-tasks",
543+
args=[job.pk],
544+
params={"project_id": self.project.pk, "batch": 1, "processing_service_name": test_service_name},
545+
)
546+
resp = self.client.get(tasks_url)
547+
self.assertEqual(resp.status_code, 200)
548+
549+
# Test result endpoint
550+
result_url = reverse_with_params(
551+
"api:job-result",
552+
args=[job.pk],
553+
params={"project_id": self.project.pk, "processing_service_name": test_service_name},
554+
)
555+
result_data = [
556+
{
557+
"reply_subject": "test.reply.1",
558+
"result": {
559+
"pipeline": "test-pipeline",
560+
"algorithms": {},
561+
"total_time": 1.5,
562+
"source_images": [],
563+
"detections": [],
564+
"errors": None,
565+
},
566+
}
567+
]
568+
resp = self.client.post(result_url, result_data, format="json")
569+
self.assertEqual(resp.status_code, 200)
570+
523571

524572
class TestJobDispatchModeFiltering(APITestCase):
525573
"""Test job filtering by dispatch_mode."""

ami/jobs/views.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from ami.base.permissions import ObjectPermission
1717
from ami.base.views import ProjectMixin
18-
from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param
18+
from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param, processing_service_name_param
1919
from ami.jobs.tasks import process_nats_pipeline_result
2020
from ami.main.api.schemas import project_id_doc_param
2121
from ami.main.api.views import DefaultViewSet
@@ -204,13 +204,16 @@ def get_queryset(self) -> QuerySet:
204204
project_id_doc_param,
205205
ids_only_param,
206206
incomplete_only_param,
207+
processing_service_name_param,
207208
]
208209
)
209210
def list(self, request, *args, **kwargs):
211+
_ = _log_processing_service_name(request, "list requested", logger)
212+
210213
return super().list(request, *args, **kwargs)
211214

212215
@extend_schema(
213-
parameters=[batch_param],
216+
parameters=[batch_param, processing_service_name_param],
214217
responses={200: dict},
215218
)
216219
@action(detail=True, methods=["get"], name="tasks")
@@ -233,6 +236,8 @@ def tasks(self, request, pk=None):
233236
if job.dispatch_mode != JobDispatchMode.ASYNC_API:
234237
raise ValidationError("Only async_api jobs have fetchable tasks")
235238

239+
_ = _log_processing_service_name(request, f"tasks ({batch}) requested for job {job.pk}", job.logger)
240+
236241
# Validate that the job has a pipeline
237242
if not job.pipeline:
238243
raise ValidationError("This job does not have a pipeline configured")
@@ -254,6 +259,9 @@ async def get_tasks():
254259

255260
return Response({"tasks": tasks})
256261

262+
@extend_schema(
263+
parameters=[processing_service_name_param],
264+
)
257265
@action(detail=True, methods=["post"], name="result")
258266
def result(self, request, pk=None):
259267
"""
@@ -266,6 +274,8 @@ def result(self, request, pk=None):
266274

267275
job = self.get_object()
268276

277+
_ = _log_processing_service_name(request, f"result received for job {job.pk}", job.logger)
278+
269279
# Validate request data is a list
270280
if isinstance(request.data, list):
271281
results = request.data
@@ -325,3 +335,24 @@ def result(self, request, pk=None):
325335
},
326336
status=500,
327337
)
338+
339+
340+
def _log_processing_service_name(request, context: str, logger: logging.Logger) -> str | None:
341+
"""
342+
Log the processing_service_name from query parameters.
343+
344+
Args:
345+
request: The HTTP request object
346+
context: A string describing the operation (e.g., "tasks requested", "result received")
347+
logger: A logging.Logger instance to use for logging
348+
Returns:
349+
The processing_service_name if provided, otherwise None
350+
"""
351+
processing_service_name = request.query_params.get("processing_service_name", None)
352+
353+
if processing_service_name:
354+
logger.info(f"Jobs {context} by processing service: {processing_service_name}")
355+
else:
356+
logger.debug(f"Jobs {context} without processing service name")
357+
358+
return processing_service_name

ami/utils/requests.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
import requests
44
from django.forms import BooleanField, FloatField
5-
from drf_spectacular.types import OpenApiTypes
6-
from drf_spectacular.utils import OpenApiParameter
75
from requests.adapters import HTTPAdapter
86
from rest_framework.request import Request
97
from urllib3.util import Retry
@@ -144,30 +142,3 @@ def get_default_classification_threshold(project: "Project | None" = None, reque
144142
return project.default_filters_score_threshold
145143
else:
146144
return default_threshold
147-
148-
149-
project_id_doc_param = OpenApiParameter(
150-
name="project_id",
151-
description="Filter by project ID",
152-
required=False,
153-
type=int,
154-
)
155-
156-
ids_only_param = OpenApiParameter(
157-
name="ids_only",
158-
description="Return only job IDs instead of full job objects",
159-
required=False,
160-
type=OpenApiTypes.BOOL,
161-
)
162-
incomplete_only_param = OpenApiParameter(
163-
name="incomplete_only",
164-
description="Filter to only show incomplete jobs (excludes SUCCESS, FAILURE, REVOKED)",
165-
required=False,
166-
type=OpenApiTypes.BOOL,
167-
)
168-
batch_param = OpenApiParameter(
169-
name="batch",
170-
description="Number of tasks to pull in the batch",
171-
required=False,
172-
type=OpenApiTypes.INT,
173-
)

0 commit comments

Comments
 (0)