Skip to content

Commit 60cf775

Browse files
committed
[DOP-28871] Improve SQLAlchemy query caching
1 parent 9c9d646 commit 60cf775

File tree

13 files changed

+431
-273
lines changed

13 files changed

+431
-273
lines changed

data_rentgen/db/repositories/dataset.py

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
String,
1414
any_,
1515
asc,
16+
bindparam,
1617
cast,
1718
desc,
1819
distinct,
@@ -28,25 +29,55 @@
2829
from data_rentgen.db.utils.search import make_tsquery, ts_match, ts_rank
2930
from data_rentgen.dto import DatasetDTO, PaginationDTO
3031

32+
fetch_bulk_query = select(Dataset).where(
33+
tuple_(Dataset.location_id, func.lower(Dataset.name)).in_(
34+
select(
35+
func.unnest(
36+
cast(bindparam("location_ids"), ARRAY(Integer())),
37+
cast(bindparam("names"), ARRAY(String())),
38+
)
39+
.table_valued("location_id", "name")
40+
.render_derived(),
41+
),
42+
),
43+
)
44+
45+
get_list_query = (
46+
select(Dataset)
47+
.where(Dataset.id == any_(bindparam("dataset_ids")))
48+
.options(selectinload(Dataset.location).selectinload(Location.addresses))
49+
.options(selectinload(Dataset.tag_values).selectinload(TagValue.tag))
50+
)
51+
52+
get_one_query = select(Dataset).where(
53+
Dataset.location_id == bindparam("location_id"),
54+
func.lower(Dataset.name) == bindparam("name"),
55+
)
56+
57+
get_stats_query = (
58+
select(
59+
Dataset.location_id.label("location_id"),
60+
func.count(Dataset.id.distinct()).label("total_datasets"),
61+
)
62+
.where(
63+
Dataset.location_id == any_(bindparam("location_ids")),
64+
)
65+
.group_by(Dataset.location_id)
66+
)
67+
3168

3269
class DatasetRepository(Repository[Dataset]):
3370
async def fetch_bulk(self, datasets_dto: list[DatasetDTO]) -> list[tuple[DatasetDTO, Dataset | None]]:
3471
if not datasets_dto:
3572
return []
3673

37-
location_ids = [dataset_dto.location.id for dataset_dto in datasets_dto]
38-
names = [dataset_dto.name.lower() for dataset_dto in datasets_dto]
39-
pairs = (
40-
func.unnest(
41-
cast(location_ids, ARRAY(Integer())),
42-
cast(names, ARRAY(String())),
43-
)
44-
.table_valued("location_id", "name")
45-
.render_derived()
74+
scalars = await self._session.scalars(
75+
fetch_bulk_query,
76+
{
77+
"location_ids": [item.location.id for item in datasets_dto],
78+
"names": [item.name.lower() for item in datasets_dto],
79+
},
4680
)
47-
48-
statement = select(Dataset).where(tuple_(Dataset.location_id, func.lower(Dataset.name)).in_(select(pairs)))
49-
scalars = await self._session.scalars(statement)
5081
existing = {(dataset.location_id, dataset.name.lower()): dataset for dataset in scalars.all()}
5182
return [
5283
(
@@ -139,39 +170,24 @@ async def paginate(
139170
async def list_by_ids(self, dataset_ids: Collection[int]) -> list[Dataset]:
140171
if not dataset_ids:
141172
return []
142-
query = (
143-
select(Dataset)
144-
.where(Dataset.id == any_(list(dataset_ids))) # type: ignore[arg-type]
145-
.options(selectinload(Dataset.location).selectinload(Location.addresses))
146-
.options(selectinload(Dataset.tag_values).selectinload(TagValue.tag))
147-
)
148-
result = await self._session.scalars(query)
173+
result = await self._session.scalars(get_list_query, {"dataset_ids": list(dataset_ids)})
149174
return list(result.all())
150175

151176
async def get_stats_by_location_ids(self, location_ids: Collection[int]) -> dict[int, Row]:
152177
if not location_ids:
153178
return {}
154179

155-
query = (
156-
select(
157-
Dataset.location_id.label("location_id"),
158-
func.count(Dataset.id.distinct()).label("total_datasets"),
159-
)
160-
.where(
161-
Dataset.location_id == any_(list(location_ids)), # type: ignore[arg-type]
162-
)
163-
.group_by(Dataset.location_id)
164-
)
165-
166-
query_result = await self._session.execute(query)
180+
query_result = await self._session.execute(get_stats_query, {"location_ids": list(location_ids)})
167181
return {row.location_id: row for row in query_result.all()}
168182

169183
async def _get(self, dataset: DatasetDTO) -> Dataset | None:
170-
statement = select(Dataset).where(
171-
Dataset.location_id == dataset.location.id,
172-
func.lower(Dataset.name) == dataset.name.lower(),
184+
return await self._session.scalar(
185+
get_one_query,
186+
{
187+
"location_id": dataset.location.id,
188+
"name": dataset.name.lower(),
189+
},
173190
)
174-
return await self._session.scalar(statement)
175191

176192
async def _create(self, dataset: DatasetDTO) -> Dataset:
177193
result = Dataset(location_id=dataset.location.id, name=dataset.name)

data_rentgen/db/repositories/dataset_symlink.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,37 @@
33

44
from collections.abc import Collection
55

6-
from sqlalchemy import ARRAY, BindParameter, Integer, any_, bindparam, cast, func, or_, select, tuple_
6+
from sqlalchemy import ARRAY, Integer, any_, bindparam, cast, func, or_, select, tuple_
77

88
from data_rentgen.db.models.dataset_symlink import DatasetSymlink, DatasetSymlinkType
99
from data_rentgen.db.repositories.base import Repository
1010
from data_rentgen.dto import DatasetSymlinkDTO
1111

12+
fetch_bulk_query = select(DatasetSymlink).where(
13+
tuple_(DatasetSymlink.from_dataset_id, DatasetSymlink.to_dataset_id).in_(
14+
select(
15+
func.unnest(
16+
cast(bindparam("from_dataset_ids"), ARRAY(Integer())),
17+
cast(bindparam("to_dataset_ids"), ARRAY(Integer())),
18+
)
19+
.table_valued("from_dataset_ids", "to_dataset_ids")
20+
.render_derived(),
21+
),
22+
),
23+
)
24+
25+
get_list_query = select(DatasetSymlink).where(
26+
or_(
27+
DatasetSymlink.from_dataset_id == any_(bindparam("dataset_ids")),
28+
DatasetSymlink.to_dataset_id == any_(bindparam("dataset_ids")),
29+
),
30+
)
31+
32+
get_one_query = select(DatasetSymlink).where(
33+
DatasetSymlink.from_dataset_id == bindparam("from_dataset_id"),
34+
DatasetSymlink.to_dataset_id == bindparam("to_dataset_id"),
35+
)
36+
1237

1338
class DatasetSymlinkRepository(Repository[DatasetSymlink]):
1439
async def fetch_bulk(
@@ -18,22 +43,13 @@ async def fetch_bulk(
1843
if not dataset_symlinks_dto:
1944
return []
2045

21-
from_dataset_ids = [dataset_symlink_dto.from_dataset.id for dataset_symlink_dto in dataset_symlinks_dto]
22-
to_dataset_ids = [dataset_symlink_dto.to_dataset.id for dataset_symlink_dto in dataset_symlinks_dto]
23-
24-
pairs = (
25-
func.unnest(
26-
cast(from_dataset_ids, ARRAY(Integer())),
27-
cast(to_dataset_ids, ARRAY(Integer())),
28-
)
29-
.table_valued("from_dataset_ids", "to_dataset_ids")
30-
.render_derived()
46+
scalars = await self._session.scalars(
47+
fetch_bulk_query,
48+
{
49+
"from_dataset_ids": [item.from_dataset.id for item in dataset_symlinks_dto],
50+
"to_dataset_ids": [item.to_dataset.id for item in dataset_symlinks_dto],
51+
},
3152
)
32-
33-
statement = select(DatasetSymlink).where(
34-
tuple_(DatasetSymlink.from_dataset_id, DatasetSymlink.to_dataset_id).in_(select(pairs)),
35-
)
36-
scalars = await self._session.scalars(statement)
3753
existing = {(item.from_dataset_id, item.to_dataset_id): item for item in scalars.all()}
3854
return [
3955
(
@@ -52,22 +68,17 @@ async def list_by_dataset_ids(self, dataset_ids: Collection[int]) -> list[Datase
5268
if not dataset_ids:
5369
return []
5470

55-
param: BindParameter[list[int]] = bindparam("dataset_ids")
56-
query = select(DatasetSymlink).where(
57-
or_(
58-
DatasetSymlink.from_dataset_id == any_(param),
59-
DatasetSymlink.to_dataset_id == any_(param),
60-
),
61-
)
62-
scalars = await self._session.scalars(query, {"dataset_ids": list(dataset_ids)})
71+
scalars = await self._session.scalars(get_list_query, {"dataset_ids": list(dataset_ids)})
6372
return list(scalars.all())
6473

6574
async def _get(self, dataset_symlink: DatasetSymlinkDTO) -> DatasetSymlink | None:
66-
query = select(DatasetSymlink).where(
67-
DatasetSymlink.from_dataset_id == dataset_symlink.from_dataset.id,
68-
DatasetSymlink.to_dataset_id == dataset_symlink.to_dataset.id,
75+
return await self._session.scalar(
76+
get_one_query,
77+
{
78+
"from_dataset_id": dataset_symlink.from_dataset.id,
79+
"to_dataset_id": dataset_symlink.to_dataset.id,
80+
},
6981
)
70-
return await self._session.scalar(query)
7182

7283
async def _create(self, dataset_symlink: DatasetSymlinkDTO) -> DatasetSymlink:
7384
result = DatasetSymlink(

data_rentgen/db/repositories/input.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,17 @@
1616
extract_timestamp_from_uuid,
1717
)
1818

19+
insert_statement = insert(Input)
20+
inserted_row = insert_statement.excluded
21+
insert_statement = insert_statement.on_conflict_do_update(
22+
index_elements=[Input.created_at, Input.id],
23+
set_={
24+
"num_bytes": func.greatest(inserted_row.num_bytes, Input.num_bytes),
25+
"num_rows": func.greatest(inserted_row.num_rows, Input.num_rows),
26+
"num_files": func.greatest(inserted_row.num_files, Input.num_files),
27+
},
28+
)
29+
1930

2031
@dataclass
2132
class InputRow:
@@ -37,19 +48,8 @@ async def create_or_update_bulk(self, inputs: list[InputDTO]) -> None:
3748
if not inputs:
3849
return
3950

40-
insert_statement = insert(Input)
41-
new_row = insert_statement.excluded
42-
statement = insert_statement.on_conflict_do_update(
43-
index_elements=[Input.created_at, Input.id],
44-
set_={
45-
"num_bytes": func.greatest(new_row.num_bytes, Input.num_bytes),
46-
"num_rows": func.greatest(new_row.num_rows, Input.num_rows),
47-
"num_files": func.greatest(new_row.num_files, Input.num_files),
48-
},
49-
)
50-
5151
await self._session.execute(
52-
statement,
52+
insert_statement,
5353
[
5454
{
5555
"id": item.generate_id(),

data_rentgen/db/repositories/job.py

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
String,
1414
any_,
1515
asc,
16+
bindparam,
1617
cast,
1718
desc,
1819
func,
@@ -30,6 +31,44 @@
3031
UNKNOWN_JOB_TYPE = 0
3132

3233

34+
fetch_bulk_query = select(Job).where(
35+
tuple_(Job.location_id, func.lower(Job.name)).in_(
36+
select(
37+
func.unnest(
38+
cast(bindparam("location_ids"), ARRAY(Integer())),
39+
cast(bindparam("names"), ARRAY(String())),
40+
)
41+
.table_valued("location_id", "name")
42+
.render_derived(),
43+
),
44+
),
45+
)
46+
47+
get_one_query = select(Job).where(
48+
Job.location_id == bindparam("location_id"),
49+
func.lower(Job.name) == bindparam("name"),
50+
)
51+
52+
get_list_query = (
53+
select(Job)
54+
.where(
55+
Job.id == any_(bindparam("job_ids")),
56+
)
57+
.options(selectinload(Job.location).selectinload(Location.addresses))
58+
)
59+
60+
get_stats_query = (
61+
select(
62+
Job.location_id.label("location_id"),
63+
func.count(Job.id.distinct()).label("total_jobs"),
64+
)
65+
.where(
66+
Job.location_id == any_(bindparam("location_ids")),
67+
)
68+
.group_by(Job.location_id)
69+
)
70+
71+
3372
class JobRepository(Repository[Job]):
3473
async def paginate(
3574
self,
@@ -90,19 +129,13 @@ async def fetch_bulk(self, jobs_dto: list[JobDTO]) -> list[tuple[JobDTO, Job | N
90129
if not jobs_dto:
91130
return []
92131

93-
location_ids = [job_dto.location.id for job_dto in jobs_dto]
94-
names = [job_dto.name.lower() for job_dto in jobs_dto]
95-
pairs = (
96-
func.unnest(
97-
cast(location_ids, ARRAY(Integer())),
98-
cast(names, ARRAY(String())),
99-
)
100-
.table_valued("location_id", "name")
101-
.render_derived()
132+
scalars = await self._session.scalars(
133+
fetch_bulk_query,
134+
{
135+
"location_ids": [item.location.id for item in jobs_dto],
136+
"names": [item.name.lower() for item in jobs_dto],
137+
},
102138
)
103-
104-
statement = select(Job).where(tuple_(Job.location_id, func.lower(Job.name)).in_(select(pairs)))
105-
scalars = await self._session.scalars(statement)
106139
existing = {(job.location_id, job.name.lower()): job for job in scalars.all()}
107140
return [
108141
(
@@ -121,11 +154,13 @@ async def create_or_update(self, job: JobDTO) -> Job:
121154
return await self.update(result, job)
122155

123156
async def _get(self, job: JobDTO) -> Job | None:
124-
statement = select(Job).where(
125-
Job.location_id == job.location.id,
126-
func.lower(Job.name) == job.name.lower(),
157+
return await self._session.scalar(
158+
get_one_query,
159+
{
160+
"location_id": job.location.id,
161+
"name": job.name.lower(),
162+
},
127163
)
128-
return await self._session.scalar(statement)
129164

130165
async def _create(self, job: JobDTO) -> Job:
131166
result = Job(
@@ -147,28 +182,13 @@ async def update(self, existing: Job, new: JobDTO) -> Job:
147182
async def list_by_ids(self, job_ids: Collection[int]) -> list[Job]:
148183
if not job_ids:
149184
return []
150-
query = (
151-
select(Job)
152-
.where(Job.id == any_(list(job_ids))) # type: ignore[arg-type]
153-
.options(selectinload(Job.location).selectinload(Location.addresses))
154-
)
155-
result = await self._session.scalars(query)
185+
186+
result = await self._session.scalars(get_list_query, {"job_ids": list(job_ids)})
156187
return list(result.all())
157188

158189
async def get_stats_by_location_ids(self, location_ids: Collection[int]) -> dict[int, Row]:
159190
if not location_ids:
160191
return {}
161192

162-
query = (
163-
select(
164-
Job.location_id.label("location_id"),
165-
func.count(Job.id.distinct()).label("total_jobs"),
166-
)
167-
.where(
168-
Job.location_id == any_(list(location_ids)), # type: ignore[arg-type]
169-
)
170-
.group_by(Job.location_id)
171-
)
172-
173-
query_result = await self._session.execute(query)
193+
query_result = await self._session.execute(get_stats_query, {"location_ids": list(location_ids)})
174194
return {row.location_id: row for row in query_result.all()}

0 commit comments

Comments
 (0)