Skip to content

Commit a914979

Browse files
committed
[DOP-28871] Improve SQLAlchemy query caching
1 parent 9c9d646 commit a914979

File tree

14 files changed

+435
-275
lines changed

14 files changed

+435
-275
lines changed

data_rentgen/db/repositories/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
ScalarResult,
1414
Select,
1515
SQLColumnExpression,
16+
bindparam,
1617
func,
1718
select,
1819
)
@@ -24,6 +25,8 @@
2425

2526
Model = TypeVar("Model", bound=Base)
2627

28+
advisory_lock_statement = select(func.pg_advisory_xact_lock(bindparam("key")))
29+
2730

2831
class Repository(ABC, Generic[Model]):
2932
def __init__(
@@ -81,5 +84,4 @@ async def _lock(
8184
digest = sha1(data.encode("utf-8"), usedforsecurity=False).digest()
8285
# sha1 returns 160bit hash, we need only first 64 bits
8386
lock_key = int.from_bytes(digest[:8], byteorder="big", signed=True)
84-
statement = select(func.pg_advisory_xact_lock(lock_key))
85-
await self._session.execute(statement)
87+
await self._session.execute(advisory_lock_statement, {"key": lock_key})

data_rentgen/db/repositories/dataset.py

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
String,
1414
any_,
1515
asc,
16+
bindparam,
1617
cast,
1718
desc,
1819
distinct,
@@ -28,25 +29,55 @@
2829
from data_rentgen.db.utils.search import make_tsquery, ts_match, ts_rank
2930
from data_rentgen.dto import DatasetDTO, PaginationDTO
3031

32+
fetch_bulk_query = select(Dataset).where(
33+
tuple_(Dataset.location_id, func.lower(Dataset.name)).in_(
34+
select(
35+
func.unnest(
36+
cast(bindparam("location_ids"), ARRAY(Integer())),
37+
cast(bindparam("names_lower"), ARRAY(String())),
38+
)
39+
.table_valued("location_id", "name_lower")
40+
.render_derived(),
41+
),
42+
),
43+
)
44+
45+
get_list_query = (
46+
select(Dataset)
47+
.where(Dataset.id == any_(bindparam("dataset_ids")))
48+
.options(selectinload(Dataset.location).selectinload(Location.addresses))
49+
.options(selectinload(Dataset.tag_values).selectinload(TagValue.tag))
50+
)
51+
52+
get_one_query = select(Dataset).where(
53+
Dataset.location_id == bindparam("location_id"),
54+
func.lower(Dataset.name) == bindparam("name_lower"),
55+
)
56+
57+
get_stats_query = (
58+
select(
59+
Dataset.location_id.label("location_id"),
60+
func.count(Dataset.id.distinct()).label("total_datasets"),
61+
)
62+
.where(
63+
Dataset.location_id == any_(bindparam("location_ids")),
64+
)
65+
.group_by(Dataset.location_id)
66+
)
67+
3168

3269
class DatasetRepository(Repository[Dataset]):
3370
async def fetch_bulk(self, datasets_dto: list[DatasetDTO]) -> list[tuple[DatasetDTO, Dataset | None]]:
3471
if not datasets_dto:
3572
return []
3673

37-
location_ids = [dataset_dto.location.id for dataset_dto in datasets_dto]
38-
names = [dataset_dto.name.lower() for dataset_dto in datasets_dto]
39-
pairs = (
40-
func.unnest(
41-
cast(location_ids, ARRAY(Integer())),
42-
cast(names, ARRAY(String())),
43-
)
44-
.table_valued("location_id", "name")
45-
.render_derived()
74+
scalars = await self._session.scalars(
75+
fetch_bulk_query,
76+
{
77+
"location_ids": [item.location.id for item in datasets_dto],
78+
"names_lower": [item.name.lower() for item in datasets_dto],
79+
},
4680
)
47-
48-
statement = select(Dataset).where(tuple_(Dataset.location_id, func.lower(Dataset.name)).in_(select(pairs)))
49-
scalars = await self._session.scalars(statement)
5081
existing = {(dataset.location_id, dataset.name.lower()): dataset for dataset in scalars.all()}
5182
return [
5283
(
@@ -139,39 +170,24 @@ async def paginate(
139170
async def list_by_ids(self, dataset_ids: Collection[int]) -> list[Dataset]:
140171
if not dataset_ids:
141172
return []
142-
query = (
143-
select(Dataset)
144-
.where(Dataset.id == any_(list(dataset_ids))) # type: ignore[arg-type]
145-
.options(selectinload(Dataset.location).selectinload(Location.addresses))
146-
.options(selectinload(Dataset.tag_values).selectinload(TagValue.tag))
147-
)
148-
result = await self._session.scalars(query)
173+
result = await self._session.scalars(get_list_query, {"dataset_ids": list(dataset_ids)})
149174
return list(result.all())
150175

151176
async def get_stats_by_location_ids(self, location_ids: Collection[int]) -> dict[int, Row]:
152177
if not location_ids:
153178
return {}
154179

155-
query = (
156-
select(
157-
Dataset.location_id.label("location_id"),
158-
func.count(Dataset.id.distinct()).label("total_datasets"),
159-
)
160-
.where(
161-
Dataset.location_id == any_(list(location_ids)), # type: ignore[arg-type]
162-
)
163-
.group_by(Dataset.location_id)
164-
)
165-
166-
query_result = await self._session.execute(query)
180+
query_result = await self._session.execute(get_stats_query, {"location_ids": list(location_ids)})
167181
return {row.location_id: row for row in query_result.all()}
168182

169183
async def _get(self, dataset: DatasetDTO) -> Dataset | None:
170-
statement = select(Dataset).where(
171-
Dataset.location_id == dataset.location.id,
172-
func.lower(Dataset.name) == dataset.name.lower(),
184+
return await self._session.scalar(
185+
get_one_query,
186+
{
187+
"location_id": dataset.location.id,
188+
"name_lower": dataset.name.lower(),
189+
},
173190
)
174-
return await self._session.scalar(statement)
175191

176192
async def _create(self, dataset: DatasetDTO) -> Dataset:
177193
result = Dataset(location_id=dataset.location.id, name=dataset.name)

data_rentgen/db/repositories/dataset_symlink.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,37 @@
33

44
from collections.abc import Collection
55

6-
from sqlalchemy import ARRAY, BindParameter, Integer, any_, bindparam, cast, func, or_, select, tuple_
6+
from sqlalchemy import ARRAY, Integer, any_, bindparam, cast, func, or_, select, tuple_
77

88
from data_rentgen.db.models.dataset_symlink import DatasetSymlink, DatasetSymlinkType
99
from data_rentgen.db.repositories.base import Repository
1010
from data_rentgen.dto import DatasetSymlinkDTO
1111

12+
fetch_bulk_query = select(DatasetSymlink).where(
13+
tuple_(DatasetSymlink.from_dataset_id, DatasetSymlink.to_dataset_id).in_(
14+
select(
15+
func.unnest(
16+
cast(bindparam("from_dataset_ids"), ARRAY(Integer())),
17+
cast(bindparam("to_dataset_ids"), ARRAY(Integer())),
18+
)
19+
.table_valued("from_dataset_ids", "to_dataset_ids")
20+
.render_derived(),
21+
),
22+
),
23+
)
24+
25+
get_list_query = select(DatasetSymlink).where(
26+
or_(
27+
DatasetSymlink.from_dataset_id == any_(bindparam("dataset_ids")),
28+
DatasetSymlink.to_dataset_id == any_(bindparam("dataset_ids")),
29+
),
30+
)
31+
32+
get_one_query = select(DatasetSymlink).where(
33+
DatasetSymlink.from_dataset_id == bindparam("from_dataset_id"),
34+
DatasetSymlink.to_dataset_id == bindparam("to_dataset_id"),
35+
)
36+
1237

1338
class DatasetSymlinkRepository(Repository[DatasetSymlink]):
1439
async def fetch_bulk(
@@ -18,22 +43,13 @@ async def fetch_bulk(
1843
if not dataset_symlinks_dto:
1944
return []
2045

21-
from_dataset_ids = [dataset_symlink_dto.from_dataset.id for dataset_symlink_dto in dataset_symlinks_dto]
22-
to_dataset_ids = [dataset_symlink_dto.to_dataset.id for dataset_symlink_dto in dataset_symlinks_dto]
23-
24-
pairs = (
25-
func.unnest(
26-
cast(from_dataset_ids, ARRAY(Integer())),
27-
cast(to_dataset_ids, ARRAY(Integer())),
28-
)
29-
.table_valued("from_dataset_ids", "to_dataset_ids")
30-
.render_derived()
46+
scalars = await self._session.scalars(
47+
fetch_bulk_query,
48+
{
49+
"from_dataset_ids": [item.from_dataset.id for item in dataset_symlinks_dto],
50+
"to_dataset_ids": [item.to_dataset.id for item in dataset_symlinks_dto],
51+
},
3152
)
32-
33-
statement = select(DatasetSymlink).where(
34-
tuple_(DatasetSymlink.from_dataset_id, DatasetSymlink.to_dataset_id).in_(select(pairs)),
35-
)
36-
scalars = await self._session.scalars(statement)
3753
existing = {(item.from_dataset_id, item.to_dataset_id): item for item in scalars.all()}
3854
return [
3955
(
@@ -52,22 +68,17 @@ async def list_by_dataset_ids(self, dataset_ids: Collection[int]) -> list[Datase
5268
if not dataset_ids:
5369
return []
5470

55-
param: BindParameter[list[int]] = bindparam("dataset_ids")
56-
query = select(DatasetSymlink).where(
57-
or_(
58-
DatasetSymlink.from_dataset_id == any_(param),
59-
DatasetSymlink.to_dataset_id == any_(param),
60-
),
61-
)
62-
scalars = await self._session.scalars(query, {"dataset_ids": list(dataset_ids)})
71+
scalars = await self._session.scalars(get_list_query, {"dataset_ids": list(dataset_ids)})
6372
return list(scalars.all())
6473

6574
async def _get(self, dataset_symlink: DatasetSymlinkDTO) -> DatasetSymlink | None:
66-
query = select(DatasetSymlink).where(
67-
DatasetSymlink.from_dataset_id == dataset_symlink.from_dataset.id,
68-
DatasetSymlink.to_dataset_id == dataset_symlink.to_dataset.id,
75+
return await self._session.scalar(
76+
get_one_query,
77+
{
78+
"from_dataset_id": dataset_symlink.from_dataset.id,
79+
"to_dataset_id": dataset_symlink.to_dataset.id,
80+
},
6981
)
70-
return await self._session.scalar(query)
7182

7283
async def _create(self, dataset_symlink: DatasetSymlinkDTO) -> DatasetSymlink:
7384
result = DatasetSymlink(

data_rentgen/db/repositories/input.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,17 @@
1616
extract_timestamp_from_uuid,
1717
)
1818

19+
insert_statement = insert(Input)
20+
inserted_row = insert_statement.excluded
21+
insert_statement = insert_statement.on_conflict_do_update(
22+
index_elements=[Input.created_at, Input.id],
23+
set_={
24+
"num_bytes": func.greatest(inserted_row.num_bytes, Input.num_bytes),
25+
"num_rows": func.greatest(inserted_row.num_rows, Input.num_rows),
26+
"num_files": func.greatest(inserted_row.num_files, Input.num_files),
27+
},
28+
)
29+
1930

2031
@dataclass
2132
class InputRow:
@@ -37,19 +48,8 @@ async def create_or_update_bulk(self, inputs: list[InputDTO]) -> None:
3748
if not inputs:
3849
return
3950

40-
insert_statement = insert(Input)
41-
new_row = insert_statement.excluded
42-
statement = insert_statement.on_conflict_do_update(
43-
index_elements=[Input.created_at, Input.id],
44-
set_={
45-
"num_bytes": func.greatest(new_row.num_bytes, Input.num_bytes),
46-
"num_rows": func.greatest(new_row.num_rows, Input.num_rows),
47-
"num_files": func.greatest(new_row.num_files, Input.num_files),
48-
},
49-
)
50-
5151
await self._session.execute(
52-
statement,
52+
insert_statement,
5353
[
5454
{
5555
"id": item.generate_id(),

0 commit comments

Comments
 (0)