Skip to content

Commit a646e32

Browse files
authored
Merge pull request #1248 from NASA-IMPACT/integrate_classification_queue
Integrate classification queue
2 parents 5ff3e26 + b7f6edf commit a646e32

16 files changed

+1443
-106
lines changed

config/settings/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,3 +349,4 @@
349349
LRM_DEV_TOKEN = env("LRM_DEV_TOKEN")
350350
XLI_TOKEN = env("XLI_TOKEN")
351351
INFERENCE_API_URL = env("INFERENCE_API_URL")
352+
TDAMM_CLASSIFICATION_THRESHOLD = env("TDAMM_CLASSIFICATION_THRESHOLD", default="0.5")

inference/models/inference.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
InferenceJobStatus,
1010
)
1111
from inference.utils.batch import BatchProcessor
12+
from inference.utils.classification_utils import update_url_with_classification_results
1213
from inference.utils.inference_api_client import InferenceAPIClient
1314

1415

@@ -167,6 +168,11 @@ def initiate(self, inference_api_url=settings.INFERENCE_API_URL) -> None:
167168

168169
if not created_batch:
169170
self.log_error_and_set_status_failed("No external jobs created")
171+
self.status = InferenceJobStatus.FAILED
172+
self.updated_at = timezone.now()
173+
self.completed_at = timezone.now()
174+
self.save()
175+
return
170176

171177
self.status = InferenceJobStatus.PENDING
172178
self.save()
@@ -184,17 +190,37 @@ def refresh_external_jobs_status_and_store_results(self) -> None:
184190
def reevaluate_progress_and_update_status(self) -> None:
185191
"""Evaluate overall job status and handle completion"""
186192

193+
if self.status == InferenceJobStatus.QUEUED:
194+
return
195+
196+
if not self.external_jobs.exists() and self.status == InferenceJobStatus.PENDING:
197+
self.status = InferenceJobStatus.FAILED
198+
self.error_message = "No external jobs created for pending job"
199+
self.completed_at = timezone.now()
200+
self.save()
201+
return
202+
187203
if self.get_ongoing_external_jobs().exists():
188204
self.status = InferenceJobStatus.PENDING
205+
self.updated_at = timezone.now()
189206
else:
190207
if self.get_failed_external_jobs().exists():
191208
self.status = InferenceJobStatus.FAILED
209+
self.updated_at = timezone.now()
192210
else:
193211
self.status = InferenceJobStatus.COMPLETED
212+
self.updated_at = timezone.now()
194213
self.completed_at = timezone.now()
195214
self.unload_model()
196215
self.save()
197216

217+
# If job is completed or failed, check if all classifications are done
218+
# if self.status in [InferenceJobStatus.COMPLETED, InferenceJobStatus.FAILED]:
219+
# self.collection.check_classifications_complete_and_finish_migration()
220+
221+
if self.status in [InferenceJobStatus.COMPLETED]:
222+
self.collection.check_classifications_complete_and_finish_migration()
223+
198224
def unload_model(self) -> None:
199225
"""
200226
Check that no other jobs are using the loaded model
@@ -247,6 +273,25 @@ def store_results(self, results) -> None:
247273
"""Store results and mark as completed"""
248274
try:
249275
self.results = results
276+
if results:
277+
collection = self.inference_job.collection
278+
279+
for idx, url_id in enumerate(self.url_ids):
280+
if idx < len(results):
281+
try:
282+
dump_url = collection.dump_urls.get(id=url_id)
283+
result = results[idx]
284+
# print(f"Processing result {idx}: {result}")
285+
if isinstance(result, dict) and "confidence" in result:
286+
# Ensure confidence is float
287+
result["confidence"] = float(result["confidence"])
288+
289+
update_url_with_classification_results(dump_url, results[idx])
290+
# tdamm_tags = update_url_with_classification_results(dump_url, results[idx])
291+
# print(f"tdamm_tags added: {tdamm_tags}")
292+
except collection.dump_urls.model.DoesNotExist:
293+
continue
294+
250295
self.mark_completed()
251296

252297
except Exception as e:
@@ -256,7 +301,8 @@ def refresh_status_and_store_results(self) -> None:
256301
"""Process this external job and update status/results"""
257302
try:
258303
api_client = InferenceAPIClient()
259-
model_version = ModelVersion.objects.get(classification_type=self.inference_job.classification_type)
304+
# model_version = ModelVersion.objects.get(classification_type=self.inference_job.classification_type)
305+
model_version = self.inference_job.model_version
260306

261307
response = api_client.get_job_status(model_version.api_identifier, self.external_job_id)
262308

@@ -268,7 +314,7 @@ def refresh_status_and_store_results(self) -> None:
268314
# Handle completion or failure
269315
if new_status == ExternalJobStatus.COMPLETED:
270316
self.store_results(response.get("results"))
271-
self.completed_at = timezone.now()
317+
# self.completed_at = timezone.now() # completed in mark_completed called in store_results
272318
self.save()
273319

274320
except Exception as e:

inference/tasks.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,6 @@
55
from inference.utils.advisory_lock import AdvisoryLock
66

77

8-
def generate_inference_job(collection, classification_type):
9-
"""Creates a new inference job for a collection."""
10-
return InferenceJob.objects.create(
11-
collection=collection, classification_type=classification_type, status=InferenceJobStatus.QUEUED
12-
)
13-
14-
158
@shared_task
169
def process_inference_job_queue():
1710
"""
@@ -25,13 +18,18 @@ def process_inference_job_queue():
2518
return "Queue processing already in progress"
2619

2720
try:
21+
# Reevaluate progress and update status of all inference jobs that are not currently queued
22+
# for job in InferenceJob.objects.exclude(status=InferenceJobStatus.QUEUED):
23+
# job.reevaluate_progress_and_update_status()
24+
2825
# Look for pending jobs first
2926
pending_jobs = InferenceJob.objects.filter(status=InferenceJobStatus.PENDING)
3027

3128
if pending_jobs.exists():
32-
# Process pending jobs
29+
# Refresh and process pending jobs
3330
for job in pending_jobs:
3431
job.refresh_external_jobs_status_and_store_results()
32+
job.reevaluate_progress_and_update_status()
3533
else:
3634
# If no pending jobs, try to initiate a queued job
3735
queued_job = (

inference/tests/local_test_inference_api_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# inference/tests/test_inference_api_client.py
1+
# inference/tests/local_test_inference_api_client.py
2+
# docker-compose -f local.yml run --rm django pytest inference/tests/local_test_inference_api_client.py
23

34
"""
45
This is a test designed to be run on a local machine which has the inference pipeline running

inference/tests/test_batch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# inference/tests/test_batch.py
2+
# docker-compose -f local.yml run --rm django pytest inference/tests/test_batch.py
23
from unittest.mock import MagicMock, Mock, patch
34

45
import pytest

0 commit comments

Comments
 (0)