diff --git a/data_rentgen/db/repositories/base.py b/data_rentgen/db/repositories/base.py index bddc8c4c..e1b6f97b 100644 --- a/data_rentgen/db/repositories/base.py +++ b/data_rentgen/db/repositories/base.py @@ -13,6 +13,7 @@ ScalarResult, Select, SQLColumnExpression, + bindparam, func, select, ) @@ -24,6 +25,8 @@ Model = TypeVar("Model", bound=Base) +advisory_lock_statement = select(func.pg_advisory_xact_lock(bindparam("key"))) + class Repository(ABC, Generic[Model]): def __init__( @@ -81,5 +84,4 @@ async def _lock( digest = sha1(data.encode("utf-8"), usedforsecurity=False).digest() # sha1 returns 160bit hash, we need only first 64 bits lock_key = int.from_bytes(digest[:8], byteorder="big", signed=True) - statement = select(func.pg_advisory_xact_lock(lock_key)) - await self._session.execute(statement) + await self._session.execute(advisory_lock_statement, {"key": lock_key}) diff --git a/data_rentgen/db/repositories/dataset.py b/data_rentgen/db/repositories/dataset.py index 8f38e81c..128afe8f 100644 --- a/data_rentgen/db/repositories/dataset.py +++ b/data_rentgen/db/repositories/dataset.py @@ -13,6 +13,7 @@ String, any_, asc, + bindparam, cast, desc, distinct, @@ -28,25 +29,55 @@ from data_rentgen.db.utils.search import make_tsquery, ts_match, ts_rank from data_rentgen.dto import DatasetDTO, PaginationDTO +fetch_bulk_query = select(Dataset).where( + tuple_(Dataset.location_id, func.lower(Dataset.name)).in_( + select( + func.unnest( + cast(bindparam("location_ids"), ARRAY(Integer())), + cast(bindparam("names_lower"), ARRAY(String())), + ) + .table_valued("location_id", "name_lower") + .render_derived(), + ), + ), +) + +get_list_query = ( + select(Dataset) + .where(Dataset.id == any_(bindparam("dataset_ids"))) + .options(selectinload(Dataset.location).selectinload(Location.addresses)) + .options(selectinload(Dataset.tag_values).selectinload(TagValue.tag)) +) + +get_one_query = select(Dataset).where( + Dataset.location_id == bindparam("location_id"), + func.lower(Dataset.name) == bindparam("name_lower"), +) + +get_stats_query = ( + select( + Dataset.location_id.label("location_id"), + func.count(Dataset.id.distinct()).label("total_datasets"), + ) + .where( + Dataset.location_id == any_(bindparam("location_ids")), + ) + .group_by(Dataset.location_id) +) + class DatasetRepository(Repository[Dataset]): async def fetch_bulk(self, datasets_dto: list[DatasetDTO]) -> list[tuple[DatasetDTO, Dataset | None]]: if not datasets_dto: return [] - location_ids = [dataset_dto.location.id for dataset_dto in datasets_dto] - names = [dataset_dto.name.lower() for dataset_dto in datasets_dto] - pairs = ( - func.unnest( - cast(location_ids, ARRAY(Integer())), - cast(names, ARRAY(String())), - ) - .table_valued("location_id", "name") - .render_derived() + scalars = await self._session.scalars( + fetch_bulk_query, + { + "location_ids": [item.location.id for item in datasets_dto], + "names_lower": [item.name.lower() for item in datasets_dto], + }, ) - - statement = select(Dataset).where(tuple_(Dataset.location_id, func.lower(Dataset.name)).in_(select(pairs))) - scalars = await self._session.scalars(statement) existing = {(dataset.location_id, dataset.name.lower()): dataset for dataset in scalars.all()} return [ ( @@ -139,39 +170,24 @@ async def paginate( async def list_by_ids(self, dataset_ids: Collection[int]) -> list[Dataset]: if not dataset_ids: return [] - query = ( - select(Dataset) - .where(Dataset.id == any_(list(dataset_ids))) # type: ignore[arg-type] - .options(selectinload(Dataset.location).selectinload(Location.addresses)) - .options(selectinload(Dataset.tag_values).selectinload(TagValue.tag)) - ) - result = await self._session.scalars(query) + result = await self._session.scalars(get_list_query, {"dataset_ids": list(dataset_ids)}) return list(result.all()) async def get_stats_by_location_ids(self, location_ids: Collection[int]) -> dict[int, Row]: if not location_ids: return {} - query = ( - select( - Dataset.location_id.label("location_id"), - func.count(Dataset.id.distinct()).label("total_datasets"), - ) - .where( - Dataset.location_id == any_(list(location_ids)), # type: ignore[arg-type] - ) - .group_by(Dataset.location_id) - ) - - query_result = await self._session.execute(query) + query_result = await self._session.execute(get_stats_query, {"location_ids": list(location_ids)}) return {row.location_id: row for row in query_result.all()} async def _get(self, dataset: DatasetDTO) -> Dataset | None: - statement = select(Dataset).where( - Dataset.location_id == dataset.location.id, - func.lower(Dataset.name) == dataset.name.lower(), + return await self._session.scalar( + get_one_query, + { + "location_id": dataset.location.id, + "name_lower": dataset.name.lower(), + }, ) - return await self._session.scalar(statement) async def _create(self, dataset: DatasetDTO) -> Dataset: result = Dataset(location_id=dataset.location.id, name=dataset.name) diff --git a/data_rentgen/db/repositories/dataset_symlink.py b/data_rentgen/db/repositories/dataset_symlink.py index 4a8a6f45..714bad32 100644 --- a/data_rentgen/db/repositories/dataset_symlink.py +++ b/data_rentgen/db/repositories/dataset_symlink.py @@ -3,12 +3,37 @@ from collections.abc import Collection -from sqlalchemy import ARRAY, BindParameter, Integer, any_, bindparam, cast, func, or_, select, tuple_ +from sqlalchemy import ARRAY, Integer, any_, bindparam, cast, func, or_, select, tuple_ from data_rentgen.db.models.dataset_symlink import DatasetSymlink, DatasetSymlinkType from data_rentgen.db.repositories.base import Repository from data_rentgen.dto import DatasetSymlinkDTO +fetch_bulk_query = select(DatasetSymlink).where( + tuple_(DatasetSymlink.from_dataset_id, DatasetSymlink.to_dataset_id).in_( + select( + func.unnest( + cast(bindparam("from_dataset_ids"), ARRAY(Integer())), + cast(bindparam("to_dataset_ids"), ARRAY(Integer())), + ) + .table_valued("from_dataset_ids", "to_dataset_ids") + .render_derived(), + ), + ), +) + +get_list_query = select(DatasetSymlink).where( + or_( + DatasetSymlink.from_dataset_id == any_(bindparam("dataset_ids")), + DatasetSymlink.to_dataset_id == any_(bindparam("dataset_ids")), + ), +) + +get_one_query = select(DatasetSymlink).where( + DatasetSymlink.from_dataset_id == bindparam("from_dataset_id"), + DatasetSymlink.to_dataset_id == bindparam("to_dataset_id"), +) + class DatasetSymlinkRepository(Repository[DatasetSymlink]): async def fetch_bulk( @@ -18,22 +43,13 @@ async def fetch_bulk( if not dataset_symlinks_dto: return [] - from_dataset_ids = [dataset_symlink_dto.from_dataset.id for dataset_symlink_dto in dataset_symlinks_dto] - to_dataset_ids = [dataset_symlink_dto.to_dataset.id for dataset_symlink_dto in dataset_symlinks_dto] - - pairs = ( - func.unnest( - cast(from_dataset_ids, ARRAY(Integer())), - cast(to_dataset_ids, ARRAY(Integer())), - ) - .table_valued("from_dataset_ids", "to_dataset_ids") - .render_derived() + scalars = await self._session.scalars( + fetch_bulk_query, + { + "from_dataset_ids": [item.from_dataset.id for item in dataset_symlinks_dto], + "to_dataset_ids": [item.to_dataset.id for item in dataset_symlinks_dto], + }, ) - - statement = select(DatasetSymlink).where( - tuple_(DatasetSymlink.from_dataset_id, DatasetSymlink.to_dataset_id).in_(select(pairs)), - ) - scalars = await self._session.scalars(statement) existing = {(item.from_dataset_id, item.to_dataset_id): item for item in scalars.all()} return [ ( @@ -52,22 +68,17 @@ async def list_by_dataset_ids(self, dataset_ids: Collection[int]) -> list[Datase if not dataset_ids: return [] - param: BindParameter[list[int]] = bindparam("dataset_ids") - query = select(DatasetSymlink).where( - or_( - DatasetSymlink.from_dataset_id == any_(param), - DatasetSymlink.to_dataset_id == any_(param), - ), - ) - scalars = await self._session.scalars(query, {"dataset_ids": list(dataset_ids)}) + scalars = await self._session.scalars(get_list_query, {"dataset_ids": list(dataset_ids)}) return list(scalars.all()) async def _get(self, dataset_symlink: DatasetSymlinkDTO) -> DatasetSymlink | None: - query = select(DatasetSymlink).where( - DatasetSymlink.from_dataset_id == dataset_symlink.from_dataset.id, - DatasetSymlink.to_dataset_id == dataset_symlink.to_dataset.id, + return await self._session.scalar( + get_one_query, + { + "from_dataset_id": dataset_symlink.from_dataset.id, + "to_dataset_id": dataset_symlink.to_dataset.id, + }, ) - return await self._session.scalar(query) async def _create(self, dataset_symlink: DatasetSymlinkDTO) -> DatasetSymlink: result = DatasetSymlink( diff --git a/data_rentgen/db/repositories/input.py b/data_rentgen/db/repositories/input.py index 98cb36f5..2fc4b71d 100644 --- a/data_rentgen/db/repositories/input.py +++ b/data_rentgen/db/repositories/input.py @@ -16,6 +16,17 @@ extract_timestamp_from_uuid, ) +insert_statement = insert(Input) +inserted_row = insert_statement.excluded +insert_statement = insert_statement.on_conflict_do_update( + index_elements=[Input.created_at, Input.id], + set_={ + "num_bytes": func.greatest(inserted_row.num_bytes, Input.num_bytes), + "num_rows": func.greatest(inserted_row.num_rows, Input.num_rows), + "num_files": func.greatest(inserted_row.num_files, Input.num_files), + }, +) + @dataclass class InputRow: @@ -37,19 +48,8 @@ async def create_or_update_bulk(self, inputs: list[InputDTO]) -> None: if not inputs: return - insert_statement = insert(Input) - new_row = insert_statement.excluded - statement = insert_statement.on_conflict_do_update( - index_elements=[Input.created_at, Input.id], - set_={ - "num_bytes": func.greatest(new_row.num_bytes, Input.num_bytes), - "num_rows": func.greatest(new_row.num_rows, Input.num_rows), - "num_files": func.greatest(new_row.num_files, Input.num_files), - }, - ) - await self._session.execute( - statement, + insert_statement, [ { "id": item.generate_id(), diff --git a/data_rentgen/db/repositories/job.py b/data_rentgen/db/repositories/job.py index 66163080..75a475f9 100644 --- a/data_rentgen/db/repositories/job.py +++ b/data_rentgen/db/repositories/job.py @@ -13,6 +13,7 @@ String, any_, asc, + bindparam, cast, desc, func, @@ -30,6 +31,44 @@ UNKNOWN_JOB_TYPE = 0 +fetch_bulk_query = select(Job).where( + tuple_(Job.location_id, func.lower(Job.name)).in_( + select( + func.unnest( + cast(bindparam("location_ids"), ARRAY(Integer())), + cast(bindparam("names_lower"), ARRAY(String())), + ) + .table_valued("location_id", "name_lower") + .render_derived(), + ), + ), +) + +get_one_query = select(Job).where( + Job.location_id == bindparam("location_id"), + func.lower(Job.name) == bindparam("name_lower"), +) + +get_list_query = ( + select(Job) + .where( + Job.id == any_(bindparam("job_ids")), + ) + .options(selectinload(Job.location).selectinload(Location.addresses)) +) + +get_stats_query = ( + select( + Job.location_id.label("location_id"), + func.count(Job.id.distinct()).label("total_jobs"), + ) + .where( + Job.location_id == any_(bindparam("location_ids")), + ) + .group_by(Job.location_id) +) + + class JobRepository(Repository[Job]): async def paginate( self, @@ -90,19 +129,13 @@ async def fetch_bulk(self, jobs_dto: list[JobDTO]) -> list[tuple[JobDTO, Job | N if not jobs_dto: return [] - location_ids = [job_dto.location.id for job_dto in jobs_dto] - names = [job_dto.name.lower() for job_dto in jobs_dto] - pairs = ( - func.unnest( - cast(location_ids, ARRAY(Integer())), - cast(names, ARRAY(String())), - ) - .table_valued("location_id", "name") - .render_derived() + scalars = await self._session.scalars( + fetch_bulk_query, + { + "location_ids": [item.location.id for item in jobs_dto], + "names_lower": [item.name.lower() for item in jobs_dto], + }, ) - - statement = select(Job).where(tuple_(Job.location_id, func.lower(Job.name)).in_(select(pairs))) - scalars = await self._session.scalars(statement) existing = {(job.location_id, job.name.lower()): job for job in scalars.all()} return [ ( @@ -121,11 +154,13 @@ async def create_or_update(self, job: JobDTO) -> Job: return await self.update(result, job) async def _get(self, job: JobDTO) -> Job | None: - statement = select(Job).where( - Job.location_id == job.location.id, - func.lower(Job.name) == job.name.lower(), + return await self._session.scalar( + get_one_query, + { + "location_id": job.location.id, + "name_lower": job.name.lower(), + }, ) - return await self._session.scalar(statement) async def _create(self, job: JobDTO) -> Job: result = Job( @@ -147,28 +182,13 @@ async def update(self, existing: Job, new: JobDTO) -> Job: async def list_by_ids(self, job_ids: Collection[int]) -> list[Job]: if not job_ids: return [] - query = ( - select(Job) - .where(Job.id == any_(list(job_ids))) # type: ignore[arg-type] - .options(selectinload(Job.location).selectinload(Location.addresses)) - ) - result = await self._session.scalars(query) + + result = await self._session.scalars(get_list_query, {"job_ids": list(job_ids)}) return list(result.all()) async def get_stats_by_location_ids(self, location_ids: Collection[int]) -> dict[int, Row]: if not location_ids: return {} - query = ( - select( - Job.location_id.label("location_id"), - func.count(Job.id.distinct()).label("total_jobs"), - ) - .where( - Job.location_id == any_(list(location_ids)), # type: ignore[arg-type] - ) - .group_by(Job.location_id) - ) - - query_result = await self._session.execute(query) + query_result = await self._session.execute(get_stats_query, {"location_ids": list(location_ids)}) return {row.location_id: row for row in query_result.all()} diff --git a/data_rentgen/db/repositories/job_type.py b/data_rentgen/db/repositories/job_type.py index c74c4dcd..b66a7ae5 100644 --- a/data_rentgen/db/repositories/job_type.py +++ b/data_rentgen/db/repositories/job_type.py @@ -3,6 +3,7 @@ from sqlalchemy import ( any_, + bindparam, select, ) @@ -10,14 +11,26 @@ from data_rentgen.db.repositories.base import Repository from data_rentgen.dto import JobTypeDTO +fetch_bulk_query = select(JobType).where( + JobType.type == any_(bindparam("types")), +) + +get_one_query = select(JobType).where( + JobType.type == bindparam("type"), +) + class JobTypeRepository(Repository[JobType]): async def fetch_bulk(self, job_types_dto: list[JobTypeDTO]) -> list[tuple[JobTypeDTO, JobType | None]]: - unique_keys = [job_type_dto.type for job_type_dto in job_types_dto] - statement = select(JobType).where( - JobType.type == any_(unique_keys), # type: ignore[arg-type] + if not job_types_dto: + return [] + + scalars = await self._session.scalars( + fetch_bulk_query, + { + "types": [job_type_dto.type for job_type_dto in job_types_dto], + }, ) - scalars = await self._session.scalars(statement) existing = {job.type: job for job in scalars.all()} return [(job_type_dto, existing.get(job_type_dto.type)) for job_type_dto in job_types_dto] @@ -27,8 +40,7 @@ async def create(self, job_type_dto: JobTypeDTO) -> JobType: return await self._get(job_type_dto) or await self._create(job_type_dto) async def _get(self, job_type_dto: JobTypeDTO) -> JobType | None: - query = select(JobType).where(JobType.type == job_type_dto.type) - return await self._session.scalar(query) + return await self._session.scalar(get_one_query, {"type": job_type_dto.type}) async def _create(self, job_type_dto: JobTypeDTO) -> JobType: result = JobType(type=job_type_dto.type) diff --git a/data_rentgen/db/repositories/location.py b/data_rentgen/db/repositories/location.py index 4cd907c5..deb79e3d 100644 --- a/data_rentgen/db/repositories/location.py +++ b/data_rentgen/db/repositories/location.py @@ -9,6 +9,7 @@ SQLColumnExpression, any_, asc, + bindparam, desc, func, select, @@ -22,6 +23,26 @@ from data_rentgen.dto import LocationDTO, PaginationDTO from data_rentgen.exceptions.entity import EntityNotFoundError +get_one_by_name_query = select(Location).where( + Location.type == bindparam("type"), + Location.name == bindparam("name"), +) +get_one_by_addresses_query = ( + select(Location) + .join(Location.addresses) + .where( + Location.type == bindparam("type"), + Address.url == any_(bindparam("addresses")), + ) +) +get_one_query = ( + select(Location) + .from_statement( + get_one_by_name_query.union(get_one_by_addresses_query), + ) + .options(selectinload(Location.addresses)) +) + class LocationRepository(Repository[Location]): async def paginate( @@ -102,19 +123,14 @@ async def create_or_update(self, location: LocationDTO) -> Location: return result async def _get(self, location: LocationDTO) -> Location | None: - by_name = select(Location).where(Location.type == location.type, Location.name == location.name) - by_addresses = ( - select(Location) - .join(Location.addresses) - .where( - Location.type == location.type, - Address.url == any_(list(location.addresses)), # type: ignore[arg-type] - ) - ) - statement = ( - select(Location).from_statement(by_name.union(by_addresses)).options(selectinload(Location.addresses)) + return await self._session.scalar( + get_one_query, + { + "type": location.type, + "name": location.name, + "addresses": list(location.addresses), + }, ) - return await self._session.scalar(statement) async def _create(self, location: LocationDTO) -> Location: result = Location(type=location.type, name=location.name) diff --git a/data_rentgen/db/repositories/operation.py b/data_rentgen/db/repositories/operation.py index 3ecc59ce..ec88a67c 100644 --- a/data_rentgen/db/repositories/operation.py +++ b/data_rentgen/db/repositories/operation.py @@ -13,6 +13,35 @@ from data_rentgen.dto import OperationDTO, PaginationDTO from data_rentgen.utils.uuid import extract_timestamp_from_uuid +insert_statement = insert(Operation).on_conflict_do_nothing() +update_statement = update(Operation) + +get_list_by_run_ids_query = select(Operation).where( + Operation.created_at >= bindparam("since"), + Operation.run_id == any_(bindparam("run_ids")), +) + +# Do not use `tuple_(Operation.created_at, Operation.id).in_(...), +# as this is too complex filter for Postgres to make an optimal query plan. +# Primary key starts with id already, and created_at filter is used to select specific partitions +get_list_by_ids = select(Operation).where( + Operation.created_at >= bindparam("since"), + Operation.created_at <= bindparam("until"), + Operation.id == any_(bindparam("operation_ids")), +) + +get_stats_by_run_ids = ( + select( + Operation.run_id.label("run_id"), + func.count(Operation.id.distinct()).label("total_operations"), + ) + .where( + Operation.created_at >= bindparam("since"), + Operation.run_id == any_(bindparam("run_ids")), + ) + .group_by(Operation.run_id) +) + class OperationRepository(Repository[Operation]): async def create_or_update_bulk(self, operations: list[OperationDTO]) -> None: @@ -38,13 +67,13 @@ async def create_or_update_bulk(self, operations: list[OperationDTO]) -> None: # this replaces all null values with defaults await self._session.execute( - insert(Operation).on_conflict_do_nothing(), + insert_statement, data, ) # if value is still none, keep existing one await self._session.execute( - update(Operation).values( + update_statement.values( { "name": func.coalesce(bindparam("name"), Operation.name), "type": func.coalesce(bindparam("type"), Operation.type), @@ -113,46 +142,45 @@ async def list_by_run_ids( # All operations are created after run min_run_created_at = extract_timestamp_from_uuid(min(run_ids)) min_operation_created_at = max(min_run_created_at, since.astimezone(timezone.utc)) - query = select(Operation).where( - Operation.created_at >= min_operation_created_at, - Operation.run_id == any_(list(run_ids)), # type: ignore[arg-type] - ) + + query = get_list_by_run_ids_query if until: + # until is rarely used, avoid making query too complicated query = query.where(Operation.created_at <= until) - result = await self._session.scalars(query) + + result = await self._session.scalars( + query, + { + "since": min_operation_created_at, + "run_ids": list(run_ids), + }, + ) return list(result.all()) async def list_by_ids(self, operation_ids: Collection[UUID]) -> list[Operation]: if not operation_ids: return [] - # Do not use `tuple_(Operation.created_at, Operation.id).in_(...), - # as this is too complex filter for Postgres to make an optimal query plan - query = select(Operation).where( - Operation.created_at >= extract_timestamp_from_uuid(min(operation_ids)), - Operation.created_at <= extract_timestamp_from_uuid(max(operation_ids)), - Operation.id == any_(list(operation_ids)), # type: ignore[arg-type] + result = await self._session.scalars( + get_list_by_ids, + { + "since": extract_timestamp_from_uuid(min(operation_ids)), + "until": extract_timestamp_from_uuid(max(operation_ids)), + "operation_ids": list(operation_ids), + }, ) - result = await self._session.scalars(query) return list(result.all()) async def get_stats_by_run_ids(self, run_ids: Collection[UUID]) -> dict[UUID, Row]: if not run_ids: return {} - # unlike list_by_run_ids, we need to get all statistics for specific runs, regardless of time range - min_created_at = extract_timestamp_from_uuid(min(run_ids)) - query = ( - select( - Operation.run_id.label("run_id"), - func.count(Operation.id.distinct()).label("total_operations"), - ) - .where( - Operation.created_at >= min_created_at, - Operation.run_id == any_(list(run_ids)), # type: ignore[arg-type] - ) - .group_by(Operation.run_id) + query_result = await self._session.execute( + get_stats_by_run_ids, + { + # All operations are created after run + "since": extract_timestamp_from_uuid(min(run_ids)), + "run_ids": list(run_ids), + }, ) - - query_result = await self._session.execute(query) return {row.run_id: row for row in query_result.all()} diff --git a/data_rentgen/db/repositories/output.py b/data_rentgen/db/repositories/output.py index e9abf968..62051175 100644 --- a/data_rentgen/db/repositories/output.py +++ b/data_rentgen/db/repositories/output.py @@ -16,6 +16,18 @@ extract_timestamp_from_uuid, ) +insert_statement = insert(Output) +inserted_row = insert_statement.excluded +insert_statement = insert_statement.on_conflict_do_update( + index_elements=[Output.created_at, Output.id], + set_={ + "type": inserted_row.type.op("|")(Output.type), + "num_bytes": func.greatest(inserted_row.num_bytes, Output.num_bytes), + "num_rows": func.greatest(inserted_row.num_rows, Output.num_rows), + "num_files": func.greatest(inserted_row.num_files, Output.num_files), + }, +) + @dataclass class OutputRow: @@ -38,20 +50,8 @@ async def create_or_update_bulk(self, outputs: list[OutputDTO]) -> None: if not outputs: return - insert_statement = insert(Output) - new_row = insert_statement.excluded - statement = insert_statement.on_conflict_do_update( - index_elements=[Output.created_at, Output.id], - set_={ - "type": new_row.type.op("|")(Output.type), - "num_bytes": func.greatest(new_row.num_bytes, Output.num_bytes), - "num_rows": func.greatest(new_row.num_rows, Output.num_rows), - "num_files": func.greatest(new_row.num_files, Output.num_files), - }, - ) - await self._session.execute( - statement, + insert_statement, [ { "id": item.generate_id(), diff --git a/data_rentgen/db/repositories/personal_token.py b/data_rentgen/db/repositories/personal_token.py index 7c40cc82..86d1f365 100644 --- a/data_rentgen/db/repositories/personal_token.py +++ b/data_rentgen/db/repositories/personal_token.py @@ -5,7 +5,7 @@ from datetime import UTC, date, datetime from uuid import UUID -from sqlalchemy import any_, select +from sqlalchemy import any_, bindparam, select from sqlalchemy.exc import IntegrityError from data_rentgen.db.models import PersonalToken @@ -13,6 +13,12 @@ from data_rentgen.dto.pagination import PaginationDTO from data_rentgen.exceptions.entity import EntityAlreadyExistsError, EntityNotFoundError +get_by_id_query = select(PersonalToken).where( + PersonalToken.user_id == bindparam("user_id"), + PersonalToken.id == bindparam("token_id"), + PersonalToken.revoked_at.is_(None), +) + class PersonalTokenRepository(Repository[PersonalToken]): async def paginate( @@ -44,12 +50,10 @@ async def get_by_id( user_id: int, token_id: UUID, ) -> PersonalToken | None: - query = select(PersonalToken).where( - PersonalToken.user_id == user_id, - PersonalToken.id == token_id, - PersonalToken.revoked_at.is_(None), + return await self._session.scalar( + get_by_id_query, + {"user_id": user_id, "token_id": token_id}, ) - return await self._session.scalar(query) async def create( self, diff --git a/data_rentgen/db/repositories/run.py b/data_rentgen/db/repositories/run.py index d2cffac0..60ec8014 100644 --- a/data_rentgen/db/repositories/run.py +++ b/data_rentgen/db/repositories/run.py @@ -25,6 +25,62 @@ from data_rentgen.dto import PaginationDTO, RunDTO from data_rentgen.utils.uuid import extract_timestamp_from_uuid +# Do not use `tuple_(Run.created_at, Run.id).in_(...), +# as this is too complex filter for Postgres to make an optimal query plan. +# Primary key starts with id already, and created_at filter is used to select specific partitions +get_list_by_id_query = ( + select(Run) + .where( + Run.created_at >= bindparam("since"), + Run.created_at <= bindparam("until"), + Run.id == any_(bindparam("run_ids")), + ) + .options(selectinload(Run.started_by_user)) +) + +get_list_by_job_ids_query = ( + select(Run) + .where( + Run.created_at >= bindparam("since"), + Run.job_id == any_(bindparam("job_ids")), + ) + .options(selectinload(Run.started_by_user)) +) + +fetch_bulk_query = select(Run).where( + Run.created_at >= bindparam("since"), + Run.id == any_(bindparam("run_ids")), +) + +insert_statement = insert(Run).values( + created_at=bindparam("created_at"), + id=bindparam("id"), + job_id=bindparam("job_id"), + status=bindparam("status"), + parent_run_id=bindparam("parent_run_id"), + started_at=bindparam("started_at"), + started_by_user_id=bindparam("started_by_user_id"), + start_reason=bindparam("start_reason"), + ended_at=bindparam("ended_at"), +) +inserted_row = insert_statement.excluded +insert_statement = insert_statement.on_conflict_do_update( + index_elements=[Run.created_at, Run.id], + set_={ + "job_id": inserted_row.job_id, + "status": func.greatest(inserted_row.status, Run.status), + "parent_run_id": func.coalesce(inserted_row.parent_run_id, Run.parent_run_id), + "started_at": func.coalesce(inserted_row.started_at, Run.started_at), + "started_by_user_id": func.coalesce(inserted_row.started_by_user_id, Run.started_by_user_id), + "start_reason": func.coalesce(inserted_row.start_reason, Run.start_reason), + "ended_at": func.coalesce(inserted_row.ended_at, Run.ended_at), + "external_id": func.coalesce(inserted_row.external_id, Run.external_id), + "attempt": func.coalesce(inserted_row.attempt, Run.attempt), + "persistent_log_url": func.coalesce(inserted_row.persistent_log_url, Run.persistent_log_url), + "running_log_url": func.coalesce(inserted_row.running_log_url, Run.running_log_url), + }, +) + class RunRepository(Repository[Run]): async def paginate( @@ -105,46 +161,41 @@ async def paginate( async def list_by_ids(self, run_ids: Collection[UUID]) -> list[Run]: if not run_ids: return [] - # do not use `tuple_(Run.created_at, Run.id).in_(...), - # as this is too complex filter for Postgres to make an optimal query plan - query = ( - select(Run) - .where( - Run.created_at >= extract_timestamp_from_uuid(min(run_ids)), - Run.created_at <= extract_timestamp_from_uuid(max(run_ids)), - Run.id == any_(list(run_ids)), # type: ignore[arg-type] - ) - .options(selectinload(Run.started_by_user)) + + result = await self._session.scalars( + get_list_by_id_query, + { + "since": extract_timestamp_from_uuid(min(run_ids)), + "until": extract_timestamp_from_uuid(max(run_ids)), + "run_ids": list(run_ids), + }, ) - result = await self._session.scalars(query) return list(result.all()) async def list_by_job_ids(self, job_ids: Collection[int], since: datetime, until: datetime | None) -> list[Run]: if not job_ids: return [] - query = ( - select(Run) - .where( - Run.created_at >= since, - Run.job_id == any_(list(job_ids)), # type: ignore[arg-type] - ) - .options(selectinload(Run.started_by_user)) - ) + + query = get_list_by_job_ids_query if until: + # until is rarely used, avoid making query too complicated query = query.where(Run.created_at <= until) - result = await self._session.scalars(query) + + result = await self._session.scalars(query, {"since": since, "job_ids": list(job_ids)}) return list(result.all()) async def fetch_bulk(self, runs_dto: list[RunDTO]) -> list[tuple[RunDTO, Run | None]]: if not runs_dto: return [] + ids = [run_dto.id for run_dto in runs_dto] - min_created_at = extract_timestamp_from_uuid(min(ids)) - statement = select(Run).where( - Run.created_at >= min_created_at, - Run.id == any_(ids), # type: ignore[arg-type] + scalars = await self._session.scalars( + fetch_bulk_query, + { + "since": extract_timestamp_from_uuid(min(ids)), + "run_ids": ids, + }, ) - scalars = await self._session.scalars(statement) existing = {run.id: run for run in scalars.all()} return [(run_dto, existing.get(run_dto.id)) for run_dto in runs_dto] @@ -209,54 +260,27 @@ async def update( return existing async def create_or_update_bulk(self, runs: list[RunDTO]) -> None: - # used only by db seed script if not runs: return - data = [ - { - "created_at": run.created_at, - "id": run.id, - "job_id": run.job.id, - "status": RunStatus(run.status), - "parent_run_id": run.parent_run.id if run.parent_run else None, - "started_at": run.started_at, - "started_by_user_id": run.user.id if run.user else None, - "start_reason": RunStartReason(run.start_reason) if run.start_reason else None, - "ended_at": run.ended_at, - "external_id": run.external_id, - "attempt": run.attempt, - "persistent_log_url": run.persistent_log_url, - "running_log_url": run.running_log_url, - } - for run in runs - ] - - statement = insert(Run).values( - created_at=bindparam("created_at"), - id=bindparam("id"), - job_id=bindparam("job_id"), - status=bindparam("status"), - parent_run_id=bindparam("parent_run_id"), - started_at=bindparam("started_at"), - started_by_user_id=bindparam("started_by_user_id"), - start_reason=bindparam("start_reason"), - ended_at=bindparam("ended_at"), - ) - statement = statement.on_conflict_do_update( - index_elements=[Run.created_at, Run.id], - set_={ - "job_id": statement.excluded.job_id, - "status": func.greatest(statement.excluded.status, Run.status), - "parent_run_id": func.coalesce(statement.excluded.parent_run_id, Run.parent_run_id), - "started_at": func.coalesce(statement.excluded.started_at, Run.started_at), - "started_by_user_id": func.coalesce(statement.excluded.started_by_user_id, Run.started_by_user_id), - "start_reason": func.coalesce(statement.excluded.start_reason, Run.start_reason), - "ended_at": func.coalesce(statement.excluded.ended_at, Run.ended_at), - "external_id": func.coalesce(statement.excluded.external_id, Run.external_id), - "attempt": func.coalesce(statement.excluded.attempt, Run.attempt), - "persistent_log_url": func.coalesce(statement.excluded.persistent_log_url, Run.persistent_log_url), - "running_log_url": func.coalesce(statement.excluded.running_log_url, Run.running_log_url), - }, + await self._session.execute( + insert_statement, + [ + { + "created_at": run.created_at, + "id": run.id, + "job_id": run.job.id, + "status": RunStatus(run.status), + "parent_run_id": run.parent_run.id if run.parent_run else None, + "started_at": run.started_at, + "started_by_user_id": run.user.id if run.user else None, + "start_reason": RunStartReason(run.start_reason) if run.start_reason else None, + "ended_at": run.ended_at, + "external_id": run.external_id, + "attempt": run.attempt, + "persistent_log_url": run.persistent_log_url, + "running_log_url": run.running_log_url, + } + for run in runs + ], ) - await self._session.execute(statement, data) diff --git a/data_rentgen/db/repositories/schema.py b/data_rentgen/db/repositories/schema.py index b45acfa8..46ea73bf 100644 --- a/data_rentgen/db/repositories/schema.py +++ b/data_rentgen/db/repositories/schema.py @@ -3,32 +3,39 @@ from collections.abc import Collection -from sqlalchemy import any_, select +from sqlalchemy import any_, bindparam, select from data_rentgen.db.models import Schema from data_rentgen.db.repositories.base import Repository from data_rentgen.dto import SchemaDTO +# schema JSON can be heavy, avoid loading it if not needed +fetch_bulk_query = select(Schema.digest, Schema.id).where( + Schema.digest == any_(bindparam("digests")), +) + +get_list_by_ids_query = select(Schema).where(Schema.id == any_(bindparam("schema_ids"))) +get_one_by_digest_query = select(Schema).where(Schema.digest == bindparam("digest")) + class SchemaRepository(Repository[Schema]): async def list_by_ids(self, schema_ids: Collection[int]) -> list[Schema]: if not schema_ids: return [] - query = select(Schema).where(Schema.id == any_(list(schema_ids))) # type: ignore[arg-type] - result = await self._session.scalars(query) + result = await self._session.scalars(get_list_by_ids_query, {"schema_ids": list(schema_ids)}) return list(result.all()) async def fetch_known_ids(self, schemas_dto: list[SchemaDTO]) -> list[tuple[SchemaDTO, int | None]]: if not schemas_dto: return [] - unique_digests = [schema_dto.digest for schema_dto in schemas_dto] - # schema JSON can be heavy, avoid loading it if not needed - statement = select(Schema.digest, Schema.id).where( - Schema.digest == any_(unique_digests), # type: ignore[arg-type] + scalars = await self._session.execute( + fetch_bulk_query, + { + "digests": [item.digest for item in schemas_dto], + }, ) - scalars = await self._session.execute(statement) known_ids = {item.digest: item.id for item in scalars.all()} return [ ( @@ -44,8 +51,7 @@ async def create(self, schema: SchemaDTO) -> Schema: return await self._get(schema) or await self._create(schema) async def _get(self, schema: SchemaDTO) -> Schema | None: - result = select(Schema).where(Schema.digest == schema.digest) - return await self._session.scalar(result) + return await self._session.scalar(get_one_by_digest_query, {"digest": schema.digest}) async def _create(self, schema: SchemaDTO) -> Schema: result = Schema(digest=schema.digest, fields=schema.fields) diff --git a/data_rentgen/db/repositories/sql_query.py b/data_rentgen/db/repositories/sql_query.py index dddf5a67..aedce821 100644 --- a/data_rentgen/db/repositories/sql_query.py +++ b/data_rentgen/db/repositories/sql_query.py @@ -2,24 +2,31 @@ # SPDX-License-Identifier: Apache-2.0 -from sqlalchemy import any_, select +from sqlalchemy import any_, bindparam, select from data_rentgen.db.models.sql_query import SQLQuery from data_rentgen.db.repositories.base import Repository from data_rentgen.dto import SQLQueryDTO +# SQLQuery text can be heavy, avoid loading it if not needed +fetch_bulk_query = select(SQLQuery.fingerprint, SQLQuery.id).where( + SQLQuery.fingerprint == any_(bindparam("fingerprints")), +) + +get_one_by_fingerprint_query = select(SQLQuery).where(SQLQuery.fingerprint == bindparam("fingerprint")) + class SQLQueryRepository(Repository[SQLQuery]): async def fetch_known_ids(self, sql_queries_dto: list[SQLQueryDTO]) -> list[tuple[SQLQueryDTO, int | None]]: if not sql_queries_dto: return [] - unique_fingerprints = [sql_query_dto.fingerprint for sql_query_dto in sql_queries_dto] - # query text can be heavy, avoid loading it if not needed - statement = select(SQLQuery.fingerprint, SQLQuery.id).where( - SQLQuery.fingerprint == any_(unique_fingerprints), # type: ignore[arg-type] + scalars = await self._session.execute( + fetch_bulk_query, + { + "fingerprints": [item.fingerprint for item in sql_queries_dto], + }, ) - scalars = await self._session.execute(statement) known_ids = {item.fingerprint: item.id for item in scalars.all()} return [ ( @@ -35,8 +42,10 @@ async def create(self, sql_query: SQLQueryDTO) -> SQLQuery: return await self._get(sql_query) or await self._create(sql_query) async def _get(self, sql_query: SQLQueryDTO) -> SQLQuery | None: - result = select(SQLQuery).where(SQLQuery.fingerprint == sql_query.fingerprint) - return await self._session.scalar(result) + return await self._session.scalar( + get_one_by_fingerprint_query, + {"fingerprint": sql_query.fingerprint}, + ) async def _create(self, sql_query: SQLQueryDTO) -> SQLQuery: result = SQLQuery(fingerprint=sql_query.fingerprint, query=sql_query.query) diff --git a/data_rentgen/db/repositories/user.py b/data_rentgen/db/repositories/user.py index 5064d973..a3e1d109 100644 --- a/data_rentgen/db/repositories/user.py +++ b/data_rentgen/db/repositories/user.py @@ -1,22 +1,36 @@ # SPDX-FileCopyrightText: 2024-2025 MTS PJSC # SPDX-License-Identifier: Apache-2.0 -from sqlalchemy import any_, func, select +from sqlalchemy import any_, bindparam, func, select from data_rentgen.db.models import User from data_rentgen.db.repositories.base import Repository from data_rentgen.dto import UserDTO +fetch_bulk_query = select(User).where( + func.lower(User.name) == any_(bindparam("names_lower")), +) + +get_one_by_name_query = select(User).where( + func.lower(User.name) == bindparam("name_lower"), +) + +get_one_by_id_query = select(User).where( + User.id == bindparam("id"), +) + class UserRepository(Repository[User]): async def fetch_bulk(self, users_dto: list[UserDTO]) -> list[tuple[UserDTO, User | None]]: if not users_dto: return [] - unique_keys = [user_dto.name.lower() for user_dto in users_dto] - statement = select(User).where( - func.lower(User.name) == any_(unique_keys), # type: ignore[arg-type] + + scalars = await self._session.scalars( + fetch_bulk_query, + { + "names_lower": [item.name.lower() for item in users_dto], + }, ) - scalars = await self._session.scalars(statement) existing = {user.name.lower(): user for user in scalars.all()} return [(user_dto, existing.get(user_dto.name.lower())) for user_dto in users_dto] @@ -29,8 +43,7 @@ async def get_or_create(self, user_dto: UserDTO) -> User: return await self._get(user_dto.name) or await self._create(user_dto) async def _get(self, name: str) -> User | None: - statement = select(User).where(func.lower(User.name) == name.lower()) - return await self._session.scalar(statement) + return await self._session.scalar(get_one_by_name_query, {"name_lower": name.lower()}) async def _create(self, user: UserDTO) -> User: result = User(name=user.name) @@ -39,5 +52,4 @@ async def _create(self, user: UserDTO) -> User: return result async def read_by_id(self, id_: int) -> User | None: - statement = select(User).where(User.id == id_) - return await self._session.scalar(statement) + return await self._session.scalar(get_one_by_id_query, {"id": id_})