33__all__ = ["JobDB" ]
44
55from datetime import datetime , timezone
6- from typing import TYPE_CHECKING , Any , Iterable
6+ from typing import TYPE_CHECKING , Any , Iterable , cast
77
88from sqlalchemy import bindparam , case , delete , insert , select , update
99
1010if TYPE_CHECKING :
1111 from sqlalchemy .sql .elements import BindParameter
12+ from sqlalchemy import Table
1213
1314from diracx .core .exceptions import InvalidQueryError
1415from diracx .core .models import JobCommand , SearchSpec , SortSpec
@@ -73,7 +74,7 @@ async def search(
7374 async def create_job (self , compressed_original_jdl : str ):
7475 """Used to insert a new job with original JDL. Returns inserted job id."""
7576 result = await self .conn .execute (
76- JobJDLs .__table__ .insert ().values (
77+ cast ( "Table" , JobJDLs .__table__ ) .insert ().values (
7778 JDL = "" ,
7879 JobRequirements = "" ,
7980 OriginalJDL = compressed_original_jdl ,
@@ -89,7 +90,7 @@ async def delete_jobs(self, job_ids: list[int]):
8990 async def insert_input_data (self , lfns : dict [int , list [str ]]):
9091 """Insert input data for jobs."""
9192 await self .conn .execute (
92- InputData .__table__ .insert (),
93+ cast ( "Table" , InputData .__table__ ) .insert (),
9394 [
9495 {
9596 "JobID" : job_id ,
@@ -103,7 +104,7 @@ async def insert_input_data(self, lfns: dict[int, list[str]]):
103104 async def insert_job_attributes (self , jobs_to_update : dict [int , dict ]):
104105 """Insert the job attributes."""
105106 await self .conn .execute (
106- Jobs .__table__ .insert (),
107+ cast ( "Table" , Jobs .__table__ ) .insert (),
107108 [
108109 {
109110 "JobID" : job_id ,
@@ -116,7 +117,7 @@ async def insert_job_attributes(self, jobs_to_update: dict[int, dict]):
116117 async def update_job_jdls (self , jdls_to_update : dict [int , str ]):
117118 """Used to update the JDL, typically just after inserting the original JDL, or rescheduling, for example."""
118119 await self .conn .execute (
119- JobJDLs .__table__ .update ().where (
120+ cast ( "Table" , JobJDLs .__table__ ) .update ().where (
120121 JobJDLs .__table__ .c .JobID == bindparam ("b_JobID" )
121122 ),
122123 [
@@ -171,7 +172,7 @@ async def set_job_attributes(self, job_data):
171172 }
172173
173174 stmt = (
174- Jobs .__table__ .update ()
175+ cast ( "Table" , Jobs .__table__ ) .update ()
175176 .values (** case_expressions )
176177 .where (Jobs .__table__ .c .JobID .in_ (job_data .keys ()))
177178 )
@@ -228,7 +229,7 @@ async def set_properties(
228229 required_parameters = list (required_parameters_set )[0 ]
229230 update_parameters = [{"job_id" : k , ** v } for k , v in properties .items ()]
230231
231- columns = _get_columns (Jobs .__table__ , required_parameters )
232+ columns = _get_columns (cast ( "Table" , Jobs .__table__ ), list ( required_parameters ) )
232233 values : dict [str , BindParameter [Any ] | datetime ] = {
233234 c .name : bindparam (c .name ) for c in columns
234235 }
@@ -267,7 +268,7 @@ async def add_heartbeat_data(
267268 }
268269 for key , value in dynamic_data .items ()
269270 ]
270- await self .conn .execute (HeartBeatLoggingInfo .__table__ .insert ().values (values ))
271+ await self .conn .execute (cast ( "Table" , HeartBeatLoggingInfo .__table__ ) .insert ().values (values ))
271272
272273 async def get_job_commands (self , job_ids : Iterable [int ]) -> list [JobCommand ]:
273274 """Get a command to be passed to the job together with the next heartbeat.
0 commit comments