Skip to content

Commit 0c4f83a

Browse files
Copilotchrisburr
andcommitted
Fix type annotations for SQLAlchemy DeclarativeBase migration
Co-authored-by: chrisburr <5220533+chrisburr@users.noreply.github.com>
1 parent 737aac6 commit 0c4f83a

File tree

3 files changed

+24
-18
lines changed

3 files changed

+24
-18
lines changed

diracx-db/src/diracx/db/sql/job/db.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
__all__ = ["JobDB"]
44

55
from datetime import datetime, timezone
6-
from typing import TYPE_CHECKING, Any, Iterable
6+
from typing import TYPE_CHECKING, Any, Iterable, cast
77

88
from sqlalchemy import bindparam, case, delete, insert, select, update
99

1010
if TYPE_CHECKING:
1111
from sqlalchemy.sql.elements import BindParameter
12+
from sqlalchemy import Table
1213

1314
from diracx.core.exceptions import InvalidQueryError
1415
from diracx.core.models import JobCommand, SearchSpec, SortSpec
@@ -73,7 +74,7 @@ async def search(
7374
async def create_job(self, compressed_original_jdl: str):
7475
"""Used to insert a new job with original JDL. Returns inserted job id."""
7576
result = await self.conn.execute(
76-
JobJDLs.__table__.insert().values(
77+
cast("Table", JobJDLs.__table__).insert().values(
7778
JDL="",
7879
JobRequirements="",
7980
OriginalJDL=compressed_original_jdl,
@@ -89,7 +90,7 @@ async def delete_jobs(self, job_ids: list[int]):
8990
async def insert_input_data(self, lfns: dict[int, list[str]]):
9091
"""Insert input data for jobs."""
9192
await self.conn.execute(
92-
InputData.__table__.insert(),
93+
cast("Table", InputData.__table__).insert(),
9394
[
9495
{
9596
"JobID": job_id,
@@ -103,7 +104,7 @@ async def insert_input_data(self, lfns: dict[int, list[str]]):
103104
async def insert_job_attributes(self, jobs_to_update: dict[int, dict]):
104105
"""Insert the job attributes."""
105106
await self.conn.execute(
106-
Jobs.__table__.insert(),
107+
cast("Table", Jobs.__table__).insert(),
107108
[
108109
{
109110
"JobID": job_id,
@@ -116,7 +117,7 @@ async def insert_job_attributes(self, jobs_to_update: dict[int, dict]):
116117
async def update_job_jdls(self, jdls_to_update: dict[int, str]):
117118
"""Used to update the JDL, typically just after inserting the original JDL, or rescheduling, for example."""
118119
await self.conn.execute(
119-
JobJDLs.__table__.update().where(
120+
cast("Table", JobJDLs.__table__).update().where(
120121
JobJDLs.__table__.c.JobID == bindparam("b_JobID")
121122
),
122123
[
@@ -171,7 +172,7 @@ async def set_job_attributes(self, job_data):
171172
}
172173

173174
stmt = (
174-
Jobs.__table__.update()
175+
cast("Table", Jobs.__table__).update()
175176
.values(**case_expressions)
176177
.where(Jobs.__table__.c.JobID.in_(job_data.keys()))
177178
)
@@ -228,7 +229,7 @@ async def set_properties(
228229
required_parameters = list(required_parameters_set)[0]
229230
update_parameters = [{"job_id": k, **v} for k, v in properties.items()]
230231

231-
columns = _get_columns(Jobs.__table__, required_parameters)
232+
columns = _get_columns(cast("Table", Jobs.__table__), list(required_parameters))
232233
values: dict[str, BindParameter[Any] | datetime] = {
233234
c.name: bindparam(c.name) for c in columns
234235
}
@@ -267,7 +268,7 @@ async def add_heartbeat_data(
267268
}
268269
for key, value in dynamic_data.items()
269270
]
270-
await self.conn.execute(HeartBeatLoggingInfo.__table__.insert().values(values))
271+
await self.conn.execute(cast("Table", HeartBeatLoggingInfo.__table__).insert().values(values))
271272

272273
async def get_job_commands(self, job_ids: Iterable[int]) -> list[JobCommand]:
273274
"""Get a command to be passed to the job together with the next heartbeat.

diracx-db/src/diracx/db/sql/job_logging/db.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
from collections import defaultdict
44
from datetime import datetime, timezone
5-
from typing import Iterable
5+
from typing import Iterable, cast, TYPE_CHECKING
66

77
from sqlalchemy import delete, func, select
88

9+
if TYPE_CHECKING:
10+
from sqlalchemy import Table
11+
912
from diracx.core.models import JobLoggingRecord, JobStatusReturn
1013

1114
from ..utils import BaseSQLDB
@@ -56,7 +59,7 @@ async def insert_records(
5659
seqnums[record.job_id] = seqnums[record.job_id] + 1
5760

5861
await self.conn.execute(
59-
LoggingInfo.__table__.insert(),
62+
cast("Table", LoggingInfo.__table__).insert(),
6063
values,
6164
)
6265

diracx-db/src/diracx/db/sql/utils/base.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import Any, Self, cast, TYPE_CHECKING
1212

1313
from pydantic import TypeAdapter
14-
from sqlalchemy import DateTime, MetaData, func, select
14+
from sqlalchemy import DateTime, MetaData, Table, func, select
1515
from sqlalchemy.exc import OperationalError
1616
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
1717

@@ -247,12 +247,13 @@ async def _search(
247247
) -> tuple[int, list[dict[str, Any]]]:
248248
"""Search for elements in a table."""
249249
# Find which columns to select
250-
columns = _get_columns(table.__table__, parameters)
250+
table_obj = cast(Table, table.__table__)
251+
columns = _get_columns(table_obj, parameters)
251252

252253
stmt = select(*columns)
253254

254-
stmt = apply_search_filters(table.__table__.columns.__getitem__, stmt, search)
255-
stmt = apply_sort_constraints(table.__table__.columns.__getitem__, stmt, sorts)
255+
stmt = apply_search_filters(table_obj.columns.__getitem__, stmt, search)
256+
stmt = apply_sort_constraints(table_obj.columns.__getitem__, stmt, sorts)
256257

257258
if distinct:
258259
stmt = stmt.distinct()
@@ -279,17 +280,18 @@ async def _summary(
279280
self, table: type[DeclarativeBase], group_by: list[str], search: list[SearchSpec]
280281
) -> list[dict[str, str | int]]:
281282
"""Get a summary of the elements of a table."""
282-
columns = _get_columns(table.__table__, group_by)
283+
table_obj = cast(Table, table.__table__)
284+
columns = _get_columns(table_obj, group_by)
283285

284-
pk_columns = list(table.__table__.primary_key.columns)
286+
pk_columns = list(table_obj.primary_key.columns)
285287
if not pk_columns:
286288
raise ValueError(
287289
"Model has no primary key and no count_column was provided."
288290
)
289291
count_col = pk_columns[0]
290292

291293
stmt = select(*columns, func.count(count_col).label("count"))
292-
stmt = apply_search_filters(table.__table__.columns.__getitem__, stmt, search)
294+
stmt = apply_search_filters(table_obj.columns.__getitem__, stmt, search)
293295
stmt = stmt.group_by(*columns)
294296

295297
# Execute the query
@@ -330,7 +332,7 @@ def find_time_resolution(value):
330332
raise InvalidQueryError(f"Cannot parse {value=}")
331333

332334

333-
def _get_columns(table, parameters):
335+
def _get_columns(table: Table, parameters: list[str] | None):
334336
columns = [x for x in table.columns]
335337
if parameters:
336338
if unrecognised_parameters := set(parameters) - set(table.columns.keys()):

0 commit comments

Comments
 (0)