Skip to content

Commit 3d654e6

Browse files
carlosgjscarlos-irreverentlabsannavikCopilotmihow
authored
PSv2: Track and display image count progress and state (#1121)
* merge * Update ML job counts in async case * Update date picker version and tweak layout logic (#1105) * fix: update date picker version and tweak layout logic * feat: set start month based on selected date * fix: Properly handle async job state with celery tasks (#1114) * merge * fix: Properly handle async job state with celery tasks * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Delete implemented plan --------- Co-authored-by: Carlos Garcia Jurado Suarez <carlos@irreverentlabs.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * PSv2: Implement queue clean-up upon job completion (#1113) * merge * feat: PSv2 - Queue/redis clean-up upon job completion * fix: catch specific exception * chore: move tests to a subdir --------- Co-authored-by: Carlos Garcia Jurado Suarez <carlos@irreverentlabs.com> Co-authored-by: Michael Bunsen <notbot@gmail.com> * fix: PSv2: Workers should not try to fetch tasks from v1 jobs (#1118) Introduces the dispatch_mode field on the Job model to track how each job dispatches its workload. This allows API clients (including the AMI worker) to filter jobs by dispatch mode — for example, fetching only async_api jobs so workers don't pull synchronous or internal jobs. JobDispatchMode enum (ami/jobs/models.py): internal — work handled entirely within the platform (Celery worker, no external calls). Default for all jobs. sync_api — worker calls an external processing service API synchronously and waits for each response. async_api — worker publishes items to NATS for external processing service workers to pick up independently. Database and Model Changes: Added dispatch_mode CharField with TextChoices, defaulting to internal, with the migration in ami/jobs/migrations/0019_job_dispatch_mode.py. ML jobs set dispatch_mode = async_api when the project's async_pipeline_workers feature flag is enabled. ML jobs set dispatch_mode = sync_api on the synchronous processing path (previously unset). API and Filtering: dispatch_mode is exposed (read-only) in job list and detail serializers. Filterable via query parameter: ?dispatch_mode=async_api The /tasks endpoint now returns 400 for non-async_api jobs, since only those have NATS tasks to fetch. Architecture doc: docs/claude/job-dispatch-modes.md documents the three modes, naming decisions, and per-job-type mapping. --------- Co-authored-by: Carlos Garcia Jurado Suarez <carlos@irreverentlabs.com> Co-authored-by: Michael Bunsen <notbot@gmail.com> Co-authored-by: Claude <noreply@anthropic.com> * PSv2 cleanup: use is_complete() and dispatch_mode in job progress handler (#1125) * refactor: use is_complete() and dispatch_mode in job progress handler Replace hardcoded `stage == "results"` check with `job.progress.is_complete()` which verifies ALL stages are done, making it work for any job type. Replace feature flag check in cleanup with `dispatch_mode == ASYNC_API` which is immutable for the job's lifetime and more correct than re-reading a mutable flag that could change between job creation and completion. Co-Authored-By: Claude <noreply@anthropic.com> * test: update cleanup tests for is_complete() and dispatch_mode checks Set dispatch_mode=ASYNC_API on test jobs to match the new cleanup guard. Complete all stages (collect, process, results) in the completion test since is_complete() correctly requires all stages to be done. Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com> * track captures and failures * Update tests, CR feedback, log error images * CR feedback * fix type checking * refactor: rename _get_progress to _commit_update in TaskStateManager Clarify naming to distinguish mutating vs read-only methods: - _commit_update(): private, writes mutations to Redis, returns progress - get_progress(): public, read-only snapshot (added in #1129) - update_state(): public API, acquires lock, calls _commit_update() Co-Authored-By: Claude <noreply@anthropic.com> * fix: unify FAILURE_THRESHOLD and convert TaskProgress to dataclass - Single FAILURE_THRESHOLD constant in tasks.py, imported by models.py - Fix async path to use `> FAILURE_THRESHOLD` (was `>=`) to match the sync path's boundary behavior at exactly 50% - Convert TaskProgress from namedtuple to dataclass with defaults, so new fields don't break existing callers Co-Authored-By: Claude <noreply@anthropic.com> * refactor: rename TaskProgress to JobStateProgress Clarify that this dataclass tracks job-level progress in Redis, not individual task/image progress. Aligns with the naming of JobProgress (the Django/Pydantic model equivalent). Co-Authored-By: Claude <noreply@anthropic.com> * docs: update NATS todo and planning docs with session learnings Mark connection handling as done (PR #1130), add worktree/remote mapping and docker testing notes for future sessions. Co-Authored-By: Claude <noreply@anthropic.com> * Rename TaskStateManager to AsyncJobStateManager * Track results counts in the job itself vs Redis * small simplification * Reset counts to 0 on reset * chore: remove local planning docs from PR branch Co-Authored-By: Claude <noreply@anthropic.com> * docs: clarify three-layer job state architecture in docstrings Explain the relationship between AsyncJobStateManager (Redis), JobProgress (JSONB), and JobState (enum). Clarify that all counts in JobStateProgress refer to source images (captures). Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Carlos Garcia Jurado Suarez <carlos@irreverentlabs.com> Co-authored-by: Anna Viklund <annamariaviklund@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Michael Bunsen <notbot@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
1 parent e901c8e commit 3d654e6

File tree

8 files changed

+431
-181
lines changed

8 files changed

+431
-181
lines changed

ami/jobs/models.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def python_slugify(value: str) -> str:
109109

110110

111111
class JobProgressSummary(pydantic.BaseModel):
112-
"""Summary of all stages of a job"""
112+
"""Top-level status and progress for a job, shown in the UI."""
113113

114114
status: JobState = JobState.CREATED
115115
progress: float = 0
@@ -132,7 +132,17 @@ class JobProgressStageDetail(ConfigurableStage, JobProgressSummary):
132132

133133

134134
class JobProgress(pydantic.BaseModel):
135-
"""The full progress of a job and its stages."""
135+
"""
136+
The user-facing progress of a job, stored as JSONB on the Job model.
137+
138+
This is what the UI displays and what external APIs read. Contains named
139+
stages ("process", "results") with per-stage params (progress percentage,
140+
detections/classifications/captures counts, failed count).
141+
142+
For async (NATS) jobs, updated by _update_job_progress() in ami/jobs/tasks.py
143+
which copies snapshots from the internal Redis-backed AsyncJobStateManager.
144+
For sync jobs, updated directly in MLJob.process_images().
145+
"""
136146

137147
summary: JobProgressSummary
138148
stages: list[JobProgressStageDetail]
@@ -222,6 +232,10 @@ def reset(self, status: JobState = JobState.CREATED):
222232
for stage in self.stages:
223233
stage.progress = 0
224234
stage.status = status
235+
# Reset numeric param values to 0
236+
for param in stage.params:
237+
if isinstance(param.value, (int, float)):
238+
param.value = 0
225239

226240
def is_complete(self) -> bool:
227241
"""
@@ -561,7 +575,8 @@ def process_images(cls, job, images):
561575

562576
job.logger.info(f"All tasks completed for job {job.pk}")
563577

564-
FAILURE_THRESHOLD = 0.5
578+
from ami.jobs.tasks import FAILURE_THRESHOLD
579+
565580
if image_count and (percent_successful < FAILURE_THRESHOLD):
566581
job.progress.update_stage("process", status=JobState.FAILURE)
567582
job.save()

ami/jobs/tasks.py

Lines changed: 117 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,25 @@
33
import logging
44
import time
55
from collections.abc import Callable
6+
from typing import TYPE_CHECKING
67

78
from asgiref.sync import async_to_sync
89
from celery.signals import task_failure, task_postrun, task_prerun
910
from django.db import transaction
1011

12+
from ami.ml.orchestration.async_job_state import AsyncJobStateManager
1113
from ami.ml.orchestration.nats_queue import TaskQueueManager
12-
from ami.ml.orchestration.task_state import TaskStateManager
1314
from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse
1415
from ami.tasks import default_soft_time_limit, default_time_limit
1516
from config import celery_app
1617

18+
if TYPE_CHECKING:
19+
from ami.jobs.models import JobState
20+
1721
logger = logging.getLogger(__name__)
22+
# Minimum success rate. Jobs with fewer than this fraction of images
23+
# processed successfully are marked as failed. Also used in MLJob.process_images().
24+
FAILURE_THRESHOLD = 0.5
1825

1926

2027
@celery_app.task(bind=True, soft_time_limit=default_soft_time_limit, time_limit=default_time_limit)
@@ -59,23 +66,27 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
5966
result_data: Dictionary containing the pipeline result
6067
reply_subject: NATS reply subject for acknowledgment
6168
"""
62-
from ami.jobs.models import Job # avoid circular import
69+
from ami.jobs.models import Job, JobState # avoid circular import
6370

6471
_, t = log_time()
6572

6673
# Validate with Pydantic - check for error response first
74+
error_result = None
6775
if "error" in result_data:
6876
error_result = PipelineResultsError(**result_data)
6977
processed_image_ids = {str(error_result.image_id)} if error_result.image_id else set()
70-
logger.error(f"Pipeline returned error for job {job_id}, image {error_result.image_id}: {error_result.error}")
78+
failed_image_ids = processed_image_ids # Same as processed for errors
7179
pipeline_result = None
7280
else:
7381
pipeline_result = PipelineResultsResponse(**result_data)
7482
processed_image_ids = {str(img.id) for img in pipeline_result.source_images}
83+
failed_image_ids = set() # No failures for successful results
7584

76-
state_manager = TaskStateManager(job_id)
85+
state_manager = AsyncJobStateManager(job_id)
7786

78-
progress_info = state_manager.update_state(processed_image_ids, stage="process", request_id=self.request.id)
87+
progress_info = state_manager.update_state(
88+
processed_image_ids, stage="process", request_id=self.request.id, failed_image_ids=failed_image_ids
89+
)
7990
if not progress_info:
8091
logger.warning(
8192
f"Another task is already processing results for job {job_id}. "
@@ -84,16 +95,31 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
8495
raise self.retry(countdown=5, max_retries=10)
8596

8697
try:
87-
_update_job_progress(job_id, "process", progress_info.percentage)
98+
complete_state = JobState.SUCCESS
99+
if progress_info.total > 0 and (progress_info.failed / progress_info.total) > FAILURE_THRESHOLD:
100+
complete_state = JobState.FAILURE
101+
_update_job_progress(
102+
job_id,
103+
"process",
104+
progress_info.percentage,
105+
complete_state=complete_state,
106+
processed=progress_info.processed,
107+
remaining=progress_info.remaining,
108+
failed=progress_info.failed,
109+
)
88110

89111
_, t = t(f"TIME: Updated job {job_id} progress in PROCESS stage progress to {progress_info.percentage*100}%")
90112
job = Job.objects.get(pk=job_id)
91113
job.logger.info(f"Processing pipeline result for job {job_id}, reply_subject: {reply_subject}")
92114
job.logger.info(
93115
f" Job {job_id} progress: {progress_info.processed}/{progress_info.total} images processed "
94-
f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {len(processed_image_ids)} just "
95-
"processed"
116+
f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {progress_info.failed} failed, "
117+
f"{len(processed_image_ids)} just processed"
96118
)
119+
if error_result:
120+
job.logger.error(
121+
f"Pipeline returned error for job {job_id}, image {error_result.image_id}: {error_result.error}"
122+
)
97123
except Job.DoesNotExist:
98124
# don't raise and ack so that we don't retry since the job doesn't exists
99125
logger.error(f"Job {job_id} not found")
@@ -102,6 +128,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
102128

103129
try:
104130
# Save to database (this is the slow operation)
131+
detections_count, classifications_count, captures_count = 0, 0, 0
105132
if pipeline_result:
106133
# should never happen since otherwise we could not be processing results here
107134
assert job.pipeline is not None, "Job pipeline is None"
@@ -112,18 +139,41 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
112139
f"Saved pipeline results to database with {len(pipeline_result.detections)} detections"
113140
f", percentage: {progress_info.percentage*100}%"
114141
)
142+
# Calculate detection and classification counts from this result
143+
detections_count = len(pipeline_result.detections)
144+
classifications_count = sum(len(detection.classifications) for detection in pipeline_result.detections)
145+
captures_count = len(pipeline_result.source_images)
115146

116147
_ack_task_via_nats(reply_subject, job.logger)
117148
# Update job stage with calculated progress
118-
progress_info = state_manager.update_state(processed_image_ids, stage="results", request_id=self.request.id)
149+
150+
progress_info = state_manager.update_state(
151+
processed_image_ids,
152+
stage="results",
153+
request_id=self.request.id,
154+
)
119155

120156
if not progress_info:
121157
logger.warning(
122158
f"Another task is already processing results for job {job_id}. "
123159
f"Retrying task {self.request.id} in 5 seconds..."
124160
)
125161
raise self.retry(countdown=5, max_retries=10)
126-
_update_job_progress(job_id, "results", progress_info.percentage)
162+
163+
# update complete state based on latest progress info after saving results
164+
complete_state = JobState.SUCCESS
165+
if progress_info.total > 0 and (progress_info.failed / progress_info.total) > FAILURE_THRESHOLD:
166+
complete_state = JobState.FAILURE
167+
168+
_update_job_progress(
169+
job_id,
170+
"results",
171+
progress_info.percentage,
172+
complete_state=complete_state,
173+
detections=detections_count,
174+
classifications=classifications_count,
175+
captures=captures_count,
176+
)
127177

128178
except Exception as e:
129179
job.logger.error(
@@ -149,19 +199,72 @@ async def ack_task():
149199
# Don't fail the task if ACK fails - data is already saved
150200

151201

152-
def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> None:
202+
def _get_current_counts_from_job_progress(job, stage: str) -> tuple[int, int, int]:
203+
"""
204+
Get current detections, classifications, and captures counts from job progress.
205+
206+
Args:
207+
job: The Job instance
208+
stage: The stage name to read counts from
209+
210+
Returns:
211+
Tuple of (detections, classifications, captures) counts, defaulting to 0 if not found
212+
"""
213+
try:
214+
stage_obj = job.progress.get_stage(stage)
215+
216+
# Initialize defaults
217+
detections = 0
218+
classifications = 0
219+
captures = 0
220+
221+
# Search through the params list for our count values
222+
for param in stage_obj.params:
223+
if param.key == "detections":
224+
detections = param.value or 0
225+
elif param.key == "classifications":
226+
classifications = param.value or 0
227+
elif param.key == "captures":
228+
captures = param.value or 0
229+
230+
return detections, classifications, captures
231+
except (ValueError, AttributeError):
232+
# Stage doesn't exist or doesn't have these attributes yet
233+
return 0, 0, 0
234+
235+
236+
def _update_job_progress(
237+
job_id: int, stage: str, progress_percentage: float, complete_state: "JobState", **state_params
238+
) -> None:
153239
from ami.jobs.models import Job, JobState # avoid circular import
154240

155241
with transaction.atomic():
156242
job = Job.objects.select_for_update().get(pk=job_id)
243+
244+
# For results stage, accumulate detections/classifications/captures counts
245+
if stage == "results":
246+
current_detections, current_classifications, current_captures = _get_current_counts_from_job_progress(
247+
job, stage
248+
)
249+
250+
# Add new counts to existing counts
251+
new_detections = state_params.get("detections", 0)
252+
new_classifications = state_params.get("classifications", 0)
253+
new_captures = state_params.get("captures", 0)
254+
255+
state_params["detections"] = current_detections + new_detections
256+
state_params["classifications"] = current_classifications + new_classifications
257+
state_params["captures"] = current_captures + new_captures
258+
157259
job.progress.update_stage(
158260
stage,
159-
status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED,
261+
status=complete_state if progress_percentage >= 1.0 else JobState.STARTED,
160262
progress=progress_percentage,
263+
**state_params,
161264
)
162265
if job.progress.is_complete():
163-
job.status = JobState.SUCCESS
164-
job.progress.summary.status = JobState.SUCCESS
266+
job.status = complete_state
267+
job.progress.summary.status = complete_state
165268
job.finished_at = datetime.datetime.now() # Use naive datetime in local time
166269
job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%")
167270
job.save()

ami/jobs/test_tasks.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ami.jobs.tasks import process_nats_pipeline_result
1818
from ami.main.models import Detection, Project, SourceImage, SourceImageCollection
1919
from ami.ml.models import Pipeline
20-
from ami.ml.orchestration.task_state import TaskStateManager, _lock_key
20+
from ami.ml.orchestration.async_job_state import AsyncJobStateManager, _lock_key
2121
from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse, SourceImageResponse
2222
from ami.users.models import User
2323

@@ -64,7 +64,7 @@ def setUp(self):
6464

6565
# Initialize state manager
6666
self.image_ids = [str(img.pk) for img in self.images]
67-
self.state_manager = TaskStateManager(self.job.pk)
67+
self.state_manager = AsyncJobStateManager(self.job.pk)
6868
self.state_manager.initialize_job(self.image_ids)
6969

7070
def tearDown(self):
@@ -90,7 +90,7 @@ def _assert_progress_updated(
9090
self, job_id: int, expected_processed: int, expected_total: int, stage: str = "process"
9191
):
9292
"""Assert TaskStateManager state is correct."""
93-
manager = TaskStateManager(job_id)
93+
manager = AsyncJobStateManager(job_id)
9494
progress = manager.get_progress(stage)
9595
self.assertIsNotNone(progress, f"Progress not found for stage '{stage}'")
9696
self.assertEqual(progress.processed, expected_processed)
@@ -157,7 +157,7 @@ def test_process_nats_pipeline_result_error_no_image_id(self, mock_manager_class
157157

158158
# Assert: Progress was NOT updated (empty set of processed images)
159159
# Since no image_id was provided, processed_image_ids = set()
160-
manager = TaskStateManager(self.job.pk)
160+
manager = AsyncJobStateManager(self.job.pk)
161161
progress = manager.get_progress("process")
162162
self.assertEqual(progress.processed, 0) # No images marked as processed
163163

@@ -208,7 +208,7 @@ def test_process_nats_pipeline_result_mixed_results(self, mock_manager_class):
208208
)
209209

210210
# Assert: All 3 images marked as processed in TaskStateManager
211-
manager = TaskStateManager(self.job.pk)
211+
manager = AsyncJobStateManager(self.job.pk)
212212
process_progress = manager.get_progress("process")
213213
self.assertIsNotNone(process_progress)
214214
self.assertEqual(process_progress.processed, 3)
@@ -266,7 +266,7 @@ def test_process_nats_pipeline_result_error_concurrent_locking(self, mock_manage
266266
)
267267

268268
# Assert: Progress was NOT updated (lock not acquired)
269-
manager = TaskStateManager(self.job.pk)
269+
manager = AsyncJobStateManager(self.job.pk)
270270
progress = manager.get_progress("process")
271271
self.assertEqual(progress.processed, 0)
272272

@@ -342,7 +342,7 @@ def setUp(self):
342342
)
343343

344344
# Initialize state manager
345-
state_manager = TaskStateManager(self.job.pk)
345+
state_manager = AsyncJobStateManager(self.job.pk)
346346
state_manager.initialize_job([str(self.image.pk)])
347347

348348
def tearDown(self):

0 commit comments

Comments
 (0)