22"""
33
44import uuid
5-
65from typing import Annotated , Any , Tuple , Dict
76from fastapi import APIRouter , Depends , HTTPException , status
87from pydantic import BaseModel
3029 FileTable ,
3130 InstTable ,
3231 ModelTable ,
32+ JobTable ,
3333)
3434
3535from ..gcsutil import StorageControl
@@ -59,6 +59,7 @@ class ModelCreationRequest(BaseModel):
5959class ModelInfo (BaseModel ):
6060 """The model object that's returned."""
6161
62+ # The model id is unique for every instance of the model (e.g. model name + version id pair)
6263 m_id : str
6364 name : str
6465 inst_id : str
@@ -76,15 +77,23 @@ class RunInfo(BaseModel):
7677 run_id : str
7778 vers_id : int = 0
7879 inst_id : str
79- m_id : str
80+ m_name : str
8081 # user id of the person who executed this run.
8182 created_by : str | None = None
83+ # Time the run info was triggered if it was triggered in the webapp
84+ triggered_at : datetime | None = None
85+ # Batch used for the run
86+ batch_name : str | None = None
87+ # output file name
88+ output_filename : str | None = None
89+ completed : bool | None = None
90+ err_msg : str | None = None
8291
8392
8493class InferenceRunRequest (BaseModel ):
8594 """Parameters for an inference run."""
8695
87- batch_id : str
96+ batch_name : str
8897
8998
9099# Model related operations. Or model specific data.
@@ -357,18 +366,21 @@ def read_inst_model_outputs(
357366 detail = "Multiple models of the same version found, this should not have happened." ,
358367 )
359368 res = []
360- if query_result [0 ][0 ].run_ids :
361- for elem in query_result [0 ][0 ].run_ids :
362- res .append (
363- {
364- # xxxx
365- "run_id" : "placeholder" ,
366- "inst_id" : uuid_to_str (query_result [0 ][0 ].inst_id ),
367- "m_id" : uuid_to_str (query_result [0 ][0 ].id ),
368- "created_by" : "placeholder" ,
369- "vers_id" : vers_id ,
370- }
371- )
369+ for elem in query_result [0 ][0 ].jobs or []:
370+ # if not elem.output_filename:
371+ # TODO make a query to databricks to retrieve status.
372+ res .append (
373+ {
374+ "vers_id" : vers_id ,
375+ "inst_id" : uuid_to_str (query_result [0 ][0 ].inst_id ),
376+ "m_name" : query_result [0 ][0 ].name ,
377+ "run_id" : elem .id ,
378+ "created_by" : uuid_to_str (elem .created_by ),
379+ "triggered_at" : elem .triggered_at ,
380+ "batch_name" : elem .batch_name ,
381+ "output_filename" : elem .output_filename ,
382+ }
383+ )
372384 return res
373385
374386
@@ -417,21 +429,27 @@ def read_inst_model_output(
417429 status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
418430 detail = "Multiple models of the same version found, this should not have happened." ,
419431 )
420- if query_result [0 ][0 ].run_ids :
421- # TODO xxx: check the run id is present, then make a query to Databricks
422- return {
423- "run_id" : "placeholder" ,
424- "inst_id" : uuid_to_str (query_result [0 ][0 ].inst_id ),
425- "m_id" : uuid_to_str (query_result [0 ][0 ].id ),
426- "created_by" : "placeholder" ,
427- "vers_id" : vers_id ,
428- }
432+
433+ for elem in query_result [0 ][0 ].jobs or []:
434+ if elem .id == run_id :
435+ # TODO xxx: if the output_filename is empty make a query to Databricks
436+ return {
437+ "vers_id" : vers_id ,
438+ "inst_id" : uuid_to_str (query_result [0 ][0 ].inst_id ),
439+ "m_name" : query_result [0 ][0 ].name ,
440+ "run_id" : elem .id ,
441+ "created_by" : uuid_to_str (elem .created_by ),
442+ "triggered_at" : elem .triggered_at ,
443+ "batch_name" : elem .batch_name ,
444+ "output_filename" : elem .output_filename ,
445+ }
429446 raise HTTPException (
430447 status_code = status .HTTP_404_NOT_FOUND ,
431448 detail = "Run not found." ,
432449 )
433450
434451
452+ # TODO: xxx update the run info returned items.
435453@router .post (
436454 "/{inst_id}/models/{model_name}/vers/{vers_id}/run-inference" ,
437455 response_model = RunInfo ,
@@ -476,13 +494,28 @@ def trigger_inference_run(
476494 status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
477495 detail = "Multiple models of the same version found, this should not have happened." ,
478496 )
479- # TODO issue Databricks call
480- return { # xxxx
481- "run_id" : "placeholder" ,
482- "inst_id" : uuid_to_str (query_result [0 ][0 ].inst_id ),
483- "m_id" : uuid_to_str (query_result [0 ][0 ].id ),
484- "created_by" : "placeholder" ,
497+ # TODO issue Databricks call then use the result to populate the below
498+
499+ triggered_timestamp = datetime .now ()
500+ # Add an entry to the jobs table with the job id
501+ """
502+ job = JobTable(
503+ id=12345,
504+ triggered_at=triggered_timestamp,
505+ created_by=str_to_uuid(current_user.user_id),
506+ batch_name=req.batch_name,
507+ model_id=query_result[0][0].id,
508+ )
509+ local_session.get().add(job)
510+ """
511+ return {
485512 "vers_id" : vers_id ,
513+ "inst_id" : uuid_to_str (query_result [0 ][0 ].inst_id ),
514+ "m_name" : query_result [0 ][0 ].name ,
515+ "run_id" : "placeholder" ,
516+ "created_by" : current_user .user_id ,
517+ "triggered_at" : triggered_timestamp ,
518+ "batch_name" : req .batch_name ,
486519 }
487520
488521
0 commit comments