-
Notifications
You must be signed in to change notification settings - Fork 12
Support cancelling ML async jobs #1144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
b60eab0
644927f
218f7aa
90da389
8618d3c
bd1be5f
b102ae1
b43b615
9827ed2
bc908aa
4c3802a
b717e80
02085dc
fa2964f
bcf6bce
883c4f8
0ab3f29
4917775
b43db16
9ffe966
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
| from config import celery_app | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ami.jobs.models import JobState | ||
| from ami.jobs.models import Job, JobState | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| # Minimum success rate. Jobs with fewer than this fraction of images | ||
|
|
@@ -94,6 +94,13 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub | |
| ) | ||
| raise self.retry(countdown=5, max_retries=10) | ||
|
|
||
| if progress_info.unknown: | ||
| logger.warning( | ||
| f"Progress info is unknown for job {job_id} when processing results. Job may be cancelled." | ||
| f"Or this could be a transient Redis error and the NATS task will be retried." | ||
carlosgjs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
| return | ||
|
Comment on lines
+97
to
+102
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unknown-progress guard is missing for the results-stage The new guard at line 97 covers the process stage, but the second 🐛 Proposed fix — mirror the guard after the results-stage `update_state` if not progress_info:
logger.warning(
f"Another task is already processing results for job {job_id}. "
f"Retrying task {self.request.id} in 5 seconds..."
)
raise self.retry(countdown=5, max_retries=10)
+
+ if progress_info.unknown:
+ logger.warning(
+ f"Progress info is unknown for job {job_id} when processing results (results stage). "
+ f" Job may be cancelled. Or this could be a transient Redis error and the NATS task will be retried."
+ )
+ return
+
# update complete state based on latest progress info after saving resultsApply this diff immediately after line 168 (the existing 🤖 Prompt for AI Agents |
||
|
|
||
| try: | ||
| complete_state = JobState.SUCCESS | ||
| if progress_info.total > 0 and (progress_info.failed / progress_info.total) > FAILURE_THRESHOLD: | ||
|
|
@@ -272,10 +279,10 @@ def _update_job_progress( | |
| # Clean up async resources for completed jobs that use NATS/Redis | ||
| if job.progress.is_complete(): | ||
| job = Job.objects.get(pk=job_id) # Re-fetch outside transaction | ||
| _cleanup_job_if_needed(job) | ||
| cleanup_async_job_if_needed(job) | ||
|
|
||
|
|
||
| def _cleanup_job_if_needed(job) -> None: | ||
| def cleanup_async_job_if_needed(job: "Job") -> None: | ||
| """ | ||
| Clean up async resources (NATS/Redis) if this job uses them. | ||
|
|
||
|
|
@@ -330,7 +337,7 @@ def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs): | |
|
|
||
| # Clean up async resources for revoked jobs | ||
| if state == JobState.REVOKED: | ||
| _cleanup_job_if_needed(job) | ||
| cleanup_async_job_if_needed(job) | ||
|
|
||
|
|
||
| @task_failure.connect(sender=run_job, retry=False) | ||
|
|
@@ -345,7 +352,7 @@ def update_job_failure(sender, task_id, exception, *args, **kwargs): | |
| job.save() | ||
|
|
||
| # Clean up async resources for failed jobs | ||
| _cleanup_job_if_needed(job) | ||
| cleanup_async_job_if_needed(job) | ||
|
|
||
|
|
||
| def log_time(start: float = 0, msg: str | None = None) -> tuple[float, Callable]: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,21 +95,31 @@ def _get_consumer_name(self, job_id: int) -> str: | |
| """Get consumer name from job_id.""" | ||
| return f"job-{job_id}-consumer" | ||
|
|
||
| async def _ensure_stream(self, job_id: int): | ||
| """Ensure stream exists for the given job.""" | ||
| async def _stream_exists(self, job_id: int) -> bool: | ||
| """Check if stream exists for the given job.""" | ||
| if self.js is None: | ||
| raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") | ||
|
|
||
| stream_name = self._get_stream_name(job_id) | ||
| subject = self._get_subject(job_id) | ||
|
|
||
| try: | ||
| await asyncio.wait_for(self.js.stream_info(stream_name), timeout=NATS_JETSTREAM_TIMEOUT) | ||
| logger.debug(f"Stream {stream_name} already exists") | ||
| return True | ||
| except asyncio.TimeoutError: | ||
| raise # NATS unreachable — let caller handle it rather than creating a stream blindly | ||
| except Exception as e: | ||
| logger.warning(f"Stream {stream_name} does not exist: {e}") | ||
| except Exception: | ||
| return False | ||
|
Comment on lines
+98
to
+111
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Check what exception nats.py raises for missing streams
cd /tmp && pip install nats-py >/dev/null 2>&1
# Try to find the exception types in nats-py
python3 << 'EOF'
try:
# Try different import paths
try:
from nats.js.errors import NotFoundError
print("Found: nats.js.errors.NotFoundError")
except ImportError:
pass
try:
from nats.errors import NotFoundError
print("Found: nats.errors.NotFoundError")
except ImportError:
pass
# List all exceptions in nats module
import nats
import inspect
# Check nats.js module
try:
import nats.js
print("\n=== nats.js exceptions ===")
for name in dir(nats.js):
obj = getattr(nats.js, name)
if inspect.isclass(obj) and issubclass(obj, BaseException):
print(f" {name}")
except Exception as e:
print(f"Error inspecting nats.js: {e}")
# Check for errors submodule
try:
import nats.js.errors as errors
print("\n=== nats.js.errors exceptions ===")
for name in dir(errors):
if not name.startswith('_'):
obj = getattr(errors, name)
if inspect.isclass(obj) and issubclass(obj, BaseException):
print(f" {name}")
except Exception as e:
print(f"Error inspecting nats.js.errors: {e}")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
EOFRepository: RolnickLab/antenna Length of output: 712 🏁 Script executed: # Also check the actual file to see the context
cat -n ami/ml/orchestration/nats_queue.py | sed -n '95,135p'Repository: RolnickLab/antenna Length of output: 2088 🏁 Script executed: # Check for other usages of stream_info or exception handling in the file
rg "stream_info|except" ami/ml/orchestration/nats_queue.py -nRepository: RolnickLab/antenna Length of output: 665 Narrow exception handling to avoid masking transient NATS errors. The Catch 🧰 Tools🪛 Ruff (0.15.2)[warning] 101-101: Avoid specifying long messages outside the exception class (TRY003) [warning] 107-107: Consider moving this statement to an (TRY300) [warning] 110-110: Do not catch blind exception: (BLE001) 🤖 Prompt for AI Agents |
||
|
|
||
| async def _ensure_stream(self, job_id: int): | ||
| """Ensure stream exists for the given job.""" | ||
| if self.js is None: | ||
| raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") | ||
|
|
||
| if not await self._stream_exists(job_id): | ||
| stream_name = self._get_stream_name(job_id) | ||
| subject = self._get_subject(job_id) | ||
|
|
||
| logger.warning(f"Stream {stream_name} does not exist") | ||
| # Stream doesn't exist, create it | ||
| await asyncio.wait_for( | ||
| self.js.add_stream( | ||
|
|
@@ -207,7 +217,10 @@ async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> li | |
| raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") | ||
|
|
||
| try: | ||
| await self._ensure_stream(job_id) | ||
| if not await self._stream_exists(job_id): | ||
| logger.debug(f"Stream for job '{job_id}' does not exist when reserving task") | ||
| return [] | ||
|
|
||
| await self._ensure_consumer(job_id) | ||
|
|
||
| consumer_name = self._get_consumer_name(job_id) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.