55from typing import Literal , Sequence
66from uuid import UUID
77
8- from sqlalchemy import ColumnElement , Select , any_ , func , literal_column , select
8+ from sqlalchemy import ColumnElement , Row , Select , any_ , func , literal_column , select
99from sqlalchemy .dialects .postgresql import insert
1010
1111from data_rentgen .db .models import Input
@@ -162,7 +162,6 @@ async def _get_inputs(
162162 if granularity == "RUN" :
163163 query = select (
164164 func .max (Input .created_at ).label ("created_at" ),
165- literal_column ("NULL" ).label ("id" ),
166165 literal_column ("NULL" ).label ("operation_id" ),
167166 Input .run_id ,
168167 Input .job_id ,
@@ -180,7 +179,6 @@ async def _get_inputs(
180179 else :
181180 query = select (
182181 func .max (Input .created_at ).label ("created_at" ),
183- literal_column ("NULL" ).label ("id" ),
184182 literal_column ("NULL" ).label ("operation_id" ),
185183 literal_column ("NULL" ).label ("run_id" ),
186184 Input .job_id ,
@@ -212,3 +210,32 @@ async def _get_inputs(
212210 )
213211 for row in query_result .all ()
214212 ]
213+
214+ async def get_stats_by_operation_ids (self , operation_ids : Sequence [UUID ]) -> dict [UUID , Row ]:
215+ if not operation_ids :
216+ return {}
217+
218+ # Input created_at is always the same as operation's created_at
219+ # do not use `tuple_(Input.created_at, Input.operation_id).in_(...),
220+ # as this is too complex filter for Postgres to make an optimal query plan
221+ min_created_at = extract_timestamp_from_uuid (min (operation_ids ))
222+ max_created_at = extract_timestamp_from_uuid (max (operation_ids ))
223+
224+ query = (
225+ select (
226+ Input .operation_id .label ("operation_id" ),
227+ func .count (Input .dataset_id .distinct ()).label ("total_datasets" ),
228+ func .sum (Input .num_bytes ).label ("total_bytes" ),
229+ func .sum (Input .num_rows ).label ("total_rows" ),
230+ func .sum (Input .num_files ).label ("total_files" ),
231+ )
232+ .where (
233+ Input .created_at >= min_created_at ,
234+ Input .created_at <= max_created_at ,
235+ Input .operation_id == any_ (operation_ids ), # type: ignore[arg-type]
236+ )
237+ .group_by (Input .operation_id )
238+ )
239+
240+ query_result = await self ._session .execute (query )
241+ return {row .operation_id : row for row in query_result .all ()}
0 commit comments