@@ -137,10 +137,6 @@ def get(self):
137137 return self .session ()
138138
139139
140- def in_ (column : Column , values ):
141- return or_ (* [column == value for value in values ])
142-
143-
144140class SqlAlchemyDatasetRepository (DatasetRepository ):
145141 def __init__ (self , session_provider : SqlAlchemySessionProvider ):
146142 self .session_provider = session_provider
@@ -194,7 +190,19 @@ def _filter_query(
194190 # return an empty DatasetCollection
195191 return DatasetCollection ()
196192
197- query = query .filter (in_ (dataset_table .c .dataset_id , dataset_id ))
193+ dataset_ids_cte = union_all (
194+ * [
195+ select (literal (dataset_id ).label ("dataset_id" ))
196+ for dataset_id in set (dataset_id )
197+ ]
198+ ).cte ("dataset_ids" )
199+
200+ query = query .select_from (
201+ dataset_table .join (
202+ dataset_ids_cte ,
203+ dataset_ids_cte .c .dataset_id == dataset_table .c .dataset_id ,
204+ )
205+ )
198206 else :
199207 query = query .filter (dataset_table .c .dataset_id == dataset_id )
200208
@@ -265,15 +273,30 @@ def _load_datasets(self, dataset_ids: list[str]) -> list[Dataset]:
265273 if not dataset_ids :
266274 return []
267275
276+ dataset_ids_cte = union_all (
277+ * [
278+ select (literal (dataset_id ).label ("dataset_id" ))
279+ for dataset_id in set (dataset_ids )
280+ ]
281+ ).cte ("dataset_ids" )
282+
268283 dataset_rows = list (
269- self .session .query (dataset_table ).filter (
270- in_ (dataset_table .c .dataset_id , dataset_ids )
284+ self .session .query (dataset_table ).select_from (
285+ dataset_table .join (
286+ dataset_ids_cte ,
287+ dataset_ids_cte .c .dataset_id == dataset_table .c .dataset_id ,
288+ )
271289 )
272290 )
273291 revisions_per_dataset = {}
274292 rows = (
275293 self .session .query (revision_table )
276- .filter (in_ (revision_table .c .dataset_id , dataset_ids ))
294+ .select_from (
295+ revision_table .join (
296+ dataset_ids_cte ,
297+ dataset_ids_cte .c .dataset_id == revision_table .c .dataset_id ,
298+ )
299+ )
277300 .order_by (revision_table .c .dataset_id )
278301 )
279302
@@ -285,7 +308,12 @@ def _load_datasets(self, dataset_ids: list[str]) -> list[Dataset]:
285308 files_per_revision = {}
286309 rows = (
287310 self .session .query (file_table )
288- .filter (in_ (file_table .c .dataset_id , dataset_ids ))
311+ .select_from (
312+ file_table .join (
313+ dataset_ids_cte ,
314+ dataset_ids_cte .c .dataset_id == file_table .c .dataset_id ,
315+ )
316+ )
289317 .order_by (file_table .c .dataset_id , file_table .c .revision_id )
290318 )
291319
0 commit comments