Skip to content

Commit 774501a

Browse files
authored
[DOP-23711] Change returning schema in lineage response (#185)
* [DOP-23711] Update schema_id union for outputs and inputs * [DOP-23711] Update schema_id union for outputs and inputs * [DOP-23711] change response schema * [DOP-23711] fix pre-commit * [DOP-23711] rename field to relevance_type
1 parent c454d17 commit 774501a

File tree

12 files changed

+439
-149
lines changed

12 files changed

+439
-149
lines changed

data_rentgen/db/repositories/input.py

Lines changed: 112 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# SPDX-FileCopyrightText: 2024-2025 MTS PJSC
22
# SPDX-License-Identifier: Apache-2.0
3-
43
from collections.abc import Sequence
4+
from dataclasses import dataclass
55
from datetime import datetime, timezone
66
from typing import Literal
7-
from uuid import UUID
87

9-
from sqlalchemy import ColumnElement, Row, Select, any_, func, literal_column, select
8+
from sqlalchemy import ColumnElement, Row, any_, func, literal_column, select
109
from sqlalchemy.dialects.postgresql import insert
10+
from uuid6 import UUID
1111

12-
from data_rentgen.db.models import Input
12+
from data_rentgen.db.models import Input, Schema
1313
from data_rentgen.db.repositories.base import Repository
1414
from data_rentgen.db.utils.uuid import (
1515
extract_timestamp_from_uuid,
@@ -18,6 +18,21 @@
1818
from data_rentgen.dto import InputDTO
1919

2020

21+
@dataclass
22+
class InputRow:
23+
created_at: datetime
24+
operation_id: UUID
25+
run_id: UUID
26+
job_id: int
27+
dataset_id: int
28+
num_bytes: int | None
29+
num_rows: int | None
30+
num_files: int | None
31+
schema_id: int | None = None
32+
schema_relevance_type: Literal["EXACT_MATCH", "LATEST_KNOWN"] | None = None
33+
schema: Schema | None = None
34+
35+
2136
class InputRepository(Repository[Input]):
2237
def get_id(self, input_: InputDTO) -> UUID:
2338
# `created_at' field of input should be the same as operation's,
@@ -71,7 +86,7 @@ async def list_by_operation_ids(
7186
self,
7287
operation_ids: Sequence[UUID],
7388
granularity: Literal["JOB", "RUN", "OPERATION"],
74-
) -> list[Input]:
89+
) -> list[InputRow]:
7590
if not operation_ids:
7691
return []
7792

@@ -94,7 +109,7 @@ async def list_by_run_ids(
94109
since: datetime,
95110
until: datetime | None,
96111
granularity: Literal["JOB", "RUN", "OPERATION"],
97-
) -> list[Input]:
112+
) -> list[InputRow]:
98113
if not run_ids:
99114
return []
100115

@@ -116,7 +131,7 @@ async def list_by_job_ids(
116131
since: datetime,
117132
until: datetime | None,
118133
granularity: Literal["JOB", "RUN", "OPERATION"],
119-
) -> list[Input]:
134+
) -> list[InputRow]:
120135
if not job_ids:
121136
return []
122137

@@ -135,7 +150,7 @@ async def list_by_dataset_ids(
135150
since: datetime,
136151
until: datetime | None,
137152
granularity: Literal["JOB", "RUN", "OPERATION"],
138-
) -> list[Input]:
153+
) -> list[InputRow]:
139154
if not dataset_ids:
140155
return []
141156

@@ -152,66 +167,115 @@ async def _get_inputs(
152167
self,
153168
where: list[ColumnElement],
154169
granularity: Literal["JOB", "RUN", "OPERATION"],
155-
) -> list[Input]:
170+
) -> list[InputRow]:
156171
if granularity == "OPERATION":
157172
# return Input as-is
158173
simple_query = select(Input).where(*where)
159174
result = await self._session.scalars(simple_query)
160-
return list(result.all())
175+
return [
176+
InputRow(
177+
created_at=row.created_at,
178+
operation_id=row.operation_id,
179+
run_id=row.run_id,
180+
job_id=row.job_id,
181+
dataset_id=row.dataset_id,
182+
num_bytes=row.num_bytes,
183+
num_rows=row.num_rows,
184+
num_files=row.num_files,
185+
schema_id=row.schema_id,
186+
schema_relevance_type="EXACT_MATCH" if row.schema_id else None,
187+
)
188+
for row in result.all()
189+
]
161190

162191
# return an aggregated Input
163-
query: Select[tuple]
164192
if granularity == "RUN":
193+
partition_by = [Input.run_id, Input.job_id, Input.dataset_id]
194+
base_query = (
195+
select(
196+
Input,
197+
func.first_value(Input.schema_id)
198+
.over(partition_by=partition_by, order_by=[Input.created_at, Input.schema_id])
199+
.label("oldest_schema_id"),
200+
func.last_value(Input.schema_id)
201+
.over(partition_by=partition_by, order_by=[Input.created_at, Input.schema_id])
202+
.label("newest_schema_id"),
203+
)
204+
.where(*where)
205+
.cte()
206+
)
165207
query = select(
166-
func.max(Input.created_at).label("created_at"),
208+
func.max(base_query.c.created_at).label("created_at"),
167209
literal_column("NULL").label("operation_id"),
168-
Input.run_id,
169-
Input.job_id,
170-
Input.dataset_id,
171-
func.sum(Input.num_bytes).label("sum_num_bytes"),
172-
func.sum(Input.num_rows).label("sum_num_rows"),
173-
func.sum(Input.num_files).label("sum_num_files"),
174-
func.min(Input.schema_id).label("min_schema_id"),
175-
func.max(Input.schema_id).label("max_schema_id"),
210+
base_query.c.run_id,
211+
base_query.c.job_id,
212+
base_query.c.dataset_id,
213+
func.sum(base_query.c.num_bytes).label("sum_num_bytes"),
214+
func.sum(base_query.c.num_rows).label("sum_num_rows"),
215+
func.sum(base_query.c.num_files).label("sum_num_files"),
216+
func.min(base_query.c.oldest_schema_id).label("min_schema_id"),
217+
func.max(base_query.c.newest_schema_id).label("max_schema_id"),
176218
).group_by(
177-
Input.run_id,
178-
Input.job_id,
179-
Input.dataset_id,
219+
base_query.c.run_id,
220+
base_query.c.job_id,
221+
base_query.c.dataset_id,
180222
)
181223
else:
224+
partition_by = [Input.job_id, Input.dataset_id]
225+
base_query = (
226+
select(
227+
Input,
228+
func.first_value(Input.schema_id)
229+
.over(partition_by=partition_by, order_by=[Input.created_at, Input.schema_id])
230+
.label("oldest_schema_id"),
231+
func.last_value(Input.schema_id)
232+
.over(partition_by=partition_by, order_by=[Input.created_at, Input.schema_id])
233+
.label("newest_schema_id"),
234+
)
235+
.where(*where)
236+
.cte()
237+
)
182238
query = select(
183-
func.max(Input.created_at).label("created_at"),
239+
func.max(base_query.c.created_at).label("created_at"),
184240
literal_column("NULL").label("operation_id"),
185241
literal_column("NULL").label("run_id"),
186-
Input.job_id,
187-
Input.dataset_id,
188-
func.sum(Input.num_bytes).label("sum_num_bytes"),
189-
func.sum(Input.num_rows).label("sum_num_rows"),
190-
func.sum(Input.num_files).label("sum_num_files"),
191-
func.min(Input.schema_id).label("min_schema_id"),
192-
func.max(Input.schema_id).label("max_schema_id"),
242+
base_query.c.job_id,
243+
base_query.c.dataset_id,
244+
func.sum(base_query.c.num_bytes).label("sum_num_bytes"),
245+
func.sum(base_query.c.num_rows).label("sum_num_rows"),
246+
func.sum(base_query.c.num_files).label("sum_num_files"),
247+
func.min(base_query.c.oldest_schema_id).label("min_schema_id"),
248+
func.max(base_query.c.newest_schema_id).label("max_schema_id"),
193249
).group_by(
194-
Input.job_id,
195-
Input.dataset_id,
250+
base_query.c.job_id,
251+
base_query.c.dataset_id,
196252
)
197253

198-
query = query.where(*where)
199254
query_result = await self._session.execute(query)
200-
return [
201-
Input(
202-
created_at=row.created_at,
203-
run_id=row.run_id,
204-
job_id=row.job_id,
205-
dataset_id=row.dataset_id,
206-
num_bytes=row.sum_num_bytes,
207-
num_rows=row.sum_num_rows,
208-
num_files=row.sum_num_files,
209-
# If all outputs within Dataset -> Run|Job have the same schema, save it.
210-
# If not, it's impossible to merge.
211-
schema_id=row.max_schema_id if row.min_schema_id == row.max_schema_id else None,
255+
256+
results = []
257+
for row in query_result.all():
258+
schema_relevance_type: Literal["EXACT_MATCH", "LATEST_KNOWN"] | None
259+
if row.max_schema_id:
260+
schema_relevance_type = "EXACT_MATCH" if row.min_schema_id == row.max_schema_id else "LATEST_KNOWN"
261+
else:
262+
schema_relevance_type = None
263+
264+
results.append(
265+
InputRow(
266+
created_at=row.created_at,
267+
operation_id=row.operation_id,
268+
run_id=row.run_id,
269+
job_id=row.job_id,
270+
dataset_id=row.dataset_id,
271+
num_bytes=row.sum_num_bytes,
272+
num_rows=row.sum_num_rows,
273+
num_files=row.sum_num_files,
274+
schema_id=row.max_schema_id,
275+
schema_relevance_type=schema_relevance_type,
276+
),
212277
)
213-
for row in query_result.all()
214-
]
278+
return results
215279

216280
async def get_stats_by_operation_ids(self, operation_ids: Sequence[UUID]) -> dict[UUID, Row]:
217281
if not operation_ids:

0 commit comments

Comments
 (0)