|
13 | 13 | String, |
14 | 14 | any_, |
15 | 15 | asc, |
| 16 | + bindparam, |
16 | 17 | cast, |
17 | 18 | desc, |
18 | 19 | distinct, |
|
28 | 29 | from data_rentgen.db.utils.search import make_tsquery, ts_match, ts_rank |
29 | 30 | from data_rentgen.dto import DatasetDTO, PaginationDTO |
30 | 31 |
|
| 32 | +fetch_bulk_query = select(Dataset).where( |
| 33 | + tuple_(Dataset.location_id, func.lower(Dataset.name)).in_( |
| 34 | + select( |
| 35 | + func.unnest( |
| 36 | + cast(bindparam("location_ids"), ARRAY(Integer())), |
| 37 | + cast(bindparam("names_lower"), ARRAY(String())), |
| 38 | + ) |
| 39 | + .table_valued("location_id", "name_lower") |
| 40 | + .render_derived(), |
| 41 | + ), |
| 42 | + ), |
| 43 | +) |
| 44 | + |
| 45 | +get_list_query = ( |
| 46 | + select(Dataset) |
| 47 | + .where(Dataset.id == any_(bindparam("dataset_ids"))) |
| 48 | + .options(selectinload(Dataset.location).selectinload(Location.addresses)) |
| 49 | + .options(selectinload(Dataset.tag_values).selectinload(TagValue.tag)) |
| 50 | +) |
| 51 | + |
| 52 | +get_one_query = select(Dataset).where( |
| 53 | + Dataset.location_id == bindparam("location_id"), |
| 54 | + func.lower(Dataset.name) == bindparam("name_lower"), |
| 55 | +) |
| 56 | + |
| 57 | +get_stats_query = ( |
| 58 | + select( |
| 59 | + Dataset.location_id.label("location_id"), |
| 60 | + func.count(Dataset.id.distinct()).label("total_datasets"), |
| 61 | + ) |
| 62 | + .where( |
| 63 | + Dataset.location_id == any_(bindparam("location_ids")), |
| 64 | + ) |
| 65 | + .group_by(Dataset.location_id) |
| 66 | +) |
| 67 | + |
31 | 68 |
|
32 | 69 | class DatasetRepository(Repository[Dataset]): |
33 | 70 | async def fetch_bulk(self, datasets_dto: list[DatasetDTO]) -> list[tuple[DatasetDTO, Dataset | None]]: |
34 | 71 | if not datasets_dto: |
35 | 72 | return [] |
36 | 73 |
|
37 | | - location_ids = [dataset_dto.location.id for dataset_dto in datasets_dto] |
38 | | - names = [dataset_dto.name.lower() for dataset_dto in datasets_dto] |
39 | | - pairs = ( |
40 | | - func.unnest( |
41 | | - cast(location_ids, ARRAY(Integer())), |
42 | | - cast(names, ARRAY(String())), |
43 | | - ) |
44 | | - .table_valued("location_id", "name") |
45 | | - .render_derived() |
| 74 | + scalars = await self._session.scalars( |
| 75 | + fetch_bulk_query, |
| 76 | + { |
| 77 | + "location_ids": [item.location.id for item in datasets_dto], |
| 78 | + "names_lower": [item.name.lower() for item in datasets_dto], |
| 79 | + }, |
46 | 80 | ) |
47 | | - |
48 | | - statement = select(Dataset).where(tuple_(Dataset.location_id, func.lower(Dataset.name)).in_(select(pairs))) |
49 | | - scalars = await self._session.scalars(statement) |
50 | 81 | existing = {(dataset.location_id, dataset.name.lower()): dataset for dataset in scalars.all()} |
51 | 82 | return [ |
52 | 83 | ( |
@@ -139,39 +170,24 @@ async def paginate( |
139 | 170 | async def list_by_ids(self, dataset_ids: Collection[int]) -> list[Dataset]: |
140 | 171 | if not dataset_ids: |
141 | 172 | return [] |
142 | | - query = ( |
143 | | - select(Dataset) |
144 | | - .where(Dataset.id == any_(list(dataset_ids))) # type: ignore[arg-type] |
145 | | - .options(selectinload(Dataset.location).selectinload(Location.addresses)) |
146 | | - .options(selectinload(Dataset.tag_values).selectinload(TagValue.tag)) |
147 | | - ) |
148 | | - result = await self._session.scalars(query) |
| 173 | + result = await self._session.scalars(get_list_query, {"dataset_ids": list(dataset_ids)}) |
149 | 174 | return list(result.all()) |
150 | 175 |
|
151 | 176 | async def get_stats_by_location_ids(self, location_ids: Collection[int]) -> dict[int, Row]: |
152 | 177 | if not location_ids: |
153 | 178 | return {} |
154 | 179 |
|
155 | | - query = ( |
156 | | - select( |
157 | | - Dataset.location_id.label("location_id"), |
158 | | - func.count(Dataset.id.distinct()).label("total_datasets"), |
159 | | - ) |
160 | | - .where( |
161 | | - Dataset.location_id == any_(list(location_ids)), # type: ignore[arg-type] |
162 | | - ) |
163 | | - .group_by(Dataset.location_id) |
164 | | - ) |
165 | | - |
166 | | - query_result = await self._session.execute(query) |
| 180 | + query_result = await self._session.execute(get_stats_query, {"location_ids": list(location_ids)}) |
167 | 181 | return {row.location_id: row for row in query_result.all()} |
168 | 182 |
|
169 | 183 | async def _get(self, dataset: DatasetDTO) -> Dataset | None: |
170 | | - statement = select(Dataset).where( |
171 | | - Dataset.location_id == dataset.location.id, |
172 | | - func.lower(Dataset.name) == dataset.name.lower(), |
| 184 | + return await self._session.scalar( |
| 185 | + get_one_query, |
| 186 | + { |
| 187 | + "location_id": dataset.location.id, |
| 188 | + "name_lower": dataset.name.lower(), |
| 189 | + }, |
173 | 190 | ) |
174 | | - return await self._session.scalar(statement) |
175 | 191 |
|
176 | 192 | async def _create(self, dataset: DatasetDTO) -> Dataset: |
177 | 193 | result = Dataset(location_id=dataset.location.id, name=dataset.name) |
|
0 commit comments