Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions data_rentgen/db/repositories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ScalarResult,
Select,
SQLColumnExpression,
bindparam,
func,
select,
)
Expand All @@ -24,6 +25,8 @@

Model = TypeVar("Model", bound=Base)

advisory_lock_statement = select(func.pg_advisory_xact_lock(bindparam("key")))


class Repository(ABC, Generic[Model]):
def __init__(
Expand Down Expand Up @@ -81,5 +84,4 @@ async def _lock(
digest = sha1(data.encode("utf-8"), usedforsecurity=False).digest()
# sha1 returns 160bit hash, we need only first 64 bits
lock_key = int.from_bytes(digest[:8], byteorder="big", signed=True)
statement = select(func.pg_advisory_xact_lock(lock_key))
await self._session.execute(statement)
await self._session.execute(advisory_lock_statement, {"key": lock_key})
86 changes: 51 additions & 35 deletions data_rentgen/db/repositories/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
String,
any_,
asc,
bindparam,
cast,
desc,
distinct,
Expand All @@ -28,25 +29,55 @@
from data_rentgen.db.utils.search import make_tsquery, ts_match, ts_rank
from data_rentgen.dto import DatasetDTO, PaginationDTO

fetch_bulk_query = select(Dataset).where(
tuple_(Dataset.location_id, func.lower(Dataset.name)).in_(
select(
func.unnest(
cast(bindparam("location_ids"), ARRAY(Integer())),
cast(bindparam("names_lower"), ARRAY(String())),
)
.table_valued("location_id", "name_lower")
.render_derived(),
),
),
)

get_list_query = (
select(Dataset)
.where(Dataset.id == any_(bindparam("dataset_ids")))
.options(selectinload(Dataset.location).selectinload(Location.addresses))
.options(selectinload(Dataset.tag_values).selectinload(TagValue.tag))
)

get_one_query = select(Dataset).where(
Dataset.location_id == bindparam("location_id"),
func.lower(Dataset.name) == bindparam("name_lower"),
)

get_stats_query = (
select(
Dataset.location_id.label("location_id"),
func.count(Dataset.id.distinct()).label("total_datasets"),
)
.where(
Dataset.location_id == any_(bindparam("location_ids")),
)
.group_by(Dataset.location_id)
)


class DatasetRepository(Repository[Dataset]):
async def fetch_bulk(self, datasets_dto: list[DatasetDTO]) -> list[tuple[DatasetDTO, Dataset | None]]:
if not datasets_dto:
return []

location_ids = [dataset_dto.location.id for dataset_dto in datasets_dto]
names = [dataset_dto.name.lower() for dataset_dto in datasets_dto]
pairs = (
func.unnest(
cast(location_ids, ARRAY(Integer())),
cast(names, ARRAY(String())),
)
.table_valued("location_id", "name")
.render_derived()
scalars = await self._session.scalars(
fetch_bulk_query,
{
"location_ids": [item.location.id for item in datasets_dto],
"names_lower": [item.name.lower() for item in datasets_dto],
},
)

statement = select(Dataset).where(tuple_(Dataset.location_id, func.lower(Dataset.name)).in_(select(pairs)))
scalars = await self._session.scalars(statement)
existing = {(dataset.location_id, dataset.name.lower()): dataset for dataset in scalars.all()}
return [
(
Expand Down Expand Up @@ -139,39 +170,24 @@ async def paginate(
async def list_by_ids(self, dataset_ids: Collection[int]) -> list[Dataset]:
if not dataset_ids:
return []
query = (
select(Dataset)
.where(Dataset.id == any_(list(dataset_ids))) # type: ignore[arg-type]
.options(selectinload(Dataset.location).selectinload(Location.addresses))
.options(selectinload(Dataset.tag_values).selectinload(TagValue.tag))
)
result = await self._session.scalars(query)
result = await self._session.scalars(get_list_query, {"dataset_ids": list(dataset_ids)})
return list(result.all())

async def get_stats_by_location_ids(self, location_ids: Collection[int]) -> dict[int, Row]:
if not location_ids:
return {}

query = (
select(
Dataset.location_id.label("location_id"),
func.count(Dataset.id.distinct()).label("total_datasets"),
)
.where(
Dataset.location_id == any_(list(location_ids)), # type: ignore[arg-type]
)
.group_by(Dataset.location_id)
)

query_result = await self._session.execute(query)
query_result = await self._session.execute(get_stats_query, {"location_ids": list(location_ids)})
return {row.location_id: row for row in query_result.all()}

async def _get(self, dataset: DatasetDTO) -> Dataset | None:
statement = select(Dataset).where(
Dataset.location_id == dataset.location.id,
func.lower(Dataset.name) == dataset.name.lower(),
return await self._session.scalar(
get_one_query,
{
"location_id": dataset.location.id,
"name_lower": dataset.name.lower(),
},
)
return await self._session.scalar(statement)

async def _create(self, dataset: DatasetDTO) -> Dataset:
result = Dataset(location_id=dataset.location.id, name=dataset.name)
Expand Down
67 changes: 39 additions & 28 deletions data_rentgen/db/repositories/dataset_symlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,37 @@

from collections.abc import Collection

from sqlalchemy import ARRAY, BindParameter, Integer, any_, bindparam, cast, func, or_, select, tuple_
from sqlalchemy import ARRAY, Integer, any_, bindparam, cast, func, or_, select, tuple_

from data_rentgen.db.models.dataset_symlink import DatasetSymlink, DatasetSymlinkType
from data_rentgen.db.repositories.base import Repository
from data_rentgen.dto import DatasetSymlinkDTO

fetch_bulk_query = select(DatasetSymlink).where(
tuple_(DatasetSymlink.from_dataset_id, DatasetSymlink.to_dataset_id).in_(
select(
func.unnest(
cast(bindparam("from_dataset_ids"), ARRAY(Integer())),
cast(bindparam("to_dataset_ids"), ARRAY(Integer())),
)
.table_valued("from_dataset_ids", "to_dataset_ids")
.render_derived(),
),
),
)

get_list_query = select(DatasetSymlink).where(
or_(
DatasetSymlink.from_dataset_id == any_(bindparam("dataset_ids")),
DatasetSymlink.to_dataset_id == any_(bindparam("dataset_ids")),
),
)

get_one_query = select(DatasetSymlink).where(
DatasetSymlink.from_dataset_id == bindparam("from_dataset_id"),
DatasetSymlink.to_dataset_id == bindparam("to_dataset_id"),
)


class DatasetSymlinkRepository(Repository[DatasetSymlink]):
async def fetch_bulk(
Expand All @@ -18,22 +43,13 @@ async def fetch_bulk(
if not dataset_symlinks_dto:
return []

from_dataset_ids = [dataset_symlink_dto.from_dataset.id for dataset_symlink_dto in dataset_symlinks_dto]
to_dataset_ids = [dataset_symlink_dto.to_dataset.id for dataset_symlink_dto in dataset_symlinks_dto]

pairs = (
func.unnest(
cast(from_dataset_ids, ARRAY(Integer())),
cast(to_dataset_ids, ARRAY(Integer())),
)
.table_valued("from_dataset_ids", "to_dataset_ids")
.render_derived()
scalars = await self._session.scalars(
fetch_bulk_query,
{
"from_dataset_ids": [item.from_dataset.id for item in dataset_symlinks_dto],
"to_dataset_ids": [item.to_dataset.id for item in dataset_symlinks_dto],
},
)

statement = select(DatasetSymlink).where(
tuple_(DatasetSymlink.from_dataset_id, DatasetSymlink.to_dataset_id).in_(select(pairs)),
)
scalars = await self._session.scalars(statement)
existing = {(item.from_dataset_id, item.to_dataset_id): item for item in scalars.all()}
return [
(
Expand All @@ -52,22 +68,17 @@ async def list_by_dataset_ids(self, dataset_ids: Collection[int]) -> list[Datase
if not dataset_ids:
return []

param: BindParameter[list[int]] = bindparam("dataset_ids")
query = select(DatasetSymlink).where(
or_(
DatasetSymlink.from_dataset_id == any_(param),
DatasetSymlink.to_dataset_id == any_(param),
),
)
scalars = await self._session.scalars(query, {"dataset_ids": list(dataset_ids)})
scalars = await self._session.scalars(get_list_query, {"dataset_ids": list(dataset_ids)})
return list(scalars.all())

async def _get(self, dataset_symlink: DatasetSymlinkDTO) -> DatasetSymlink | None:
query = select(DatasetSymlink).where(
DatasetSymlink.from_dataset_id == dataset_symlink.from_dataset.id,
DatasetSymlink.to_dataset_id == dataset_symlink.to_dataset.id,
return await self._session.scalar(
get_one_query,
{
"from_dataset_id": dataset_symlink.from_dataset.id,
"to_dataset_id": dataset_symlink.to_dataset.id,
},
)
return await self._session.scalar(query)

async def _create(self, dataset_symlink: DatasetSymlinkDTO) -> DatasetSymlink:
result = DatasetSymlink(
Expand Down
24 changes: 12 additions & 12 deletions data_rentgen/db/repositories/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@
extract_timestamp_from_uuid,
)

insert_statement = insert(Input)
inserted_row = insert_statement.excluded
insert_statement = insert_statement.on_conflict_do_update(
index_elements=[Input.created_at, Input.id],
set_={
"num_bytes": func.greatest(inserted_row.num_bytes, Input.num_bytes),
"num_rows": func.greatest(inserted_row.num_rows, Input.num_rows),
"num_files": func.greatest(inserted_row.num_files, Input.num_files),
},
)


@dataclass
class InputRow:
Expand All @@ -37,19 +48,8 @@ async def create_or_update_bulk(self, inputs: list[InputDTO]) -> None:
if not inputs:
return

insert_statement = insert(Input)
new_row = insert_statement.excluded
statement = insert_statement.on_conflict_do_update(
index_elements=[Input.created_at, Input.id],
set_={
"num_bytes": func.greatest(new_row.num_bytes, Input.num_bytes),
"num_rows": func.greatest(new_row.num_rows, Input.num_rows),
"num_files": func.greatest(new_row.num_files, Input.num_files),
},
)

await self._session.execute(
statement,
insert_statement,
[
{
"id": item.generate_id(),
Expand Down
Loading
Loading