Skip to content

Commit 9091dec

Browse files
committed
[DOP-32998] Extract dataset tags
1 parent 9cc2df4 commit 9091dec

File tree

9 files changed

+156
-23
lines changed

9 files changed

+156
-23
lines changed

data_rentgen/consumer/extractors/batch_extraction_result.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def add_location(self, location: LocationDTO):
119119

120120
def add_dataset(self, dataset: DatasetDTO):
121121
dataset.location = self.add_location(dataset.location)
122+
dataset.tag_values = {self.add_tag_value(tag_value) for tag_value in dataset.tag_values}
122123
return self._add(self._datasets, dataset)
123124

124125
def add_dataset_symlink(self, dataset_symlink: DatasetSymlinkDTO):
@@ -207,6 +208,7 @@ def get_tag_value(self, tag_value_key: tuple) -> TagValueDTO:
207208
def get_dataset(self, dataset_key: tuple) -> DatasetDTO:
208209
dataset = self._datasets[dataset_key]
209210
dataset.location = self.get_location(dataset.location.unique_key)
211+
dataset.tag_values = {self.get_tag_value(tag_value.unique_key) for tag_value in dataset.tag_values}
210212
return dataset
211213

212214
def get_dataset_symlink(self, dataset_symlink_key: tuple) -> DatasetSymlinkDTO:

data_rentgen/consumer/extractors/generic/dataset.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
DatasetSymlinkDTO,
1010
DatasetSymlinkTypeDTO,
1111
LocationDTO,
12+
TagDTO,
13+
TagValueDTO,
1214
)
1315
from data_rentgen.openlineage.dataset import (
1416
OpenLineageDataset,
@@ -28,7 +30,8 @@ def extract_dataset(self, dataset: OpenLineageDataset) -> DatasetDTO:
2830
"""
2931
Extract DatasetDTO from input or output OpenLineageDataset
3032
"""
31-
return self._extract_dataset_ref(dataset)
33+
dataset_dto = self._extract_dataset_ref(dataset)
34+
return self._enrich_dataset_tags(dataset_dto, dataset)
3235

3336
def _extract_dataset_ref(
3437
self,
@@ -108,3 +111,15 @@ def _connect_dataset_with_symlinks(
108111
)
109112

110113
return sorted(result, key=lambda x: x.type)
114+
115+
def _enrich_dataset_tags(self, dataset_dto: DatasetDTO, dataset: OpenLineageDataset) -> DatasetDTO:
116+
if not dataset.facets.tags:
117+
return dataset_dto
118+
119+
for raw_tag in dataset.facets.tags.tags:
120+
tag_value = TagValueDTO(
121+
tag=TagDTO(name=raw_tag.key.lower().replace(" ", "_")),
122+
value=raw_tag.value,
123+
)
124+
dataset_dto.tag_values.add(tag_value)
125+
return dataset_dto

data_rentgen/consumer/saver.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,11 @@ async def create_datasets(self, data: BatchExtractionResult):
6262
self.logger.debug("Creating datasets")
6363
dataset_pairs = await self.unit_of_work.dataset.fetch_bulk(data.datasets())
6464
for dataset_dto, dataset in dataset_pairs:
65-
if not dataset:
66-
async with self.unit_of_work:
67-
dataset = await self.unit_of_work.dataset.create(dataset_dto) # noqa: PLW2901
65+
async with self.unit_of_work:
66+
if not dataset:
67+
dataset = await self.unit_of_work.dataset.create_or_update(dataset_dto) # noqa: PLW2901
68+
else:
69+
dataset = await self.unit_of_work.dataset.update(dataset, dataset_dto) # noqa: PLW2901
6870
dataset_dto.id = dataset.id
6971

7072
async def create_dataset_symlinks(self, data: BatchExtractionResult):

data_rentgen/db/repositories/dataset.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@
2222
tuple_,
2323
union,
2424
)
25+
from sqlalchemy.dialects.postgresql import insert
2526
from sqlalchemy.orm import selectinload
2627

2728
from data_rentgen.db.models import Address, Dataset, Location, TagValue
29+
from data_rentgen.db.models.dataset import DatasetTagValue
2830
from data_rentgen.db.repositories.base import Repository
2931
from data_rentgen.db.utils.search import make_tsquery, ts_match, ts_rank
3032
from data_rentgen.dto import DatasetDTO, PaginationDTO
@@ -65,6 +67,17 @@
6567
.group_by(Dataset.location_id)
6668
)
6769

70+
insert_tag_value_query = (
71+
insert(DatasetTagValue)
72+
.values(
73+
{
74+
"dataset_id": bindparam("dataset_id"),
75+
"tag_value_id": bindparam("tag_value_id"),
76+
}
77+
)
78+
.on_conflict_do_nothing(index_elements=["dataset_id", "tag_value_id"])
79+
)
80+
6881

6982
class DatasetRepository(Repository[Dataset]):
7083
async def fetch_bulk(self, datasets_dto: list[DatasetDTO]) -> list[tuple[DatasetDTO, Dataset | None]]:
@@ -87,10 +100,51 @@ async def fetch_bulk(self, datasets_dto: list[DatasetDTO]) -> list[tuple[Dataset
87100
for dto in datasets_dto
88101
]
89102

90-
async def create(self, dataset: DatasetDTO) -> Dataset:
91-
# if another worker already created the same row, just use it. if not - create with holding the lock.
92-
await self._lock(dataset.location.id, dataset.name.lower())
93-
return await self._get(dataset) or await self._create(dataset)
103+
async def create_or_update(self, dataset: DatasetDTO) -> Dataset:
104+
result = await self._get(dataset)
105+
if not result:
106+
# try one more time, but with lock acquired.
107+
# if another worker already created the same row, just use it. if not - create with holding the lock.
108+
await self._lock(dataset.location.id, dataset.name.lower())
109+
result = await self._get(dataset)
110+
111+
if not result:
112+
result = await self._create(dataset)
113+
return await self.update(result, dataset)
114+
115+
async def _get(self, dataset: DatasetDTO) -> Dataset | None:
116+
return await self._session.scalar(
117+
get_one_query,
118+
{
119+
"location_id": dataset.location.id,
120+
"name_lower": dataset.name.lower(),
121+
},
122+
)
123+
124+
async def _create(self, dataset: DatasetDTO) -> Dataset:
125+
result = Dataset(location_id=dataset.location.id, name=dataset.name)
126+
self._session.add(result)
127+
await self._session.flush([result])
128+
return result
129+
130+
async def update(self, existing: Dataset, new: DatasetDTO) -> Dataset:
131+
if not new.tag_values:
132+
# in most cases datasets have no tag values, so we can avoid INSERT statements
133+
return existing
134+
135+
# Lock to prevent inserting the same rows from multiple workers
136+
await self._lock(existing.location_id, existing.name)
137+
await self._session.execute(
138+
insert_tag_value_query,
139+
[
140+
{
141+
"dataset_id": existing.id,
142+
"tag_value_id": tag_value_dto.id,
143+
}
144+
for tag_value_dto in new.tag_values
145+
],
146+
)
147+
return existing
94148

95149
async def paginate(
96150
self,
@@ -184,18 +238,3 @@ async def get_stats_by_location_ids(self, location_ids: Collection[int]) -> dict
184238

185239
query_result = await self._session.execute(get_stats_query, {"location_ids": list(location_ids)})
186240
return {row.location_id: row for row in query_result.all()}
187-
188-
async def _get(self, dataset: DatasetDTO) -> Dataset | None:
189-
return await self._session.scalar(
190-
get_one_query,
191-
{
192-
"location_id": dataset.location.id,
193-
"name_lower": dataset.name.lower(),
194-
},
195-
)
196-
197-
async def _create(self, dataset: DatasetDTO) -> Dataset:
198-
result = Dataset(location_id=dataset.location.id, name=dataset.name)
199-
self._session.add(result)
200-
await self._session.flush([result])
201-
return result

data_rentgen/dto/dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
from dataclasses import dataclass, field
77

88
from data_rentgen.dto.location import LocationDTO
9+
from data_rentgen.dto.tag import TagValueDTO
910

1011

1112
@dataclass(slots=True)
1213
class DatasetDTO:
1314
location: LocationDTO
1415
name: str
16+
tag_values: set[TagValueDTO] = field(default_factory=set)
1517
id: int | None = field(default=None, compare=False)
1618

1719
@property
@@ -21,4 +23,5 @@ def unique_key(self) -> tuple:
2123
def merge(self, new: DatasetDTO) -> DatasetDTO:
2224
self.location.merge(new.location)
2325
self.id = new.id or self.id
26+
self.tag_values.update(new.tag_values)
2427
return self

data_rentgen/openlineage/dataset_facets/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
OpenLineageColumnLineageDatasetFacetFieldRef,
1414
OpenLineageColumnLineageDatasetFacetFieldTransformation,
1515
)
16+
from data_rentgen.openlineage.dataset_facets.dataset_tags import (
17+
OpenLineageDatasetTagsFacet,
18+
OpenLineageDatasetTagsFacetField,
19+
)
1620
from data_rentgen.openlineage.dataset_facets.documentation import (
1721
OpenLineageDocumentationDatasetFacet,
1822
)
@@ -46,6 +50,8 @@
4650
"OpenLineageDatasetFacets",
4751
"OpenLineageDatasetLifecycleStateChange",
4852
"OpenLineageDatasetPreviousIdentifier",
53+
"OpenLineageDatasetTagsFacet",
54+
"OpenLineageDatasetTagsFacetField",
4955
"OpenLineageDocumentationDatasetFacet",
5056
"OpenLineageInputDatasetFacets",
5157
"OpenLineageInputStatisticsInputDatasetFacet",
@@ -70,6 +76,7 @@ class OpenLineageDatasetFacets(OpenLineageBase):
7076
datasetSchema: OpenLineageSchemaDatasetFacet | None = Field(default=None, alias="schema")
7177
symlinks: OpenLineageSymlinksDatasetFacet | None = None
7278
columnLineage: OpenLineageColumnLineageDatasetFacet | None = None
79+
tags: OpenLineageDatasetTagsFacet | None = None
7380

7481

7582
class OpenLineageInputDatasetFacets(OpenLineageBase):
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# SPDX-FileCopyrightText: 2024-present MTS PJSC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from pydantic import BaseModel, Field
5+
6+
from data_rentgen.openlineage.dataset_facets.base import OpenLineageDatasetFacet
7+
8+
9+
class OpenLineageDatasetTagsFacetField(BaseModel):
10+
"""Dataset tags field type.
11+
See [TagsDatasetFacet](https://github.com/OpenLineage/OpenLineage/blob/main/spec/facets/TagsDatasetFacet.json).
12+
"""
13+
14+
key: str = Field(description="Key that identifies the tag")
15+
value: str = Field(description="The value of the field")
16+
source: str | None = Field(default=None, description="The source of the tag. INTEGRATION|USER|DBT CORE|SPARK|etc.")
17+
field: str | None = Field(default=None, description="Identifies the field in a dataset if a tag applies to one")
18+
19+
20+
class OpenLineageDatasetTagsFacet(OpenLineageDatasetFacet):
21+
"""Dataset facet describing dataset tags.
22+
See [TagsDatasetFacet](https://github.com/OpenLineage/OpenLineage/blob/main/spec/facets/DatasetTypeDatasetFacet.json).
23+
"""
24+
25+
tags: list[OpenLineageDatasetTagsFacetField] = Field(
26+
default_factory=list,
27+
description="The tags applied to the dataset facet",
28+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Extract dataset tags provided by OpenLineage integrations.

tests/test_consumer/test_extractors/test_extractors_dataset.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@
55
DatasetSymlinkDTO,
66
DatasetSymlinkTypeDTO,
77
LocationDTO,
8+
TagDTO,
9+
TagValueDTO,
810
)
911
from data_rentgen.openlineage.dataset import (
1012
OpenLineageDataset,
1113
)
1214
from data_rentgen.openlineage.dataset_facets import (
1315
OpenLineageDatasetFacets,
16+
OpenLineageDatasetTagsFacet,
17+
OpenLineageDatasetTagsFacetField,
1418
OpenLineageSymlinkIdentifier,
1519
OpenLineageSymlinksDatasetFacet,
1620
OpenLineageSymlinkType,
@@ -293,3 +297,35 @@ def test_extractors_extract_dataset_unknown():
293297
name="some.name",
294298
)
295299
assert symlinks_dto == []
300+
301+
302+
def test_extractors_extract_dataset_with_tags():
303+
dataset = OpenLineageDataset(
304+
namespace="postgres://192.168.1.1:5432",
305+
name="mydb.myschema.mytable",
306+
facets=OpenLineageDatasetFacets(
307+
tags=OpenLineageDatasetTagsFacet(
308+
tags=[
309+
OpenLineageDatasetTagsFacetField(key="somekey", value="somevalue"),
310+
OpenLineageDatasetTagsFacetField(key="somekey", value="othervalue", source="OTHER"),
311+
OpenLineageDatasetTagsFacetField(key="anotherkey", value="anothervalue", source="ABC", field="abc"),
312+
],
313+
),
314+
),
315+
)
316+
317+
dataset_dto, symlinks_dto = GenericExtractor().extract_dataset_and_symlinks(dataset)
318+
assert dataset_dto == DatasetDTO(
319+
location=LocationDTO(
320+
type="postgres",
321+
name="192.168.1.1:5432",
322+
addresses={"postgres://192.168.1.1:5432"},
323+
),
324+
name="mydb.myschema.mytable",
325+
tag_values={
326+
TagValueDTO(tag=TagDTO(name="somekey"), value="somevalue"),
327+
TagValueDTO(tag=TagDTO(name="somekey"), value="othervalue"),
328+
TagValueDTO(tag=TagDTO(name="anotherkey"), value="anothervalue"),
329+
},
330+
)
331+
assert symlinks_dto == []

0 commit comments

Comments
 (0)