11"""Provide PostgreSQL-based storage implementation."""
22
3+ import base64
34import json
45import logging
56from collections import defaultdict
67from collections .abc import Iterable
78
89from ga4gh .vrs import models as vrs_models
9- from sqlalchemy import create_engine , delete , func , select
10+ from sqlalchemy import and_ , create_engine , delete , func , or_ , select
1011from sqlalchemy .dialects .postgresql import insert
1112from sqlalchemy .exc import IntegrityError
1213from sqlalchemy .orm import joinedload , sessionmaker
1516from anyvar .core import objects as anyvar_objects
1617from anyvar .storage import orm
1718from anyvar .storage .base import (
19+ AlleleSearchPage ,
1820 DataIntegrityError ,
1921 IncompleteVrsObjectError ,
2022 InvalidSearchParamsError ,
@@ -33,10 +35,6 @@ class PostgresObjectStore(Storage):
3335 with object mapping to convert between VRS models and database entities.
3436 """
3537
36- # temporary cap on max # of rows that can be returned by a single SQL query
37- # issue 295 should convert this to a batch size parameter
38- MAX_ROWS = 100
39-
4038 _VRS_OBJECT_INSERT_ORDER : list [str ] = [ # noqa: RUF012
4139 orm .SequenceReference .__name__ ,
4240 orm .Location .__name__ ,
@@ -51,6 +49,7 @@ def __init__(self, db_url: str, *args, **kwargs) -> None:
5149 self .db_url = db_url
5250 self .engine = create_engine (db_url )
5351 self .session_factory = sessionmaker (bind = self .engine )
52+ self .batch_size = kwargs .get ("batch_size" , 1000 )
5453 create_tables (self .db_url )
5554
5655 def close (self ) -> None :
@@ -133,8 +132,10 @@ def get_objects(
133132 :return: iterable collection of VRS objects matching given IDs
134133 """
135134 object_ids_list = list (object_ids )
136- results = []
135+ if not object_ids_list :
136+ return []
137137
138+ results = []
138139 with self .session_factory () as session :
139140 if object_type is vrs_models .Allele :
140141 # Get alleles with eager loading
@@ -146,7 +147,6 @@ def get_objects(
146147 )
147148 )
148149 .where (orm .Allele .id .in_ (object_ids_list ))
149- .limit (self .MAX_ROWS )
150150 )
151151 db_objects = session .scalars (stmt ).all ()
152152 elif object_type is vrs_models .SequenceLocation :
@@ -155,15 +155,12 @@ def get_objects(
155155 select (orm .Location )
156156 .options (joinedload (orm .Location .sequence_reference ))
157157 .where (orm .Location .id .in_ (object_ids_list ))
158- .limit (self .MAX_ROWS )
159158 )
160159 db_objects = session .scalars (stmt ).all ()
161160 elif object_type is vrs_models .SequenceReference :
162161 # Get sequence references
163- stmt = (
164- select (orm .SequenceReference )
165- .where (orm .SequenceReference .id .in_ (object_ids_list ))
166- .limit (self .MAX_ROWS )
162+ stmt = select (orm .SequenceReference ).where (
163+ orm .SequenceReference .id .in_ (object_ids_list )
167164 )
168165 db_objects = session .scalars (stmt ).all ()
169166 else :
@@ -353,7 +350,9 @@ def search_alleles(
353350 refget_accession : str ,
354351 start : int ,
355352 stop : int ,
356- ) -> list [vrs_models .Allele ]:
353+ page_size : int = 1000 ,
354+ cursor : str | None = None ,
355+ ) -> AlleleSearchPage :
357356 """Find all Alleles that are located within the specified interval.
358357
359358 The interval is the closed range [start, stop] on the sequence identified by
@@ -373,16 +372,29 @@ def search_alleles(
373372 :param refget_accession: refget accession (e.g. `"SQ.IW78mgV5Cqf6M24hy52hPjyyo5tCCd86"`)
374373 :param start: Inclusive, inter-residue start position of the interval
375374 :param stop: Inclusive, inter-residue end position of the interval
376- :return: a list of matching VRS alleles
375+ :return:
377376 :raise InvalidSearchParamsError: if above search param requirements are violated
378-
379377 """
378+
379+ def _encode_cursor (start : int , allele_id : str ) -> str :
380+ raw = json .dumps (
381+ {"start" : start , "id" : allele_id }, separators = ("," , ":" )
382+ ).encode ()
383+ return base64 .urlsafe_b64encode (raw ).decode ()
384+
385+ def _decode_cursor (cursor : str ) -> tuple [int , str ]:
386+ raw = base64 .urlsafe_b64decode (cursor .encode ())
387+ obj = json .loads (raw )
388+ return int (obj ["start" ]), str (obj ["id" ])
389+
380390 if start < 0 or stop < 0 or start > stop :
381391 raise InvalidSearchParamsError
392+ seek_start : int | None = None
393+ seek_id : str | None = None
394+ if cursor :
395+ seek_start , seek_id = _decode_cursor (cursor )
382396
383397 with self .session_factory () as session :
384- # Query alleles with overlapping locations
385- # NOTE: this is any overlap, not containment.
386398 stmt = (
387399 select (orm .Allele )
388400 .options (
@@ -398,10 +410,25 @@ def search_alleles(
398410 func .int8range (start , stop , "[]" )
399411 ),
400412 )
401- .limit (self .MAX_ROWS )
413+ .order_by (orm .Location .start , orm .Allele .id )
414+ .limit (page_size )
402415 )
403- db_alleles = session .scalars (stmt ).all ()
404416
405- return [
406- mapper_registry .from_db_entity (db_allele ) for db_allele in db_alleles
407- ]
417+ if seek_start is not None and seek_id is not None :
418+ stmt = stmt .where (
419+ or_ (
420+ orm .Location .start > seek_start ,
421+ and_ (orm .Location .start == seek_start , orm .Allele .id > seek_id ),
422+ )
423+ )
424+
425+ page_db = session .scalars (stmt ).all ()
426+
427+ items = [mapper_registry .from_db_entity (a ) for a in page_db ]
428+
429+ if not page_db :
430+ return AlleleSearchPage (items = [], next_cursor = None )
431+
432+ last = page_db [- 1 ]
433+ next_cursor = _encode_cursor (last .location .start , last .id )
434+ return AlleleSearchPage (items = items , next_cursor = next_cursor )
0 commit comments