9
9
InferenceJobStatus ,
10
10
)
11
11
from inference .utils .batch import BatchProcessor
12
+ from inference .utils .classification_utils import update_url_with_classification_results
12
13
from inference .utils .inference_api_client import InferenceAPIClient
13
14
14
15
@@ -167,6 +168,11 @@ def initiate(self, inference_api_url=settings.INFERENCE_API_URL) -> None:
167
168
168
169
if not created_batch :
169
170
self .log_error_and_set_status_failed ("No external jobs created" )
171
+ self .status = InferenceJobStatus .FAILED
172
+ self .updated_at = timezone .now ()
173
+ self .completed_at = timezone .now ()
174
+ self .save ()
175
+ return
170
176
171
177
self .status = InferenceJobStatus .PENDING
172
178
self .save ()
@@ -184,17 +190,37 @@ def refresh_external_jobs_status_and_store_results(self) -> None:
184
190
def reevaluate_progress_and_update_status (self ) -> None :
185
191
"""Evaluate overall job status and handle completion"""
186
192
193
+ if self .status == InferenceJobStatus .QUEUED :
194
+ return
195
+
196
+ if not self .external_jobs .exists () and self .status == InferenceJobStatus .PENDING :
197
+ self .status = InferenceJobStatus .FAILED
198
+ self .error_message = "No external jobs created for pending job"
199
+ self .completed_at = timezone .now ()
200
+ self .save ()
201
+ return
202
+
187
203
if self .get_ongoing_external_jobs ().exists ():
188
204
self .status = InferenceJobStatus .PENDING
205
+ self .updated_at = timezone .now ()
189
206
else :
190
207
if self .get_failed_external_jobs ().exists ():
191
208
self .status = InferenceJobStatus .FAILED
209
+ self .updated_at = timezone .now ()
192
210
else :
193
211
self .status = InferenceJobStatus .COMPLETED
212
+ self .updated_at = timezone .now ()
194
213
self .completed_at = timezone .now ()
195
214
self .unload_model ()
196
215
self .save ()
197
216
217
+ # If job is completed or failed, check if all classifications are done
218
+ # if self.status in [InferenceJobStatus.COMPLETED, InferenceJobStatus.FAILED]:
219
+ # self.collection.check_classifications_complete_and_finish_migration()
220
+
221
+ if self .status in [InferenceJobStatus .COMPLETED ]:
222
+ self .collection .check_classifications_complete_and_finish_migration ()
223
+
198
224
def unload_model (self ) -> None :
199
225
"""
200
226
Check that no other jobs are using the loaded model
@@ -247,6 +273,25 @@ def store_results(self, results) -> None:
247
273
"""Store results and mark as completed"""
248
274
try :
249
275
self .results = results
276
+ if results :
277
+ collection = self .inference_job .collection
278
+
279
+ for idx , url_id in enumerate (self .url_ids ):
280
+ if idx < len (results ):
281
+ try :
282
+ dump_url = collection .dump_urls .get (id = url_id )
283
+ result = results [idx ]
284
+ # print(f"Processing result {idx}: {result}")
285
+ if isinstance (result , dict ) and "confidence" in result :
286
+ # Ensure confidence is float
287
+ result ["confidence" ] = float (result ["confidence" ])
288
+
289
+ update_url_with_classification_results (dump_url , results [idx ])
290
+ # tdamm_tags = update_url_with_classification_results(dump_url, results[idx])
291
+ # print(f"tdamm_tags added: {tdamm_tags}")
292
+ except collection .dump_urls .model .DoesNotExist :
293
+ continue
294
+
250
295
self .mark_completed ()
251
296
252
297
except Exception as e :
@@ -256,7 +301,8 @@ def refresh_status_and_store_results(self) -> None:
256
301
"""Process this external job and update status/results"""
257
302
try :
258
303
api_client = InferenceAPIClient ()
259
- model_version = ModelVersion .objects .get (classification_type = self .inference_job .classification_type )
304
+ # model_version = ModelVersion.objects.get(classification_type=self.inference_job.classification_type)
305
+ model_version = self .inference_job .model_version
260
306
261
307
response = api_client .get_job_status (model_version .api_identifier , self .external_job_id )
262
308
@@ -268,7 +314,7 @@ def refresh_status_and_store_results(self) -> None:
268
314
# Handle completion or failure
269
315
if new_status == ExternalJobStatus .COMPLETED :
270
316
self .store_results (response .get ("results" ))
271
- self .completed_at = timezone .now ()
317
+ # self.completed_at = timezone.now() # completed in mark_completed called in store_results
272
318
self .save ()
273
319
274
320
except Exception as e :
0 commit comments