11# SPDX-FileCopyrightText: 2024-2025 MTS PJSC
22# SPDX-License-Identifier: Apache-2.0
3-
43from collections .abc import Sequence
4+ from dataclasses import dataclass
55from datetime import datetime , timezone
66from typing import Literal
7- from uuid import UUID
87
9- from sqlalchemy import ColumnElement , Row , Select , any_ , func , literal_column , select
8+ from sqlalchemy import ColumnElement , Row , any_ , func , literal_column , select
109from sqlalchemy .dialects .postgresql import insert
10+ from uuid6 import UUID
1111
12- from data_rentgen .db .models import Input
12+ from data_rentgen .db .models import Input , Schema
1313from data_rentgen .db .repositories .base import Repository
1414from data_rentgen .db .utils .uuid import (
1515 extract_timestamp_from_uuid ,
1818from data_rentgen .dto import InputDTO
1919
2020
21+ @dataclass
22+ class InputRow :
23+ created_at : datetime
24+ operation_id : UUID
25+ run_id : UUID
26+ job_id : int
27+ dataset_id : int
28+ num_bytes : int | None
29+ num_rows : int | None
30+ num_files : int | None
31+ schema_id : int | None = None
32+ schema_relevance_type : Literal ["EXACT_MATCH" , "LATEST_KNOWN" ] | None = None
33+ schema : Schema | None = None
34+
35+
2136class InputRepository (Repository [Input ]):
2237 def get_id (self , input_ : InputDTO ) -> UUID :
2338 # `created_at' field of input should be the same as operation's,
@@ -71,7 +86,7 @@ async def list_by_operation_ids(
7186 self ,
7287 operation_ids : Sequence [UUID ],
7388 granularity : Literal ["JOB" , "RUN" , "OPERATION" ],
74- ) -> list [Input ]:
89+ ) -> list [InputRow ]:
7590 if not operation_ids :
7691 return []
7792
@@ -94,7 +109,7 @@ async def list_by_run_ids(
94109 since : datetime ,
95110 until : datetime | None ,
96111 granularity : Literal ["JOB" , "RUN" , "OPERATION" ],
97- ) -> list [Input ]:
112+ ) -> list [InputRow ]:
98113 if not run_ids :
99114 return []
100115
@@ -116,7 +131,7 @@ async def list_by_job_ids(
116131 since : datetime ,
117132 until : datetime | None ,
118133 granularity : Literal ["JOB" , "RUN" , "OPERATION" ],
119- ) -> list [Input ]:
134+ ) -> list [InputRow ]:
120135 if not job_ids :
121136 return []
122137
@@ -135,7 +150,7 @@ async def list_by_dataset_ids(
135150 since : datetime ,
136151 until : datetime | None ,
137152 granularity : Literal ["JOB" , "RUN" , "OPERATION" ],
138- ) -> list [Input ]:
153+ ) -> list [InputRow ]:
139154 if not dataset_ids :
140155 return []
141156
@@ -152,66 +167,115 @@ async def _get_inputs(
152167 self ,
153168 where : list [ColumnElement ],
154169 granularity : Literal ["JOB" , "RUN" , "OPERATION" ],
155- ) -> list [Input ]:
170+ ) -> list [InputRow ]:
156171 if granularity == "OPERATION" :
157172 # return Input as-is
158173 simple_query = select (Input ).where (* where )
159174 result = await self ._session .scalars (simple_query )
160- return list (result .all ())
175+ return [
176+ InputRow (
177+ created_at = row .created_at ,
178+ operation_id = row .operation_id ,
179+ run_id = row .run_id ,
180+ job_id = row .job_id ,
181+ dataset_id = row .dataset_id ,
182+ num_bytes = row .num_bytes ,
183+ num_rows = row .num_rows ,
184+ num_files = row .num_files ,
185+ schema_id = row .schema_id ,
186+ schema_relevance_type = "EXACT_MATCH" if row .schema_id else None ,
187+ )
188+ for row in result .all ()
189+ ]
161190
162191 # return an aggregated Input
163- query : Select [tuple ]
164192 if granularity == "RUN" :
193+ partition_by = [Input .run_id , Input .job_id , Input .dataset_id ]
194+ base_query = (
195+ select (
196+ Input ,
197+ func .first_value (Input .schema_id )
198+ .over (partition_by = partition_by , order_by = [Input .created_at , Input .schema_id ])
199+ .label ("oldest_schema_id" ),
200+ func .last_value (Input .schema_id )
201+ .over (partition_by = partition_by , order_by = [Input .created_at , Input .schema_id ])
202+ .label ("newest_schema_id" ),
203+ )
204+ .where (* where )
205+ .cte ()
206+ )
165207 query = select (
166- func .max (Input .created_at ).label ("created_at" ),
208+ func .max (base_query . c .created_at ).label ("created_at" ),
167209 literal_column ("NULL" ).label ("operation_id" ),
168- Input .run_id ,
169- Input .job_id ,
170- Input .dataset_id ,
171- func .sum (Input .num_bytes ).label ("sum_num_bytes" ),
172- func .sum (Input .num_rows ).label ("sum_num_rows" ),
173- func .sum (Input .num_files ).label ("sum_num_files" ),
174- func .min (Input . schema_id ).label ("min_schema_id" ),
175- func .max (Input . schema_id ).label ("max_schema_id" ),
210+ base_query . c .run_id ,
211+ base_query . c .job_id ,
212+ base_query . c .dataset_id ,
213+ func .sum (base_query . c .num_bytes ).label ("sum_num_bytes" ),
214+ func .sum (base_query . c .num_rows ).label ("sum_num_rows" ),
215+ func .sum (base_query . c .num_files ).label ("sum_num_files" ),
216+ func .min (base_query . c . oldest_schema_id ).label ("min_schema_id" ),
217+ func .max (base_query . c . newest_schema_id ).label ("max_schema_id" ),
176218 ).group_by (
177- Input .run_id ,
178- Input .job_id ,
179- Input .dataset_id ,
219+ base_query . c .run_id ,
220+ base_query . c .job_id ,
221+ base_query . c .dataset_id ,
180222 )
181223 else :
224+ partition_by = [Input .job_id , Input .dataset_id ]
225+ base_query = (
226+ select (
227+ Input ,
228+ func .first_value (Input .schema_id )
229+ .over (partition_by = partition_by , order_by = [Input .created_at , Input .schema_id ])
230+ .label ("oldest_schema_id" ),
231+ func .last_value (Input .schema_id )
232+ .over (partition_by = partition_by , order_by = [Input .created_at , Input .schema_id ])
233+ .label ("newest_schema_id" ),
234+ )
235+ .where (* where )
236+ .cte ()
237+ )
182238 query = select (
183- func .max (Input .created_at ).label ("created_at" ),
239+ func .max (base_query . c .created_at ).label ("created_at" ),
184240 literal_column ("NULL" ).label ("operation_id" ),
185241 literal_column ("NULL" ).label ("run_id" ),
186- Input .job_id ,
187- Input .dataset_id ,
188- func .sum (Input .num_bytes ).label ("sum_num_bytes" ),
189- func .sum (Input .num_rows ).label ("sum_num_rows" ),
190- func .sum (Input .num_files ).label ("sum_num_files" ),
191- func .min (Input . schema_id ).label ("min_schema_id" ),
192- func .max (Input . schema_id ).label ("max_schema_id" ),
242+ base_query . c .job_id ,
243+ base_query . c .dataset_id ,
244+ func .sum (base_query . c .num_bytes ).label ("sum_num_bytes" ),
245+ func .sum (base_query . c .num_rows ).label ("sum_num_rows" ),
246+ func .sum (base_query . c .num_files ).label ("sum_num_files" ),
247+ func .min (base_query . c . oldest_schema_id ).label ("min_schema_id" ),
248+ func .max (base_query . c . newest_schema_id ).label ("max_schema_id" ),
193249 ).group_by (
194- Input .job_id ,
195- Input .dataset_id ,
250+ base_query . c .job_id ,
251+ base_query . c .dataset_id ,
196252 )
197253
198- query = query .where (* where )
199254 query_result = await self ._session .execute (query )
200- return [
201- Input (
202- created_at = row .created_at ,
203- run_id = row .run_id ,
204- job_id = row .job_id ,
205- dataset_id = row .dataset_id ,
206- num_bytes = row .sum_num_bytes ,
207- num_rows = row .sum_num_rows ,
208- num_files = row .sum_num_files ,
209- # If all outputs within Dataset -> Run|Job have the same schema, save it.
210- # If not, it's impossible to merge.
211- schema_id = row .max_schema_id if row .min_schema_id == row .max_schema_id else None ,
255+
256+ results = []
257+ for row in query_result .all ():
258+ schema_relevance_type : Literal ["EXACT_MATCH" , "LATEST_KNOWN" ] | None
259+ if row .max_schema_id :
260+ schema_relevance_type = "EXACT_MATCH" if row .min_schema_id == row .max_schema_id else "LATEST_KNOWN"
261+ else :
262+ schema_relevance_type = None
263+
264+ results .append (
265+ InputRow (
266+ created_at = row .created_at ,
267+ operation_id = row .operation_id ,
268+ run_id = row .run_id ,
269+ job_id = row .job_id ,
270+ dataset_id = row .dataset_id ,
271+ num_bytes = row .sum_num_bytes ,
272+ num_rows = row .sum_num_rows ,
273+ num_files = row .sum_num_files ,
274+ schema_id = row .max_schema_id ,
275+ schema_relevance_type = schema_relevance_type ,
276+ ),
212277 )
213- for row in query_result .all ()
214- ]
278+ return results
215279
216280 async def get_stats_by_operation_ids (self , operation_ids : Sequence [UUID ]) -> dict [UUID , Row ]:
217281 if not operation_ids :
0 commit comments