diff --git a/data_rentgen/db/repositories/operation.py b/data_rentgen/db/repositories/operation.py index ec88a67c..3fc22df5 100644 --- a/data_rentgen/db/repositories/operation.py +++ b/data_rentgen/db/repositories/operation.py @@ -11,12 +11,13 @@ from data_rentgen.db.models import Operation, OperationStatus, OperationType from data_rentgen.db.repositories.base import Repository from data_rentgen.dto import OperationDTO, PaginationDTO -from data_rentgen.utils.uuid import extract_timestamp_from_uuid +from data_rentgen.utils.uuid import extract_timestamp_from_uuid, get_max_uuid, get_min_uuid insert_statement = insert(Operation).on_conflict_do_nothing() update_statement = update(Operation) get_list_by_run_ids_query = select(Operation).where( + Operation.id >= bindparam("min_id"), Operation.created_at >= bindparam("since"), Operation.run_id == any_(bindparam("run_ids")), ) @@ -36,6 +37,7 @@ func.count(Operation.id.distinct()).label("total_operations"), ) .where( + Operation.id >= bindparam("min_id"), Operation.created_at >= bindparam("since"), Operation.run_id == any_(bindparam("run_ids")), ) @@ -101,25 +103,43 @@ async def paginate( # do not use `tuple_(Operation.created_at, Operation.id).in_(...), # as this is too complex filter for Postgres to make an optimal query plan where = [] + + # created_at and id are always correlated, + # and primary key starts with id, so we need to apply filter on both + # to get the most optimal query plan if operation_ids: min_operation_created_at = extract_timestamp_from_uuid(min(operation_ids)) max_operation_created_at = extract_timestamp_from_uuid(max(operation_ids)) - min_created_at = max(since, min_operation_created_at) if since else min_operation_created_at - max_created_at = min(until, max_operation_created_at) if until else max_operation_created_at + # narrow created_at range + min_created_at = max(filter(None, [since, min_operation_created_at])) + max_created_at = min(filter(None, [until, max_operation_created_at])) where = [ Operation.created_at >= min_created_at, Operation.created_at <= max_created_at, + Operation.id == any_(list(operation_ids)), # type: ignore[arg-type] + ] + + elif run_id: + run_created_at = extract_timestamp_from_uuid(run_id) + # narrow created_at range + min_created_at = max(filter(None, [since, run_created_at])) + where = [ + Operation.run_id == run_id, + Operation.created_at >= min_created_at, + Operation.id >= get_min_uuid(min_created_at), + ] + + elif since: + where = [ + Operation.created_at >= since, + Operation.id >= get_min_uuid(since), + ] + + if until and not operation_ids: + where += [ + Operation.created_at <= until, + Operation.id <= get_max_uuid(until), ] - else: - if since: - where.append(Operation.created_at >= since) - if until: - where.append(Operation.created_at <= until) - - if run_id: - where.append(Operation.run_id == run_id) - if operation_ids: - where.append(Operation.id == any_(list(operation_ids))) # type: ignore[arg-type] query = select(Operation).where(*where) order_by: list[UnaryExpression] = [Operation.created_at.desc(), Operation.id.desc()] @@ -151,6 +171,7 @@ async def list_by_run_ids( result = await self._session.scalars( query, { + "min_id": get_min_uuid(min_operation_created_at), "since": min_operation_created_at, "run_ids": list(run_ids), }, @@ -175,11 +196,13 @@ async def get_stats_by_run_ids(self, run_ids: Collection[UUID]) -> dict[UUID, Ro if not run_ids: return {} + # All operations are created after run + since = extract_timestamp_from_uuid(min(run_ids)) query_result = await self._session.execute( get_stats_by_run_ids, { - # All operations are created after run - "since": extract_timestamp_from_uuid(min(run_ids)), + "since": since, + "min_id": get_min_uuid(since), "run_ids": list(run_ids), }, ) diff --git a/data_rentgen/db/repositories/run.py b/data_rentgen/db/repositories/run.py index 60ec8014..60376021 100644 --- a/data_rentgen/db/repositories/run.py +++ b/data_rentgen/db/repositories/run.py @@ -23,7 +23,7 @@ from data_rentgen.db.repositories.base import Repository from data_rentgen.db.utils.search import make_tsquery, ts_match, ts_rank from data_rentgen.dto import PaginationDTO, RunDTO -from data_rentgen.utils.uuid import extract_timestamp_from_uuid +from data_rentgen.utils.uuid import extract_timestamp_from_uuid, get_max_uuid, get_min_uuid # Do not use `tuple_(Run.created_at, Run.id).in_(...), # as this is too complex filter for Postgres to make an optimal query plan. @@ -41,6 +41,7 @@ get_list_by_job_ids_query = ( select(Run) .where( + Run.id >= bindparam("min_id"), Run.created_at >= bindparam("since"), Run.job_id == any_(bindparam("job_ids")), ) @@ -97,11 +98,16 @@ async def paginate( # do not use `tuple_(Run.created_at, Run.id).in_(...), # as this is too complex filter for Postgres to make an optimal query plan where = [] + + # created_at and id are always correlated, + # and primary key starts with id, so we need to apply filter on both + # to get the most optimal query plan if run_ids: min_run_created_at = extract_timestamp_from_uuid(min(run_ids)) max_run_created_at = extract_timestamp_from_uuid(max(run_ids)) - min_created_at = max(since, min_run_created_at) if since else min_run_created_at - max_created_at = min(until, max_run_created_at) if until else max_run_created_at + # narrow created_at range + min_created_at = max(filter(None, [since, min_run_created_at])) + max_created_at = min(filter(None, [until, max_run_created_at])) where = [ Run.created_at >= min_created_at, Run.created_at <= max_created_at, @@ -109,12 +115,17 @@ async def paginate( ] else: if since: - where.append(Run.created_at >= since) + where = [ + Run.created_at >= since, + Run.id >= get_min_uuid(since), + ] + if until: - where.append(Run.created_at <= until) + where += [ + Run.created_at <= until, + Run.id <= get_max_uuid(until), + ] - if run_ids: - where.append(Run.id == any_(list(run_ids))) # type: ignore[arg-type] if job_id: where.append(Run.job_id == job_id) if parent_run_id: @@ -181,7 +192,14 @@ async def list_by_job_ids(self, job_ids: Collection[int], since: datetime, until # until is rarely used, avoid making query too complicated query = query.where(Run.created_at <= until) - result = await self._session.scalars(query, {"since": since, "job_ids": list(job_ids)}) + result = await self._session.scalars( + query, + { + "min_id": get_min_uuid(since), + "since": since, + "job_ids": list(job_ids), + }, + ) return list(result.all()) async def fetch_bulk(self, runs_dto: list[RunDTO]) -> list[tuple[RunDTO, Run | None]]: diff --git a/data_rentgen/utils/uuid.py b/data_rentgen/utils/uuid.py index 36b68d51..18773a8c 100644 --- a/data_rentgen/utils/uuid.py +++ b/data_rentgen/utils/uuid.py @@ -80,6 +80,21 @@ def _build_uuidv7(timestamp: int, node: int) -> NewUUID: return NewUUID(int=uuid_int) +def get_min_uuid(timestamp: datetime) -> NewUUID: + """Get minimal possible UUID for timestamp""" + timestamp_int = int(timestamp.timestamp() * 1000) + uuid_int = (timestamp_int & 0xFFFFFFFFFFFF) << 80 + return NewUUID(int=uuid_int) + + +def get_max_uuid(timestamp: datetime) -> NewUUID: + """Get maximal possible UUID for timestamp""" + timestamp_int = int(timestamp.timestamp() * 1000) + uuid_int = (timestamp_int & 0xFFFFFFFFFFFF) << 80 + uuid_int |= 0xFFFFFFFFFFFFFFFFFFFF + return NewUUID(int=uuid_int) + + def generate_static_uuid(data: str) -> BaseUUID: """Generate static UUID for data. Each function call returns the same UUID value. diff --git a/tests/test_consumer/test_utils/test_uuid.py b/tests/test_consumer/test_utils/test_uuid.py index 332a47e8..8095e0cb 100644 --- a/tests/test_consumer/test_utils/test_uuid.py +++ b/tests/test_consumer/test_utils/test_uuid.py @@ -1,9 +1,12 @@ from datetime import UTC, datetime, timedelta +from uuid import UUID from data_rentgen.utils.uuid import ( generate_incremental_uuid, generate_new_uuid, generate_static_uuid, + get_max_uuid, + get_min_uuid, ) @@ -70,3 +73,15 @@ def test_generate_incremental_uuid_sorted_like_timestamp(): uuid1 = generate_incremental_uuid(current, "test") uuid2 = generate_incremental_uuid(following, "test") assert uuid1 < uuid2 + + +def test_get_min_uuid(): + timestamp = datetime(2025, 9, 21, 23, 35, 49, 123456, tzinfo=UTC) + uuid = get_min_uuid(timestamp) + assert uuid == UUID("01996ea2-3883-0000-0000-000000000000") + + +def test_get_max_uuid(): + timestamp = datetime(2025, 9, 21, 23, 35, 49, 123456, tzinfo=UTC) + uuid = get_max_uuid(timestamp) + assert uuid == UUID("01996ea2-3883-ffff-ffff-ffffffffffff")