-
Notifications
You must be signed in to change notification settings - Fork 11
Support cancelling ML async jobs #1144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b60eab0
644927f
218f7aa
90da389
8618d3c
bd1be5f
b102ae1
b43b615
9827ed2
bc908aa
4c3802a
b717e80
02085dc
fa2964f
bcf6bce
883c4f8
0ab3f29
4917775
b43db16
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |||||
| from config import celery_app | ||||||
|
|
||||||
| if TYPE_CHECKING: | ||||||
| from ami.jobs.models import JobState | ||||||
| from ami.jobs.models import Job, JobState | ||||||
|
|
||||||
| logger = logging.getLogger(__name__) | ||||||
| # Minimum success rate. Jobs with fewer than this fraction of images | ||||||
|
|
@@ -94,6 +94,13 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub | |||||
| ) | ||||||
| raise self.retry(countdown=5, max_retries=10) | ||||||
|
|
||||||
| if progress_info.unknown: | ||||||
| logger.warning( | ||||||
| f"Progress info is unknown for job {job_id} when processing results. Job may be cancelled." | ||||||
| f"Or this could be a transient Redis error and the NATS task will be retried." | ||||||
|
||||||
| f"Or this could be a transient Redis error and the NATS task will be retried." | |
| f" Or this could be a transient Redis error and the NATS task will be retried." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space between implicitly concatenated f-strings.
Lines 99–100 use adjacent f-strings that Python concatenates without any separator. The result is "...cancelled.Or this could..." — missing a space after the period.
Proposed fix
if progress_info.unknown:
logger.warning(
f"Progress info is unknown for job {job_id} when processing results. Job may be cancelled."
- f"Or this could be a transient Redis error and the NATS task will be retried."
+ f" Or this could be a transient Redis error and the NATS task will be retried."
)
return🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@ami/jobs/tasks.py` around lines 97 - 102, The logger.warning message inside
the progress_info.unknown branch concatenates two adjacent f-strings and drops
the space between sentences; update the warning in the block that checks
progress_info.unknown (where job_id is referenced) so the two sentences are
separated (e.g., add a trailing space to the first f-string, a leading space on
the second, or merge into one f-string) ensuring the logged text reads
"...cancelled. Or this could..." when calling logger.warning.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,21 +95,31 @@ def _get_consumer_name(self, job_id: int) -> str: | |
| """Get consumer name from job_id.""" | ||
| return f"job-{job_id}-consumer" | ||
|
|
||
| async def _ensure_stream(self, job_id: int): | ||
| """Ensure stream exists for the given job.""" | ||
| async def _stream_exists(self, job_id: int) -> bool: | ||
| """Check if stream exists for the given job.""" | ||
| if self.js is None: | ||
| raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") | ||
|
|
||
| stream_name = self._get_stream_name(job_id) | ||
| subject = self._get_subject(job_id) | ||
|
|
||
| try: | ||
| await asyncio.wait_for(self.js.stream_info(stream_name), timeout=NATS_JETSTREAM_TIMEOUT) | ||
| logger.debug(f"Stream {stream_name} already exists") | ||
| return True | ||
| except asyncio.TimeoutError: | ||
| raise # NATS unreachable — let caller handle it rather than creating a stream blindly | ||
| except Exception as e: | ||
| logger.warning(f"Stream {stream_name} does not exist: {e}") | ||
| except Exception: | ||
| return False | ||
|
Comment on lines
+98
to
+111
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Check what exception nats.py raises for missing streams
cd /tmp && pip install nats-py >/dev/null 2>&1
# Try to find the exception types in nats-py
python3 << 'EOF'
try:
# Try different import paths
try:
from nats.js.errors import NotFoundError
print("Found: nats.js.errors.NotFoundError")
except ImportError:
pass
try:
from nats.errors import NotFoundError
print("Found: nats.errors.NotFoundError")
except ImportError:
pass
# List all exceptions in nats module
import nats
import inspect
# Check nats.js module
try:
import nats.js
print("\n=== nats.js exceptions ===")
for name in dir(nats.js):
obj = getattr(nats.js, name)
if inspect.isclass(obj) and issubclass(obj, BaseException):
print(f" {name}")
except Exception as e:
print(f"Error inspecting nats.js: {e}")
# Check for errors submodule
try:
import nats.js.errors as errors
print("\n=== nats.js.errors exceptions ===")
for name in dir(errors):
if not name.startswith('_'):
obj = getattr(errors, name)
if inspect.isclass(obj) and issubclass(obj, BaseException):
print(f" {name}")
except Exception as e:
print(f"Error inspecting nats.js.errors: {e}")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
EOFRepository: RolnickLab/antenna Length of output: 712 🏁 Script executed: # Also check the actual file to see the context
cat -n ami/ml/orchestration/nats_queue.py | sed -n '95,135p'Repository: RolnickLab/antenna Length of output: 2088 🏁 Script executed: # Check for other usages of stream_info or exception handling in the file
rg "stream_info|except" ami/ml/orchestration/nats_queue.py -nRepository: RolnickLab/antenna Length of output: 665 Narrow exception handling to avoid masking transient NATS errors. The Catch 🧰 Tools🪛 Ruff (0.15.2)[warning] 101-101: Avoid specifying long messages outside the exception class (TRY003) [warning] 107-107: Consider moving this statement to an (TRY300) [warning] 110-110: Do not catch blind exception: (BLE001) 🤖 Prompt for AI Agents |
||
|
|
||
| async def _ensure_stream(self, job_id: int): | ||
| """Ensure stream exists for the given job.""" | ||
| if self.js is None: | ||
| raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") | ||
|
|
||
| if not await self._stream_exists(job_id): | ||
| stream_name = self._get_stream_name(job_id) | ||
| subject = self._get_subject(job_id) | ||
|
|
||
| logger.warning(f"Stream {stream_name} does not exist") | ||
| # Stream doesn't exist, create it | ||
| await asyncio.wait_for( | ||
| self.js.add_stream( | ||
|
|
@@ -207,7 +217,10 @@ async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> li | |
| raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") | ||
|
|
||
| try: | ||
| await self._ensure_stream(job_id) | ||
| if not await self._stream_exists(job_id): | ||
| logger.debug(f"Stream for job '{job_id}' does not exist when reserving task") | ||
| return [] | ||
|
|
||
| await self._ensure_consumer(job_id) | ||
|
|
||
| consumer_name = self._get_consumer_name(job_id) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -902,7 +902,7 @@ def test_small_size_filter_assigns_not_identifiable(self): | |||||
| ) | ||||||
|
|
||||||
|
|
||||||
| class TestTaskStateManager(TestCase): | ||||||
| class TestAsyncJobStateManager(TestCase): | ||||||
| """Test TaskStateManager for job progress tracking.""" | ||||||
|
||||||
| """Test TaskStateManager for job progress tracking.""" | |
| """Test AsyncJobStateManager for job progress tracking.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unconditional status change to REVOKED on lines 981-982 could create a race condition with the task_postrun signal handler. For sync jobs with task_id, task.revoke(terminate=True) triggers the task_postrun signal which calls update_job_status and sets status to REVOKED. However, if the signal handler runs between line 978 and line 981, this code would still overwrite any status the signal handler set. Consider checking the dispatch mode or current status before unconditionally setting to REVOKED, or removing the duplicate status update for jobs with task_id.