-
Notifications
You must be signed in to change notification settings - Fork 11
[integration] Enable async and distributed processing for the ML backend #910
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
8aad275
61b45a4
4a03c7e
09d7dfb
996674e
ce973fc
bf7178d
41efa42
78babeb
8d28d01
bb22514
1dbc5f0
fe1a9f4
7747f3a
1978cbe
82ac82d
7d733f9
07d61d9
d129029
76ce2d8
035b952
1230386
45dbacf
7361fb2
3f722c8
075a7ec
85c676d
cbd7ae0
7d15ffb
c2881b4
14396ba
6613366
2cf0c0a
1dbf3b1
c82c076
fb874c4
f2ef5ff
b6ce90f
8fe8b1d
d0f4f26
fc8470d
3d3b820
5c7af56
e7e579e
cb74eac
ffea1aa
6cb852b
d0380b9
2594049
57e6691
f785dda
0a22d53
d139734
a83dd20
0707433
0103b7e
3a3b881
7c86612
c016d47
fa510ed
6da55a9
652f47f
f3b588a
5654ed0
875d3cb
be351f8
e550531
2fa57ef
5c21be6
a1e8fa3
57fba22
f8e374a
cd593bc
bde423a
a1238dc
03390e2
460f27c
6f87ca4
bd86042
b552d1e
41b8c16
caa11db
0fd6369
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| # Generated by Django 4.2.10 on 2025-11-04 11:44 | ||
|
|
||
| import datetime | ||
| from django.db import migrations, models | ||
|
|
||
|
|
||
| class Migration(migrations.Migration): | ||
| dependencies = [ | ||
| ("jobs", "0022_job_last_checked"), | ||
| ] | ||
|
|
||
| operations = [ | ||
| migrations.AlterField( | ||
| model_name="job", | ||
| name="last_checked", | ||
| field=models.DateTimeField(blank=True, default=datetime.datetime.now, null=True), | ||
| ), | ||
| migrations.AlterField( | ||
| model_name="mltaskrecord", | ||
| name="status", | ||
| field=models.CharField( | ||
| choices=[ | ||
| ("PENDING", "PENDING"), | ||
| ("STARTED", "STARTED"), | ||
| ("SUCCESS", "SUCCESS"), | ||
| ("FAIL", "FAIL"), | ||
| ("REVOKED", "REVOKED"), | ||
| ], | ||
| default="STARTED", | ||
| max_length=255, | ||
| ), | ||
| ), | ||
| ] | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -396,9 +396,9 @@ def check_inprogress_subtasks(cls, job: "Job") -> bool: | |||||||||||||||||||||||||
| inprogress_subtask.task_id = save_results_task.id | ||||||||||||||||||||||||||
| task_id = save_results_task.id | ||||||||||||||||||||||||||
| inprogress_subtask.save() | ||||||||||||||||||||||||||
| job.logger.info(f"Started save results task {inprogress_subtask.task_id}") | ||||||||||||||||||||||||||
| job.logger.debug(f"Started save results task {inprogress_subtask.task_id}") | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| job.logger.info("A save results task is already in progress, will not start another one yet.") | ||||||||||||||||||||||||||
| job.logger.debug("A save results task is already in progress, will not start another one yet.") | ||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| task = AsyncResult(task_id) | ||||||||||||||||||||||||||
|
|
@@ -407,12 +407,12 @@ def check_inprogress_subtasks(cls, job: "Job") -> bool: | |||||||||||||||||||||||||
| inprogress_subtask.status = ( | ||||||||||||||||||||||||||
| MLSubtaskState.SUCCESS.name if task.successful() else MLSubtaskState.FAIL.name | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| inprogress_subtask.raw_traceback = task.traceback | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if task.traceback: | ||||||||||||||||||||||||||
| # TODO: Error logs will have many tracebacks | ||||||||||||||||||||||||||
| # could add some processing to provide a concise error summary | ||||||||||||||||||||||||||
| job.logger.error(f"Subtask {task_name} ({task_id}) failed: {task.traceback}") | ||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logs the error, but then still tries to parse a successful result. Can you mark the task as failed and then continue to the next one? The status check did it's job correctly! But the subtask failed. Can you update the MLTaskRecord to say it failed? |
||||||||||||||||||||||||||
| inprogress_subtask.status = MLSubtaskState.FAIL.name | ||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this get saved?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, with Lines 441 to 452 in bd86042
|
||||||||||||||||||||||||||
| inprogress_subtask.raw_traceback = task.traceback | ||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| results_dict = task.result | ||||||||||||||||||||||||||
| if task_name == MLSubtaskNames.process_pipeline_request.name: | ||||||||||||||||||||||||||
|
|
@@ -505,7 +505,7 @@ def check_inprogress_subtasks(cls, job: "Job") -> bool: | |||||||||||||||||||||||||
| f"{inprogress_subtasks.count()} inprogress subtasks remaining out of {total_subtasks} total subtasks." | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| inprogress_task_ids = [task.task_id for task in inprogress_subtasks] | ||||||||||||||||||||||||||
| job.logger.info(f"Subtask ids: {inprogress_task_ids}") # TODO: remove this? not very useful to the user | ||||||||||||||||||||||||||
| job.logger.debug(f"Subtask ids: {inprogress_task_ids}") | ||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| job.logger.info("No inprogress subtasks left.") | ||||||||||||||||||||||||||
|
|
@@ -999,6 +999,7 @@ class MLSubtaskState(str, OrderedEnum): | |||||||||||||||||||||||||
| STARTED = "STARTED" | ||||||||||||||||||||||||||
| SUCCESS = "SUCCESS" | ||||||||||||||||||||||||||
| FAIL = "FAIL" | ||||||||||||||||||||||||||
| REVOKED = "REVOKED" | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| class MLTaskRecord(BaseModel): | ||||||||||||||||||||||||||
|
|
@@ -1041,6 +1042,17 @@ def clean(self): | |||||||||||||||||||||||||
| if self.status == MLSubtaskState.PENDING.name and self.task_name != MLSubtaskNames.save_results.name: | ||||||||||||||||||||||||||
| raise ValueError(f"{self.task_name} tasks cannot have a PENDING status.") | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def kill_task(self): | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| Kill the celery task associated with this MLTaskRecord. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| from config.celery_app import app as celery_app | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if self.task_id: | ||||||||||||||||||||||||||
| celery_app.control.revoke(self.task_id, terminate=True, signal="SIGTERM") | ||||||||||||||||||||||||||
| self.status = MLSubtaskState.REVOKED.name | ||||||||||||||||||||||||||
| self.save(update_fields=["status"]) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| class Job(BaseModel): | ||||||||||||||||||||||||||
| """A job to be run by the scheduler""" | ||||||||||||||||||||||||||
|
|
@@ -1050,7 +1062,7 @@ class Job(BaseModel): | |||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| name = models.CharField(max_length=255) | ||||||||||||||||||||||||||
| queue = models.CharField(max_length=255, default="default") | ||||||||||||||||||||||||||
| last_checked = models.DateTimeField(null=True, blank=True) | ||||||||||||||||||||||||||
| last_checked = models.DateTimeField(null=True, blank=True, default=datetime.datetime.now) | ||||||||||||||||||||||||||
| scheduled_at = models.DateTimeField(null=True, blank=True) | ||||||||||||||||||||||||||
| started_at = models.DateTimeField(null=True, blank=True) | ||||||||||||||||||||||||||
| finished_at = models.DateTimeField(null=True, blank=True) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,4 +1,5 @@ | ||||||||||||||
| # from rich import print | ||||||||||||||
| import datetime | ||||||||||||||
| import logging | ||||||||||||||
| import time | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -406,4 +407,94 @@ def get_ml_job_subtask_details(task_name, job): | |||||||||||||
| self.assertEqual(job.status, JobState.SUCCESS.value) | ||||||||||||||
| self.assertEqual(job.progress.summary.progress, 1) | ||||||||||||||
| self.assertEqual(job.progress.summary.status, JobState.SUCCESS) | ||||||||||||||
| job.save() | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class TestStaleMLJob(TransactionTestCase): | ||||||||||||||
| def setUp(self): | ||||||||||||||
| self.project = Project.objects.first() # get the original test project | ||||||||||||||
| assert self.project | ||||||||||||||
| self.source_image_collection = self.project.sourceimage_collections.get(name="Test Source Image Collection") | ||||||||||||||
| self.pipeline = Pipeline.objects.get(slug="constant") | ||||||||||||||
|
|
||||||||||||||
| # remove existing detections from the source image collection | ||||||||||||||
| for image in self.source_image_collection.images.all(): | ||||||||||||||
| image.detections.all().delete() | ||||||||||||||
| image.save() | ||||||||||||||
|
|
||||||||||||||
| def test_kill_dangling_ml_job(self): | ||||||||||||||
| """Test killing a dangling ML job.""" | ||||||||||||||
| from ami.ml.tasks import check_dangling_ml_jobs | ||||||||||||||
| from config import celery_app | ||||||||||||||
|
|
||||||||||||||
| job = Job.objects.create( | ||||||||||||||
| job_type_key=MLJob.key, | ||||||||||||||
| project=self.project, | ||||||||||||||
| name="Test dangling job", | ||||||||||||||
| delay=0, | ||||||||||||||
| pipeline=self.pipeline, | ||||||||||||||
| source_image_collection=self.source_image_collection, | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| job.run() | ||||||||||||||
| connection.commit() | ||||||||||||||
| job.refresh_from_db() | ||||||||||||||
|
|
||||||||||||||
| # Simulate last_checked being older than 5 minutes | ||||||||||||||
| job.last_checked = datetime.datetime.now() - datetime.timedelta(minutes=10) | ||||||||||||||
| job.save(update_fields=["last_checked"]) | ||||||||||||||
|
|
||||||||||||||
| # Run the dangling job checker | ||||||||||||||
| check_dangling_ml_jobs() | ||||||||||||||
|
|
||||||||||||||
| # Refresh job from DB | ||||||||||||||
| job.refresh_from_db() | ||||||||||||||
|
|
||||||||||||||
| # Make sure no tasks are still in progress | ||||||||||||||
| for ml_task_record in job.ml_task_records.all(): | ||||||||||||||
| self.assertEqual(ml_task_record.status, MLSubtaskState.REVOKED.value) | ||||||||||||||
|
|
||||||||||||||
| # Also check celery queue to make sure all tasks have been revoked | ||||||||||||||
| task_id = ml_task_record.task_id | ||||||||||||||
|
|
||||||||||||||
| inspector = celery_app.control.inspect() | ||||||||||||||
| active = inspector.active() or {} | ||||||||||||||
| reserved = inspector.reserved() or {} | ||||||||||||||
|
|
||||||||||||||
| not_running = all( | ||||||||||||||
| task_id not in [t["id"] for w in active.values() for t in w] for w in active.values() | ||||||||||||||
| ) and all(task_id not in [t["id"] for w in reserved.values() for t in w] for w in reserved.values()) | ||||||||||||||
|
Comment on lines
+466
to
+468
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix redundant loop in task presence check. The logic for verifying that Apply this diff to simplify and correct the logic: - not_running = all(
- task_id not in [t["id"] for w in active.values() for t in w] for w in active.values()
- ) and all(task_id not in [t["id"] for w in reserved.values() for t in w] for w in reserved.values())
+ active_task_ids = [t["id"] for tasks in active.values() for t in tasks]
+ reserved_task_ids = [t["id"] for tasks in reserved.values() for t in tasks]
+ not_running = task_id not in active_task_ids and task_id not in reserved_task_ids📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||
|
|
||||||||||||||
| self.assertTrue(not_running) | ||||||||||||||
|
|
||||||||||||||
| self.assertEqual(job.status, JobState.REVOKED.value) | ||||||||||||||
|
|
||||||||||||||
| def test_kill_task_prevents_execution(self): | ||||||||||||||
| from ami.jobs.models import Job, MLSubtaskNames, MLTaskRecord | ||||||||||||||
| from ami.ml.models.pipeline import process_pipeline_request | ||||||||||||||
| from config import celery_app | ||||||||||||||
|
|
||||||||||||||
| logger.info("Testing that killing a task prevents its execution.") | ||||||||||||||
| result = process_pipeline_request.apply_async(args=[{}, 1], countdown=5) | ||||||||||||||
| logger.info(f"Scheduled task with id {result.id} to run in 5 seconds.") | ||||||||||||||
| task_id = result.id | ||||||||||||||
|
|
||||||||||||||
| job = Job.objects.create( | ||||||||||||||
| job_type_key=MLJob.key, | ||||||||||||||
| project=self.project, | ||||||||||||||
| name="Test killing job tasks", | ||||||||||||||
| delay=0, | ||||||||||||||
| pipeline=self.pipeline, | ||||||||||||||
| source_image_collection=self.source_image_collection, | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| ml_task_record = MLTaskRecord.objects.create( | ||||||||||||||
| job=job, task_name=MLSubtaskNames.process_pipeline_request.value, task_id=task_id | ||||||||||||||
| ) | ||||||||||||||
| logger.info(f"Killing task {task_id} immediately.") | ||||||||||||||
| ml_task_record.kill_task() | ||||||||||||||
|
|
||||||||||||||
| async_result = celery_app.AsyncResult(task_id) | ||||||||||||||
| time.sleep(5) # the REVOKED STATUS isn't visible until the task is actually run after the delay | ||||||||||||||
|
|
||||||||||||||
| self.assertIn(async_result.state, ["REVOKED"]) | ||||||||||||||
| self.assertEqual(ml_task_record.status, "REVOKED") | ||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -141,6 +141,11 @@ def check_ml_job_status(ml_job_id: int): | |
| job.update_status(JobState.FAILURE) | ||
| job.finished_at = datetime.datetime.now() | ||
| job.save() | ||
|
|
||
| # Remove remaining tasks from the queue | ||
| for ml_task_record in job.ml_task_records.all(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we have a Cancel action for the parent Job. Can you move or add this logic there? |
||
| ml_task_record.kill_task() | ||
|
|
||
| raise Exception(error_msg) | ||
|
|
||
|
|
||
|
|
@@ -171,5 +176,8 @@ def check_dangling_ml_jobs(): | |
| job.update_status(JobState.REVOKED) | ||
| job.finished_at = datetime.datetime.now() | ||
| job.save() | ||
|
|
||
| for ml_task_record in job.ml_task_records.all(): | ||
| ml_task_record.kill_task() | ||
| else: | ||
| logger.info(f"Job {job.pk} is active. Last checked at {last_checked}.") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use a timezone-aware default for
last_checked.DateTimeField(default=datetime.datetime.now)returns naïve datetimes whenUSE_TZ=True, triggering warnings and risking incorrect conversions. Please switch todjango.utils.timezone.now, which returns an aware datetime.Suggested fix:
🤖 Prompt for AI Agents