99 InferenceJobStatus ,
1010)
1111from inference .utils .batch import BatchProcessor
12+ from inference .utils .classification_utils import update_url_with_classification_results
1213from inference .utils .inference_api_client import InferenceAPIClient
1314
1415
@@ -167,6 +168,11 @@ def initiate(self, inference_api_url=settings.INFERENCE_API_URL) -> None:
167168
168169 if not created_batch :
169170 self .log_error_and_set_status_failed ("No external jobs created" )
171+ self .status = InferenceJobStatus .FAILED
172+ self .updated_at = timezone .now ()
173+ self .completed_at = timezone .now ()
174+ self .save ()
175+ return
170176
171177 self .status = InferenceJobStatus .PENDING
172178 self .save ()
@@ -184,17 +190,37 @@ def refresh_external_jobs_status_and_store_results(self) -> None:
184190 def reevaluate_progress_and_update_status (self ) -> None :
185191 """Evaluate overall job status and handle completion"""
186192
193+ if self .status == InferenceJobStatus .QUEUED :
194+ return
195+
196+ if not self .external_jobs .exists () and self .status == InferenceJobStatus .PENDING :
197+ self .status = InferenceJobStatus .FAILED
198+ self .error_message = "No external jobs created for pending job"
199+ self .completed_at = timezone .now ()
200+ self .save ()
201+ return
202+
187203 if self .get_ongoing_external_jobs ().exists ():
188204 self .status = InferenceJobStatus .PENDING
205+ self .updated_at = timezone .now ()
189206 else :
190207 if self .get_failed_external_jobs ().exists ():
191208 self .status = InferenceJobStatus .FAILED
209+ self .updated_at = timezone .now ()
192210 else :
193211 self .status = InferenceJobStatus .COMPLETED
212+ self .updated_at = timezone .now ()
194213 self .completed_at = timezone .now ()
195214 self .unload_model ()
196215 self .save ()
197216
217+ # If job is completed or failed, check if all classifications are done
218+ # if self.status in [InferenceJobStatus.COMPLETED, InferenceJobStatus.FAILED]:
219+ # self.collection.check_classifications_complete_and_finish_migration()
220+
221+ if self .status in [InferenceJobStatus .COMPLETED ]:
222+ self .collection .check_classifications_complete_and_finish_migration ()
223+
198224 def unload_model (self ) -> None :
199225 """
200226 Check that no other jobs are using the loaded model
@@ -247,6 +273,25 @@ def store_results(self, results) -> None:
247273 """Store results and mark as completed"""
248274 try :
249275 self .results = results
276+ if results :
277+ collection = self .inference_job .collection
278+
279+ for idx , url_id in enumerate (self .url_ids ):
280+ if idx < len (results ):
281+ try :
282+ dump_url = collection .dump_urls .get (id = url_id )
283+ result = results [idx ]
284+ # print(f"Processing result {idx}: {result}")
285+ if isinstance (result , dict ) and "confidence" in result :
286+ # Ensure confidence is float
287+ result ["confidence" ] = float (result ["confidence" ])
288+
289+ update_url_with_classification_results (dump_url , results [idx ])
290+ # tdamm_tags = update_url_with_classification_results(dump_url, results[idx])
291+ # print(f"tdamm_tags added: {tdamm_tags}")
292+ except collection .dump_urls .model .DoesNotExist :
293+ continue
294+
250295 self .mark_completed ()
251296
252297 except Exception as e :
@@ -256,7 +301,8 @@ def refresh_status_and_store_results(self) -> None:
256301 """Process this external job and update status/results"""
257302 try :
258303 api_client = InferenceAPIClient ()
259- model_version = ModelVersion .objects .get (classification_type = self .inference_job .classification_type )
304+ # model_version = ModelVersion.objects.get(classification_type=self.inference_job.classification_type)
305+ model_version = self .inference_job .model_version
260306
261307 response = api_client .get_job_status (model_version .api_identifier , self .external_job_id )
262308
@@ -268,7 +314,7 @@ def refresh_status_and_store_results(self) -> None:
268314 # Handle completion or failure
269315 if new_status == ExternalJobStatus .COMPLETED :
270316 self .store_results (response .get ("results" ))
271- self .completed_at = timezone .now ()
317+ # self.completed_at = timezone.now() # completed in mark_completed called in store_results
272318 self .save ()
273319
274320 except Exception as e :
0 commit comments