Skip to content

Commit 0c2b58c

Browse files
committed
[DOP-28871] Use bulk DB fetch in consumer
1 parent 194a730 commit 0c2b58c

File tree

15 files changed

+462
-291
lines changed

15 files changed

+462
-291
lines changed

data_rentgen/consumer/saver.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# SPDX-FileCopyrightText: 2024-2025 MTS PJSC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from faststream import Logger
7+
from sqlalchemy.exc import DatabaseError, IntegrityError
8+
from sqlalchemy.ext.asyncio import AsyncSession
9+
10+
from data_rentgen.consumer.extractors import BatchExtractionResult
11+
from data_rentgen.services.uow import UnitOfWork
12+
13+
14+
class DatabaseSaver:
15+
def __init__(
16+
self,
17+
session: AsyncSession,
18+
logger: Logger,
19+
) -> None:
20+
self.unit_of_work = UnitOfWork(session)
21+
self.logger = logger
22+
23+
async def save(self, data: BatchExtractionResult):
24+
self.logger.info("Saving to database")
25+
26+
await self.create_locations(data)
27+
await self.create_datasets(data)
28+
await self.create_dataset_symlinks(data)
29+
await self.create_job_types(data)
30+
await self.create_jobs(data)
31+
await self.create_users(data)
32+
await self.create_sql_queries(data)
33+
await self.create_schemas(data)
34+
35+
try:
36+
await self.create_runs_bulk(data)
37+
except DatabaseError:
38+
await self.create_runs_one_by_one(data)
39+
40+
await self.create_operations(data)
41+
await self.create_inputs(data)
42+
await self.create_outputs(data)
43+
await self.create_column_lineage(data)
44+
45+
self.logger.info("Saved successfully")
46+
47+
async def create_locations(self, data: BatchExtractionResult):
48+
self.logger.debug("Creating locations")
49+
# It's hard to fetch locations in bulk, and number of locations is usually small,
50+
# so using a row-by-row approach
51+
for location_dto in data.locations():
52+
async with self.unit_of_work:
53+
location = await self.unit_of_work.location.create_or_update(location_dto)
54+
location_dto.id = location.id
55+
56+
# To avoid deadlocks when parallel consumer instances insert/update the same row,
57+
# commit changes for each row instead of committing the whole batch. Yes, this cloud be slow.
58+
# But most entities are unchanged after creation, so we could just fetch them, and do nothing.
59+
async def create_datasets(self, data: BatchExtractionResult):
60+
self.logger.debug("Creating datasets")
61+
dataset_pairs = await self.unit_of_work.dataset.get_bulk(data.datasets())
62+
for dataset_dto, dataset in dataset_pairs:
63+
if not dataset:
64+
async with self.unit_of_work:
65+
dataset = await self.unit_of_work.dataset.create(dataset_dto) # noqa: PLW2901
66+
dataset_dto.id = dataset.id
67+
68+
async def create_dataset_symlinks(self, data: BatchExtractionResult):
69+
self.logger.debug("Creating dataset symlinks")
70+
dataset_symlinks_pairs = await self.unit_of_work.dataset_symlink.fetch_bulk(data.dataset_symlinks())
71+
for dataset_symlink_dto, dataset_symlink in dataset_symlinks_pairs:
72+
if not dataset_symlink:
73+
async with self.unit_of_work:
74+
dataset_symlink = await self.unit_of_work.dataset_symlink.create(dataset_symlink_dto) # noqa: PLW2901
75+
dataset_symlink_dto.id = dataset_symlink.id
76+
77+
async def create_job_types(self, data: BatchExtractionResult):
78+
self.logger.debug("Creating job types")
79+
job_type_pairs = await self.unit_of_work.job_type.get_bulk(data.job_types())
80+
for job_type_dto, job_type in job_type_pairs:
81+
if not job_type:
82+
async with self.unit_of_work:
83+
job_type = await self.unit_of_work.job_type.create(job_type_dto) # noqa: PLW2901
84+
job_type_dto.id = job_type.id
85+
86+
async def create_jobs(self, data: BatchExtractionResult):
87+
self.logger.debug("Creating jobs")
88+
job_pairs = await self.unit_of_work.job.get_bulk(data.jobs())
89+
for job_dto, job in job_pairs:
90+
async with self.unit_of_work:
91+
if not job:
92+
job = await self.unit_of_work.job.create_or_update(job_dto) # noqa: PLW2901
93+
else:
94+
job = await self.unit_of_work.job.update(job, job_dto) # noqa: PLW2901
95+
job_dto.id = job.id
96+
97+
async def create_users(self, data: BatchExtractionResult):
98+
self.logger.debug("Creating users")
99+
user_pairs = await self.unit_of_work.user.fetch_bulk(data.users())
100+
for user_dto, user in user_pairs:
101+
if not user:
102+
async with self.unit_of_work:
103+
user = await self.unit_of_work.user.create(user_dto) # noqa: PLW2901
104+
user_dto.id = user.id
105+
106+
async def create_sql_queries(self, data: BatchExtractionResult):
107+
self.logger.debug("Creating sql queries")
108+
sql_query_ids = await self.unit_of_work.sql_query.fetch_known_ids(data.sql_queries())
109+
for sql_query_dto, sql_query_id in sql_query_ids:
110+
if not sql_query_id:
111+
async with self.unit_of_work:
112+
sql_query = await self.unit_of_work.sql_query.create(sql_query_dto)
113+
sql_query_dto.id = sql_query.id
114+
else:
115+
sql_query_dto.id = sql_query_id
116+
117+
async def create_schemas(self, data: BatchExtractionResult):
118+
self.logger.debug("Creating schemas")
119+
schema_ids = await self.unit_of_work.schema.fetch_known_ids(data.schemas())
120+
for schema_dto, schema_id in schema_ids:
121+
if not schema_id:
122+
async with self.unit_of_work:
123+
schema = await self.unit_of_work.schema.create(schema_dto)
124+
schema_dto.id = schema.id
125+
else:
126+
schema_dto.id = schema_id
127+
128+
# In most cases, all the run tree created by some parent is send into one
129+
# Kafka partition, and thus handled by just one worker.
130+
# Cross fingers and create all runs in one transaction.
131+
async def create_runs_bulk(self, data: BatchExtractionResult):
132+
self.logger.debug("Creating runs in bulk")
133+
async with self.unit_of_work:
134+
await self.unit_of_work.run.create_or_update_bulk(data.runs())
135+
136+
# In case then child and parent runs are in different partitions,
137+
# multiple workers may try to create/update the same run, leading to a deadlock.
138+
# Fallback to creating runs one by one
139+
async def create_runs_one_by_one(self, data: BatchExtractionResult):
140+
self.logger.debug("Creating runs in one-by-one")
141+
run_pairs = await self.unit_of_work.run.fetch_bulk(data.runs())
142+
for run_dto, run in run_pairs:
143+
try:
144+
async with self.unit_of_work:
145+
if not run:
146+
await self.unit_of_work.run.create(run_dto)
147+
else:
148+
await self.unit_of_work.run.update(run, run_dto)
149+
except IntegrityError: # noqa: PERF203
150+
# deadlock occurred, states in DB and RAM are out of sync,
151+
# so we have to fetch run from DB
152+
async with self.unit_of_work:
153+
await self.unit_of_work.run.create_or_update(run_dto)
154+
155+
# All events related to same operation are always send to the same Kafka partition,
156+
# so other workers never insert/update the same operation in parallel.
157+
# These rows can be inserted/updated in bulk, in one transaction.
158+
async def create_operations(self, data: BatchExtractionResult):
159+
async with self.unit_of_work:
160+
self.logger.debug("Creating operations")
161+
await self.unit_of_work.operation.create_or_update_bulk(data.operations())
162+
163+
async def create_inputs(self, data: BatchExtractionResult):
164+
async with self.unit_of_work:
165+
self.logger.debug("Creating inputs")
166+
await self.unit_of_work.input.create_or_update_bulk(data.inputs())
167+
168+
async def create_outputs(self, data: BatchExtractionResult):
169+
async with self.unit_of_work:
170+
self.logger.debug("Creating outputs")
171+
await self.unit_of_work.output.create_or_update_bulk(data.outputs())
172+
173+
async def create_column_lineage(self, data: BatchExtractionResult):
174+
async with self.unit_of_work:
175+
self.logger.debug("Creating dataset column relations")
176+
await self.unit_of_work.dataset_column_relation.create_bulk_for_column_lineage(data.column_lineage())
177+
178+
self.logger.debug("Creating column lineage")
179+
await self.unit_of_work.column_lineage.create_bulk(data.column_lineage())

data_rentgen/consumer/subscribers.py

Lines changed: 4 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
from pydantic import TypeAdapter
1515
from sqlalchemy.ext.asyncio import AsyncSession
1616

17-
from data_rentgen.consumer.extractors import BatchExtractionResult, BatchExtractor
17+
from data_rentgen.consumer.extractors import BatchExtractor
18+
from data_rentgen.consumer.saver import DatabaseSaver
1819
from data_rentgen.dependencies.stub import Stub
1920
from data_rentgen.openlineage.run_event import OpenLineageRunEvent
20-
from data_rentgen.services.uow import UnitOfWork
2121

2222
__all__ = [
2323
"runs_events_subscriber",
@@ -54,9 +54,8 @@ async def runs_events_subscriber(
5454
extracted = extractor.result
5555
logger.info("Got %r", extracted)
5656

57-
logger.info("Saving to database")
58-
await save_to_db(extracted, session, logger)
59-
logger.info("Saved successfully")
57+
saver = DatabaseSaver(session, logger)
58+
await saver.save(extracted)
6059

6160
if malformed:
6261
logger.warning("Malformed messages: %d", len(malformed))
@@ -88,96 +87,6 @@ async def parse_messages(
8887
await asyncio.sleep(0)
8988

9089

91-
async def save_to_db(
92-
data: BatchExtractionResult,
93-
session: AsyncSession,
94-
logger: Logger,
95-
) -> None:
96-
# To avoid deadlocks when parallel consumer instances insert/update the same row,
97-
# commit changes for each row instead of committing the whole batch. Yes, this cloud be slow.
98-
99-
unit_of_work = UnitOfWork(session)
100-
101-
logger.debug("Creating locations")
102-
for location_dto in data.locations():
103-
async with unit_of_work:
104-
location = await unit_of_work.location.create_or_update(location_dto)
105-
location_dto.id = location.id
106-
107-
logger.debug("Creating datasets")
108-
for dataset_dto in data.datasets():
109-
async with unit_of_work:
110-
dataset = await unit_of_work.dataset.get_or_create(dataset_dto)
111-
dataset_dto.id = dataset.id
112-
113-
logger.debug("Creating symlinks")
114-
for dataset_symlink_dto in data.dataset_symlinks():
115-
async with unit_of_work:
116-
dataset_symlink = await unit_of_work.dataset_symlink.get_or_create(dataset_symlink_dto)
117-
dataset_symlink_dto.id = dataset_symlink.id
118-
119-
logger.debug("Creating job types")
120-
for job_type_dto in data.job_types():
121-
async with unit_of_work:
122-
job_type = await unit_of_work.job_type.get_or_create(job_type_dto)
123-
job_type_dto.id = job_type.id
124-
125-
logger.debug("Creating jobs")
126-
for job_dto in data.jobs():
127-
async with unit_of_work:
128-
job = await unit_of_work.job.create_or_update(job_dto)
129-
job_dto.id = job.id
130-
131-
logger.debug("Creating sql queries")
132-
for sql_query_dto in data.sql_queries():
133-
async with unit_of_work:
134-
sql_query = await unit_of_work.sql_query.get_or_create(sql_query_dto)
135-
sql_query_dto.id = sql_query.id
136-
137-
logger.debug("Creating users")
138-
for user_dto in data.users():
139-
async with unit_of_work:
140-
user = await unit_of_work.user.get_or_create(user_dto)
141-
user_dto.id = user.id
142-
143-
logger.debug("Creating schemas")
144-
for schema_dto in data.schemas():
145-
async with unit_of_work:
146-
schema = await unit_of_work.schema.get_or_create(schema_dto)
147-
schema_dto.id = schema.id
148-
149-
# Some events related to specific run are send to the same Kafka partition,
150-
# but at the same time we have parent_run which may be already inserted/updated by other worker
151-
# (Kafka key maybe different for run and it's parent).
152-
# In this case we cannot insert all the rows in one transaction, as it may lead to deadlocks.
153-
logger.debug("Creating runs")
154-
for run_dto in data.runs():
155-
async with unit_of_work:
156-
await unit_of_work.run.create_or_update(run_dto)
157-
158-
# All events related to same operation are always send to the same Kafka partition,
159-
# so other workers never insert/update the same operation in parallel.
160-
# These rows can be inserted/updated in bulk, in one transaction.
161-
async with unit_of_work:
162-
logger.debug("Creating operations")
163-
await unit_of_work.operation.create_or_update_bulk(data.operations())
164-
165-
logger.debug("Creating inputs")
166-
await unit_of_work.input.create_or_update_bulk(data.inputs())
167-
168-
logger.debug("Creating outputs")
169-
await unit_of_work.output.create_or_update_bulk(data.outputs())
170-
171-
# If something went wrong here, at least we will have inputs/outputs
172-
async with unit_of_work:
173-
column_lineage = data.column_lineage()
174-
logger.debug("Creating dataset column relations")
175-
await unit_of_work.dataset_column_relation.create_bulk_for_column_lineage(column_lineage)
176-
177-
logger.debug("Creating column lineage")
178-
await unit_of_work.column_lineage.create_bulk(column_lineage)
179-
180-
18190
async def report_malformed(
18291
messages: list[ConsumerRecord],
18392
message_id: str,

data_rentgen/db/repositories/column_lineage.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import NamedTuple
77
from uuid import UUID
88

9-
from sqlalchemy import ColumnElement, any_, func, select, tuple_
9+
from sqlalchemy import ARRAY, ColumnElement, Integer, any_, cast, func, select, tuple_
1010
from sqlalchemy.dialects.postgresql import insert
1111

1212
from data_rentgen.db.models import ColumnLineage, DatasetColumnRelation
@@ -123,9 +123,20 @@ async def list_by_dataset_pairs(
123123
if not dataset_ids_pairs:
124124
return []
125125

126+
source_dataset_ids = [pair[0] for pair in dataset_ids_pairs]
127+
target_dataset_ids = [pair[1] for pair in dataset_ids_pairs]
128+
pairs = (
129+
func.unnest(
130+
cast(source_dataset_ids, ARRAY(Integer())),
131+
cast(target_dataset_ids, ARRAY(Integer())),
132+
)
133+
.table_valued("source_dataset_id", "target_dataset_id")
134+
.render_derived()
135+
)
136+
126137
where = [
127138
ColumnLineage.created_at >= since,
128-
tuple_(ColumnLineage.source_dataset_id, ColumnLineage.target_dataset_id).in_(dataset_ids_pairs),
139+
tuple_(ColumnLineage.source_dataset_id, ColumnLineage.target_dataset_id).in_(select(pairs)),
129140
]
130141
if until:
131142
where.append(

0 commit comments

Comments
 (0)