Skip to content

Commit f196969

Browse files
committed
[DOP-28871] Use bulk DB fetch in consumer
1 parent a6bb651 commit f196969

File tree

17 files changed

+488
-289
lines changed

17 files changed

+488
-289
lines changed

data_rentgen/consumer/saver.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# SPDX-FileCopyrightText: 2024-2025 MTS PJSC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from faststream import Logger
7+
from sqlalchemy.exc import DatabaseError, IntegrityError
8+
from sqlalchemy.ext.asyncio import AsyncSession
9+
10+
from data_rentgen.consumer.extractors import BatchExtractionResult
11+
from data_rentgen.services.uow import UnitOfWork
12+
13+
14+
class DatabaseSaver:
15+
def __init__(
16+
self,
17+
session: AsyncSession,
18+
logger: Logger,
19+
) -> None:
20+
self.unit_of_work = UnitOfWork(session)
21+
self.logger = logger
22+
23+
async def save(self, data: BatchExtractionResult):
24+
self.logger.info("Saving to database")
25+
26+
await self.create_locations(data)
27+
await self.create_datasets(data)
28+
await self.create_dataset_symlinks(data)
29+
await self.create_job_types(data)
30+
await self.create_jobs(data)
31+
await self.create_users(data)
32+
await self.create_sql_queries(data)
33+
await self.create_schemas(data)
34+
35+
try:
36+
await self.create_runs_bulk(data)
37+
except DatabaseError:
38+
await self.create_runs_one_by_one(data)
39+
40+
await self.create_operations(data)
41+
await self.create_inputs(data)
42+
await self.create_outputs(data)
43+
await self.create_column_lineage(data)
44+
45+
self.logger.info("Saved successfully")
46+
47+
async def create_locations(self, data: BatchExtractionResult):
48+
self.logger.debug("Creating locations")
49+
# It's hard to fetch locations in bulk, and number of locations is usually small,
50+
# so using a row-by-row approach
51+
for location_dto in data.locations():
52+
async with self.unit_of_work:
53+
location = await self.unit_of_work.location.create_or_update(location_dto)
54+
location_dto.id = location.id
55+
56+
# To avoid deadlocks when parallel consumer instances insert/update the same row,
57+
# commit changes for each row instead of committing the whole batch. Yes, this cloud be slow.
58+
# But most entities are unchanged after creation, so we could just fetch them, and do nothing.
59+
async def create_datasets(self, data: BatchExtractionResult):
60+
self.logger.debug("Creating datasets")
61+
dataset_pairs = await self.unit_of_work.dataset.get_bulk(data.datasets())
62+
for dataset_dto, dataset in dataset_pairs:
63+
if not dataset:
64+
async with self.unit_of_work:
65+
dataset = await self.unit_of_work.dataset.create(dataset_dto) # noqa: PLW2901
66+
dataset_dto.id = dataset.id
67+
68+
async def create_dataset_symlinks(self, data: BatchExtractionResult):
69+
self.logger.debug("Creating dataset symlinks")
70+
dataset_symlinks_pairs = await self.unit_of_work.dataset_symlink.fetch_bulk(data.dataset_symlinks())
71+
for dataset_symlink_dto, dataset_symlink in dataset_symlinks_pairs:
72+
if not dataset_symlink:
73+
async with self.unit_of_work:
74+
dataset_symlink = await self.unit_of_work.dataset_symlink.create(dataset_symlink_dto) # noqa: PLW2901
75+
dataset_symlink_dto.id = dataset_symlink.id
76+
77+
async def create_job_types(self, data: BatchExtractionResult):
78+
self.logger.debug("Creating job types")
79+
job_type_pairs = await self.unit_of_work.job_type.get_bulk(data.job_types())
80+
for job_type_dto, job_type in job_type_pairs:
81+
if not job_type:
82+
async with self.unit_of_work:
83+
job_type = await self.unit_of_work.job_type.create(job_type_dto) # noqa: PLW2901
84+
job_type_dto.id = job_type.id
85+
86+
async def create_jobs(self, data: BatchExtractionResult):
87+
self.logger.debug("Creating jobs")
88+
job_pairs = await self.unit_of_work.job.get_bulk(data.jobs())
89+
for job_dto, job in job_pairs:
90+
async with self.unit_of_work:
91+
if not job:
92+
job = await self.unit_of_work.job.create_or_update(job_dto) # noqa: PLW2901
93+
else:
94+
job = await self.unit_of_work.job.update(job, job_dto) # noqa: PLW2901
95+
job_dto.id = job.id
96+
97+
async def create_users(self, data: BatchExtractionResult):
98+
self.logger.debug("Creating users")
99+
user_pairs = await self.unit_of_work.user.fetch_bulk(data.users())
100+
for user_dto, user in user_pairs:
101+
if not user:
102+
async with self.unit_of_work:
103+
user = await self.unit_of_work.user.create(user_dto) # noqa: PLW2901
104+
user_dto.id = user.id
105+
106+
async def create_sql_queries(self, data: BatchExtractionResult):
107+
self.logger.debug("Creating sql queries")
108+
sql_query_pairs = await self.unit_of_work.sql_query.fetch_bulk(data.sql_queries())
109+
for sql_query_dto, sql_query in sql_query_pairs:
110+
if not sql_query:
111+
async with self.unit_of_work:
112+
sql_query = await self.unit_of_work.sql_query.create(sql_query_dto) # noqa: PLW2901
113+
sql_query_dto.id = sql_query.id
114+
115+
async def create_schemas(self, data: BatchExtractionResult):
116+
self.logger.debug("Creating schemas")
117+
schema_pairs = await self.unit_of_work.schema.fetch_bulk(data.schemas())
118+
for schema_dto, schema in schema_pairs:
119+
if not schema:
120+
async with self.unit_of_work:
121+
schema = await self.unit_of_work.schema.create(schema_dto) # noqa: PLW2901
122+
schema_dto.id = schema.id
123+
124+
# In most cases, all the run tree created by some parent is send into one
125+
# Kafka partition, and thus handled by just one worker.
126+
# Cross fingers and create all runs in one transaction.
127+
async def create_runs_bulk(self, data: BatchExtractionResult):
128+
self.logger.debug("Creating runs in bulk")
129+
async with self.unit_of_work:
130+
await self.unit_of_work.run.create_or_update_bulk(data.runs())
131+
132+
# In case then child and parent runs are in different partitions,
133+
# multiple workers may try to create/update the same run, leading to a deadlock.
134+
# Fallback to creating runs one by one
135+
async def create_runs_one_by_one(self, data: BatchExtractionResult):
136+
self.logger.debug("Creating runs in one-by-one")
137+
run_pairs = await self.unit_of_work.run.fetch_bulk(data.runs())
138+
for run_dto, run in run_pairs:
139+
try:
140+
async with self.unit_of_work:
141+
if not run:
142+
await self.unit_of_work.run.create(run_dto)
143+
else:
144+
await self.unit_of_work.run.update(run, run_dto)
145+
except IntegrityError: # noqa: PERF203
146+
async with self.unit_of_work:
147+
await self.unit_of_work.run.create_or_update(run_dto)
148+
149+
# All events related to same operation are always send to the same Kafka partition,
150+
# so other workers never insert/update the same operation in parallel.
151+
# These rows can be inserted/updated in bulk, in one transaction.
152+
async def create_operations(self, data: BatchExtractionResult):
153+
async with self.unit_of_work:
154+
self.logger.debug("Creating operations")
155+
await self.unit_of_work.operation.create_or_update_bulk(data.operations())
156+
157+
async def create_inputs(self, data: BatchExtractionResult):
158+
async with self.unit_of_work:
159+
self.logger.debug("Creating inputs")
160+
await self.unit_of_work.input.create_or_update_bulk(data.inputs())
161+
162+
async def create_outputs(self, data: BatchExtractionResult):
163+
async with self.unit_of_work:
164+
self.logger.debug("Creating outputs")
165+
await self.unit_of_work.output.create_or_update_bulk(data.outputs())
166+
167+
async def create_column_lineage(self, data: BatchExtractionResult):
168+
async with self.unit_of_work:
169+
self.logger.debug("Creating dataset column relations")
170+
await self.unit_of_work.dataset_column_relation.create_bulk_for_column_lineage(data.column_lineage())
171+
172+
self.logger.debug("Creating column lineage")
173+
await self.unit_of_work.column_lineage.create_bulk(data.column_lineage())

data_rentgen/consumer/subscribers.py

Lines changed: 4 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
from pydantic import TypeAdapter
1515
from sqlalchemy.ext.asyncio import AsyncSession
1616

17-
from data_rentgen.consumer.extractors import BatchExtractionResult, BatchExtractor
17+
from data_rentgen.consumer.extractors import BatchExtractor
18+
from data_rentgen.consumer.saver import DatabaseSaver
1819
from data_rentgen.dependencies.stub import Stub
1920
from data_rentgen.openlineage.run_event import OpenLineageRunEvent
20-
from data_rentgen.services.uow import UnitOfWork
2121

2222
__all__ = [
2323
"runs_events_subscriber",
@@ -54,9 +54,8 @@ async def runs_events_subscriber(
5454
extracted = extractor.result
5555
logger.info("Got %r", extracted)
5656

57-
logger.info("Saving to database")
58-
await save_to_db(extracted, session, logger)
59-
logger.info("Saved successfully")
57+
saver = DatabaseSaver(session, logger)
58+
await saver.save(extracted)
6059

6160
if malformed:
6261
logger.warning("Malformed messages: %d", len(malformed))
@@ -88,96 +87,6 @@ async def parse_messages(
8887
await asyncio.sleep(0)
8988

9089

91-
async def save_to_db(
92-
data: BatchExtractionResult,
93-
session: AsyncSession,
94-
logger: Logger,
95-
) -> None:
96-
# To avoid deadlocks when parallel consumer instances insert/update the same row,
97-
# commit changes for each row instead of committing the whole batch. Yes, this cloud be slow.
98-
99-
unit_of_work = UnitOfWork(session)
100-
101-
logger.debug("Creating locations")
102-
for location_dto in data.locations():
103-
async with unit_of_work:
104-
location = await unit_of_work.location.create_or_update(location_dto)
105-
location_dto.id = location.id
106-
107-
logger.debug("Creating datasets")
108-
for dataset_dto in data.datasets():
109-
async with unit_of_work:
110-
dataset = await unit_of_work.dataset.get_or_create(dataset_dto)
111-
dataset_dto.id = dataset.id
112-
113-
logger.debug("Creating symlinks")
114-
for dataset_symlink_dto in data.dataset_symlinks():
115-
async with unit_of_work:
116-
dataset_symlink = await unit_of_work.dataset_symlink.get_or_create(dataset_symlink_dto)
117-
dataset_symlink_dto.id = dataset_symlink.id
118-
119-
logger.debug("Creating job types")
120-
for job_type_dto in data.job_types():
121-
async with unit_of_work:
122-
job_type = await unit_of_work.job_type.get_or_create(job_type_dto)
123-
job_type_dto.id = job_type.id
124-
125-
logger.debug("Creating jobs")
126-
for job_dto in data.jobs():
127-
async with unit_of_work:
128-
job = await unit_of_work.job.create_or_update(job_dto)
129-
job_dto.id = job.id
130-
131-
logger.debug("Creating sql queries")
132-
for sql_query_dto in data.sql_queries():
133-
async with unit_of_work:
134-
sql_query = await unit_of_work.sql_query.get_or_create(sql_query_dto)
135-
sql_query_dto.id = sql_query.id
136-
137-
logger.debug("Creating users")
138-
for user_dto in data.users():
139-
async with unit_of_work:
140-
user = await unit_of_work.user.get_or_create(user_dto)
141-
user_dto.id = user.id
142-
143-
logger.debug("Creating schemas")
144-
for schema_dto in data.schemas():
145-
async with unit_of_work:
146-
schema = await unit_of_work.schema.get_or_create(schema_dto)
147-
schema_dto.id = schema.id
148-
149-
# Some events related to specific run are send to the same Kafka partition,
150-
# but at the same time we have parent_run which may be already inserted/updated by other worker
151-
# (Kafka key maybe different for run and it's parent).
152-
# In this case we cannot insert all the rows in one transaction, as it may lead to deadlocks.
153-
logger.debug("Creating runs")
154-
for run_dto in data.runs():
155-
async with unit_of_work:
156-
await unit_of_work.run.create_or_update(run_dto)
157-
158-
# All events related to same operation are always send to the same Kafka partition,
159-
# so other workers never insert/update the same operation in parallel.
160-
# These rows can be inserted/updated in bulk, in one transaction.
161-
async with unit_of_work:
162-
logger.debug("Creating operations")
163-
await unit_of_work.operation.create_or_update_bulk(data.operations())
164-
165-
logger.debug("Creating inputs")
166-
await unit_of_work.input.create_or_update_bulk(data.inputs())
167-
168-
logger.debug("Creating outputs")
169-
await unit_of_work.output.create_or_update_bulk(data.outputs())
170-
171-
# If something went wrong here, at least we will have inputs/outputs
172-
async with unit_of_work:
173-
column_lineage = data.column_lineage()
174-
logger.debug("Creating dataset column relations")
175-
await unit_of_work.dataset_column_relation.create_bulk_for_column_lineage(column_lineage)
176-
177-
logger.debug("Creating column lineage")
178-
await unit_of_work.column_lineage.create_bulk(column_lineage)
179-
180-
18190
async def report_malformed(
18291
messages: list[ConsumerRecord],
18392
message_id: str,

data_rentgen/db/repositories/column_lineage.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import NamedTuple
77
from uuid import UUID
88

9-
from sqlalchemy import ColumnElement, any_, func, select, tuple_
9+
from sqlalchemy import ARRAY, ColumnElement, Integer, any_, cast, func, select, tuple_
1010
from sqlalchemy.dialects.postgresql import insert
1111

1212
from data_rentgen.db.models import ColumnLineage, DatasetColumnRelation
@@ -123,9 +123,20 @@ async def list_by_dataset_pairs(
123123
if not dataset_ids_pairs:
124124
return []
125125

126+
source_dataset_ids = [pair[0] for pair in dataset_ids_pairs]
127+
target_dataset_ids = [pair[1] for pair in dataset_ids_pairs]
128+
pairs = (
129+
func.unnest(
130+
cast(source_dataset_ids, ARRAY(Integer())),
131+
cast(target_dataset_ids, ARRAY(Integer())),
132+
)
133+
.table_valued("source_dataset_id", "target_dataset_id")
134+
.render_derived()
135+
)
136+
126137
where = [
127138
ColumnLineage.created_at >= since,
128-
tuple_(ColumnLineage.source_dataset_id, ColumnLineage.target_dataset_id).in_(dataset_ids_pairs),
139+
tuple_(ColumnLineage.source_dataset_id, ColumnLineage.target_dataset_id).in_(select(pairs)),
129140
]
130141
if until:
131142
where.append(

data_rentgen/db/repositories/dataset.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,22 @@
33
from collections.abc import Collection
44

55
from sqlalchemy import (
6+
ARRAY,
67
ColumnElement,
78
CompoundSelect,
9+
Integer,
810
Row,
911
Select,
1012
SQLColumnExpression,
13+
String,
1114
any_,
1215
asc,
16+
cast,
1317
desc,
1418
distinct,
1519
func,
1620
select,
21+
tuple_,
1722
union,
1823
)
1924
from sqlalchemy.orm import selectinload
@@ -25,19 +30,39 @@
2530

2631

2732
class DatasetRepository(Repository[Dataset]):
28-
async def get_or_create(self, dataset: DatasetDTO) -> Dataset:
29-
result = await self._get(dataset)
33+
async def get_bulk(self, datasets_dto: list[DatasetDTO]) -> list[tuple[DatasetDTO, Dataset | None]]:
34+
existing = await self._get_bulk(datasets_dto)
35+
result = []
36+
for dataset_dto in datasets_dto:
37+
key = (dataset_dto.location.id, dataset_dto.name.lower())
38+
dataset = existing.get(key) # type: ignore[arg-type]
39+
result.append((dataset_dto, dataset))
40+
return result
3041

31-
if not result:
32-
# try one more time, but with lock acquired.
33-
# if another worker already created the same row, just use it. if not - create with holding the lock.
34-
await self._lock(dataset.location.id, dataset.name)
35-
result = await self._get(dataset)
42+
async def _get_bulk(self, datasets_dto: list[DatasetDTO]) -> dict[tuple[int, str], Dataset]:
43+
location_ids = [dataset_dto.location.id for dataset_dto in datasets_dto]
44+
names = [dataset_dto.name.lower() for dataset_dto in datasets_dto]
45+
pairs = (
46+
func.unnest(
47+
cast(location_ids, ARRAY(Integer())),
48+
cast(names, ARRAY(String())),
49+
)
50+
.table_valued("location_id", "name")
51+
.render_derived()
52+
)
3653

37-
if not result:
38-
return await self._create(dataset)
54+
statement = select(Dataset).where(tuple_(Dataset.location_id, func.lower(Dataset.name)).in_(select(pairs)))
55+
scalars = await self._session.scalars(statement)
56+
result = {}
57+
for dataset in scalars.all():
58+
result[(dataset.location_id, dataset.name.lower())] = dataset
3959
return result
4060

61+
async def create(self, dataset: DatasetDTO) -> Dataset:
62+
# if another worker already created the same row, just use it. if not - create with holding the lock.
63+
await self._lock(dataset.location.id, dataset.name.lower())
64+
return await self._get(dataset) or await self._create(dataset)
65+
4166
async def paginate(
4267
self,
4368
page: int,

0 commit comments

Comments
 (0)