@@ -254,6 +254,7 @@ def create_job(
254254 name : str ,
255255 query : str ,
256256 query_type : JobQueryType = JobQueryType .PYTHON ,
257+ status : JobStatus = JobStatus .CREATED ,
257258 workers : int = 1 ,
258259 python_version : Optional [str ] = None ,
259260 params : Optional [dict [str , str ]] = None ,
@@ -264,33 +265,35 @@ def create_job(
264265 """
265266
266267 @abstractmethod
267- def set_job_status (
268+ def get_job (self , job_id : str ) -> Optional [Job ]:
269+ """Returns the job with the given ID."""
270+
271+ @abstractmethod
272+ def update_job (
268273 self ,
269274 job_id : str ,
270- status : JobStatus ,
275+ status : Optional [JobStatus ] = None ,
276+ exit_code : Optional [int ] = None ,
271277 error_message : Optional [str ] = None ,
272278 error_stack : Optional [str ] = None ,
279+ finished_at : Optional [datetime ] = None ,
273280 metrics : Optional [dict [str , Any ]] = None ,
274- ) -> None :
275- """Set the status of the given job."""
281+ ) -> Optional [ "Job" ] :
282+ """Updates job fields ."""
276283
277284 @abstractmethod
278- def get_job_status (self , job_id : str ) -> Optional [JobStatus ]:
279- """Returns the status of the given job."""
280-
281- @abstractmethod
282- def set_job_and_dataset_status (
285+ def set_job_status (
283286 self ,
284287 job_id : str ,
285- job_status : JobStatus ,
286- dataset_status : DatasetStatus ,
288+ status : JobStatus ,
289+ error_message : Optional [str ] = None ,
290+ error_stack : Optional [str ] = None ,
287291 ) -> None :
288- """Set the status of the given job and dataset ."""
292+ """Set the status of the given job."""
289293
290294 @abstractmethod
291- def get_job_dataset_versions (self , job_id : str ) -> list [tuple [str , int ]]:
292- """Returns dataset names and versions for the job."""
293- raise NotImplementedError
295+ def get_job_status (self , job_id : str ) -> Optional [JobStatus ]:
296+ """Returns the status of the given job."""
294297
295298
296299class AbstractDBMetastore (AbstractMetastore ):
@@ -651,30 +654,31 @@ def update_dataset_version(
651654 dataset_version = dataset .get_version (version )
652655
653656 values = {}
657+ version_values : dict = {}
654658 for field , value in kwargs .items ():
655659 if field in self ._dataset_version_fields [1 :]:
656660 if field == "schema" :
657- dataset_version .update (** {field : DatasetRecord .parse_schema (value )})
658661 values [field ] = json .dumps (value ) if value else None
662+ version_values [field ] = DatasetRecord .parse_schema (value )
659663 elif field == "feature_schema" :
660664 values [field ] = json .dumps (value ) if value else None
665+ version_values [field ] = value
661666 elif field == "preview" and isinstance (value , list ):
662667 values [field ] = json .dumps (value , cls = JSONSerialize )
668+ version_values [field ] = value
663669 else :
664670 values [field ] = value
665- dataset_version .update (** {field : value })
666-
667- if not values :
668- # Nothing to update
669- return dataset_version
671+ version_values [field ] = value
670672
671- dv = self ._datasets_versions
672- self .db .execute (
673- self ._datasets_versions_update ()
674- .where (dv .c .id == dataset_version .id )
675- .values (values ),
676- conn = conn ,
677- ) # type: ignore [attr-defined]
673+ if values :
674+ dv = self ._datasets_versions
675+ self .db .execute (
676+ self ._datasets_versions_update ()
677+ .where (dv .c .dataset_id == dataset .id and dv .c .version == version )
678+ .values (values ),
679+ conn = conn ,
680+ ) # type: ignore [attr-defined]
681+ dataset_version .update (** version_values )
678682
679683 return dataset_version
680684
@@ -702,7 +706,7 @@ def _get_dataset_query(
702706 dataset_fields : list [str ],
703707 dataset_version_fields : list [str ],
704708 isouter : bool = True ,
705- ):
709+ ) -> "Select" :
706710 if not (
707711 self .db .has_table (self ._datasets .name )
708712 and self .db .has_table (self ._datasets_versions .name )
@@ -719,12 +723,12 @@ def _get_dataset_query(
719723 j = d .join (dv , d .c .id == dv .c .dataset_id , isouter = isouter )
720724 return query .select_from (j )
721725
722- def _base_dataset_query (self ):
726+ def _base_dataset_query (self ) -> "Select" :
723727 return self ._get_dataset_query (
724728 self ._dataset_fields , self ._dataset_version_fields
725729 )
726730
727- def _base_list_datasets_query (self ):
731+ def _base_list_datasets_query (self ) -> "Select" :
728732 return self ._get_dataset_query (
729733 self ._dataset_list_fields , self ._dataset_list_version_fields , isouter = False
730734 )
@@ -1018,6 +1022,7 @@ def create_job(
10181022 name : str ,
10191023 query : str ,
10201024 query_type : JobQueryType = JobQueryType .PYTHON ,
1025+ status : JobStatus = JobStatus .CREATED ,
10211026 workers : int = 1 ,
10221027 python_version : Optional [str ] = None ,
10231028 params : Optional [dict [str , str ]] = None ,
@@ -1032,7 +1037,7 @@ def create_job(
10321037 self ._jobs_insert ().values (
10331038 id = job_id ,
10341039 name = name ,
1035- status = JobStatus . CREATED ,
1040+ status = status ,
10361041 created_at = datetime .now (timezone .utc ),
10371042 query = query ,
10381043 query_type = query_type .value ,
@@ -1047,25 +1052,65 @@ def create_job(
10471052 )
10481053 return job_id
10491054
1055+ def get_job (self , job_id : str , conn = None ) -> Optional [Job ]:
1056+ """Returns the job with the given ID."""
1057+ query = self ._jobs_select (self ._jobs ).where (self ._jobs .c .id == job_id )
1058+ results = list (self .db .execute (query , conn = conn ))
1059+ if not results :
1060+ return None
1061+ return self ._parse_job (results [0 ])
1062+
1063+ def update_job (
1064+ self ,
1065+ job_id : str ,
1066+ status : Optional [JobStatus ] = None ,
1067+ exit_code : Optional [int ] = None ,
1068+ error_message : Optional [str ] = None ,
1069+ error_stack : Optional [str ] = None ,
1070+ finished_at : Optional [datetime ] = None ,
1071+ metrics : Optional [dict [str , Any ]] = None ,
1072+ conn : Optional [Any ] = None ,
1073+ ) -> Optional ["Job" ]:
1074+ """Updates job fields."""
1075+ values : dict = {}
1076+ if status is not None :
1077+ values ["status" ] = status
1078+ if exit_code is not None :
1079+ values ["exit_code" ] = exit_code
1080+ if error_message is not None :
1081+ values ["error_message" ] = error_message
1082+ if error_stack is not None :
1083+ values ["error_stack" ] = error_stack
1084+ if finished_at is not None :
1085+ values ["finished_at" ] = finished_at
1086+ if metrics :
1087+ values ["metrics" ] = json .dumps (metrics )
1088+
1089+ if values :
1090+ j = self ._jobs
1091+ self .db .execute (
1092+ self ._jobs_update ().where (j .c .id == job_id ).values (** values ),
1093+ conn = conn ,
1094+ ) # type: ignore [attr-defined]
1095+
1096+ return self .get_job (job_id , conn = conn )
1097+
10501098 def set_job_status (
10511099 self ,
10521100 job_id : str ,
10531101 status : JobStatus ,
10541102 error_message : Optional [str ] = None ,
10551103 error_stack : Optional [str ] = None ,
1056- metrics : Optional [dict [str , Any ]] = None ,
10571104 conn : Optional [Any ] = None ,
10581105 ) -> None :
10591106 """Set the status of the given job."""
1060- values : dict = {"status" : status . value }
1061- if status . value in JobStatus .finished ():
1107+ values : dict = {"status" : status }
1108+ if status in JobStatus .finished ():
10621109 values ["finished_at" ] = datetime .now (timezone .utc )
10631110 if error_message :
10641111 values ["error_message" ] = error_message
10651112 if error_stack :
10661113 values ["error_stack" ] = error_stack
1067- if metrics :
1068- values ["metrics" ] = json .dumps (metrics )
10691114 self .db .execute (
10701115 self ._jobs_update (self ._jobs .c .id == job_id ).values (** values ),
10711116 conn = conn ,
@@ -1086,37 +1131,3 @@ def get_job_status(
10861131 if not results :
10871132 return None
10881133 return results [0 ][0 ]
1089-
1090- def set_job_and_dataset_status (
1091- self ,
1092- job_id : str ,
1093- job_status : JobStatus ,
1094- dataset_status : DatasetStatus ,
1095- ) -> None :
1096- """Set the status of the given job and dataset."""
1097- with self .db .transaction () as conn :
1098- self .set_job_status (job_id , status = job_status , conn = conn )
1099- dv = self ._datasets_versions
1100- query = (
1101- self ._datasets_versions_update ()
1102- .where (
1103- (dv .c .job_id == job_id ) & (dv .c .status != DatasetStatus .COMPLETE )
1104- )
1105- .values (status = dataset_status )
1106- )
1107- self .db .execute (query , conn = conn ) # type: ignore[attr-defined]
1108-
1109- def get_job_dataset_versions (self , job_id : str ) -> list [tuple [str , int ]]:
1110- """Returns dataset names and versions for the job."""
1111- dv = self ._datasets_versions
1112- ds = self ._datasets
1113-
1114- join_condition = dv .c .dataset_id == ds .c .id
1115-
1116- query = (
1117- self ._datasets_versions_select (ds .c .name , dv .c .version )
1118- .select_from (dv .join (ds , join_condition ))
1119- .where (dv .c .job_id == job_id )
1120- )
1121-
1122- return list (self .db .execute (query ))
0 commit comments