Skip to content

Commit 5755b5b

Browse files
committed
stash
1 parent b6f121a commit 5755b5b

File tree

5 files changed

+83
-39
lines changed

5 files changed

+83
-39
lines changed

src/anyvar/restapi/schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ class SearchResponse(BaseModel):
296296
"""Describe response for the GET /search endpoint"""
297297

298298
variations: list[models.Variation]
299+
next_cursor: str | None
299300

300301

301302
class RunStatusResponse(BaseModel):

src/anyvar/restapi/search_router.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def search_variations(
3636
int,
3737
Query(..., description="End position for genomic region", examples=[2781758]),
3838
],
39+
page_size: int = Query(1000, ge=1, le=10000),
40+
cursor: str | None = Query(None, description="Opaque pagination cursor"),
3941
) -> SearchResponse:
4042
"""Perform genomic coordinate-based search over all registered variations."""
4143
av: AnyVar = request.app.state.anyvar
@@ -50,15 +52,18 @@ def search_variations(
5052
detail="Unable to dereference provided accession ID",
5153
) from e
5254

53-
alleles = []
54-
if ga4gh_id:
55-
try:
56-
refget_accession = ga4gh_id.split("ga4gh:")[-1]
57-
alleles = av.object_store.search_alleles(refget_accession, start, end)
58-
except NotImplementedError as e:
59-
raise HTTPException(
60-
status_code=HTTPStatus.NOT_IMPLEMENTED,
61-
detail="Search not implemented for current storage backend",
62-
) from e
55+
if not ga4gh_id:
56+
return SearchResponse(variations=[], next_cursor=None)
6357

64-
return SearchResponse(variations=alleles)
58+
try:
59+
refget_accession = ga4gh_id.split("ga4gh:")[-1]
60+
page = av.object_store.search_alleles(
61+
refget_accession, start, end, page_size=page_size, cursor=cursor
62+
)
63+
except NotImplementedError as e:
64+
raise HTTPException(
65+
status_code=HTTPStatus.NOT_IMPLEMENTED,
66+
detail="Search not implemented for current storage backend",
67+
) from e
68+
69+
return SearchResponse(variations=page.items, next_cursor=page.next_cursor)

src/anyvar/storage/base.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from abc import ABC, abstractmethod
44
from collections.abc import Iterable
5+
from dataclasses import dataclass
56

67
from ga4gh.vrs import models as vrs_models
78

@@ -29,6 +30,14 @@ class InvalidSearchParamsError(StorageError):
2930
"""Raise if search params violate specified logical constraints"""
3031

3132

33+
@dataclass(frozen=True)
34+
class AlleleSearchPage:
35+
"""Return object for implementing keyset pagination in allele search"""
36+
37+
items: list[vrs_models.Allele]
38+
next_cursor: str | None
39+
40+
3241
class Storage(ABC):
3342
"""Abstract base class for interacting with storage backends."""
3443

@@ -174,7 +183,9 @@ def search_alleles(
174183
refget_accession: str,
175184
start: int,
176185
stop: int,
177-
) -> list[vrs_models.Allele]:
186+
page_size: int = 1000,
187+
cursor: str | None = None,
188+
) -> AlleleSearchPage:
178189
"""Find all Alleles that are located within the specified interval.
179190
180191
The interval is the closed range [start, stop] on the sequence identified by
@@ -194,7 +205,8 @@ def search_alleles(
194205
:param refget_accession: refget accession (e.g. `"SQ.IW78mgV5Cqf6M24hy52hPjyyo5tCCd86"`)
195206
:param start: Inclusive, inter-residue start position of the interval
196207
:param stop: Inclusive, inter-residue end position of the interval
197-
:return: a list of matching VRS alleles
208+
:param page_size:
209+
:param cursor:
210+
:return:
198211
:raise InvalidSearchParamsError: if above search param requirements are violated
199-
200212
"""

src/anyvar/storage/postgres.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""Provide PostgreSQL-based storage implementation."""
22

3+
import base64
34
import json
45
import logging
56
from collections import defaultdict
67
from collections.abc import Iterable
78

89
from ga4gh.vrs import models as vrs_models
9-
from sqlalchemy import create_engine, delete, func, select
10+
from sqlalchemy import and_, create_engine, delete, func, or_, select
1011
from sqlalchemy.dialects.postgresql import insert
1112
from sqlalchemy.exc import IntegrityError
1213
from sqlalchemy.orm import joinedload, sessionmaker
@@ -15,6 +16,7 @@
1516
from anyvar.core import objects as anyvar_objects
1617
from anyvar.storage import orm
1718
from anyvar.storage.base import (
19+
AlleleSearchPage,
1820
DataIntegrityError,
1921
IncompleteVrsObjectError,
2022
InvalidSearchParamsError,
@@ -33,10 +35,6 @@ class PostgresObjectStore(Storage):
3335
with object mapping to convert between VRS models and database entities.
3436
"""
3537

36-
# temporary cap on max # of rows that can be returned by a single SQL query
37-
# issue 295 should convert this to a batch size parameter
38-
MAX_ROWS = 100
39-
4038
_VRS_OBJECT_INSERT_ORDER: list[str] = [ # noqa: RUF012
4139
orm.SequenceReference.__name__,
4240
orm.Location.__name__,
@@ -51,6 +49,7 @@ def __init__(self, db_url: str, *args, **kwargs) -> None:
5149
self.db_url = db_url
5250
self.engine = create_engine(db_url)
5351
self.session_factory = sessionmaker(bind=self.engine)
52+
self.batch_size = kwargs.get("batch_size", 1000)
5453
create_tables(self.db_url)
5554

5655
def close(self) -> None:
@@ -133,8 +132,10 @@ def get_objects(
133132
:return: iterable collection of VRS objects matching given IDs
134133
"""
135134
object_ids_list = list(object_ids)
136-
results = []
135+
if not object_ids_list:
136+
return []
137137

138+
results = []
138139
with self.session_factory() as session:
139140
if object_type is vrs_models.Allele:
140141
# Get alleles with eager loading
@@ -146,7 +147,6 @@ def get_objects(
146147
)
147148
)
148149
.where(orm.Allele.id.in_(object_ids_list))
149-
.limit(self.MAX_ROWS)
150150
)
151151
db_objects = session.scalars(stmt).all()
152152
elif object_type is vrs_models.SequenceLocation:
@@ -155,15 +155,12 @@ def get_objects(
155155
select(orm.Location)
156156
.options(joinedload(orm.Location.sequence_reference))
157157
.where(orm.Location.id.in_(object_ids_list))
158-
.limit(self.MAX_ROWS)
159158
)
160159
db_objects = session.scalars(stmt).all()
161160
elif object_type is vrs_models.SequenceReference:
162161
# Get sequence references
163-
stmt = (
164-
select(orm.SequenceReference)
165-
.where(orm.SequenceReference.id.in_(object_ids_list))
166-
.limit(self.MAX_ROWS)
162+
stmt = select(orm.SequenceReference).where(
163+
orm.SequenceReference.id.in_(object_ids_list)
167164
)
168165
db_objects = session.scalars(stmt).all()
169166
else:
@@ -353,7 +350,9 @@ def search_alleles(
353350
refget_accession: str,
354351
start: int,
355352
stop: int,
356-
) -> list[vrs_models.Allele]:
353+
page_size: int = 1000,
354+
cursor: str | None = None,
355+
) -> AlleleSearchPage:
357356
"""Find all Alleles that are located within the specified interval.
358357
359358
The interval is the closed range [start, stop] on the sequence identified by
@@ -373,16 +372,29 @@ def search_alleles(
373372
:param refget_accession: refget accession (e.g. `"SQ.IW78mgV5Cqf6M24hy52hPjyyo5tCCd86"`)
374373
:param start: Inclusive, inter-residue start position of the interval
375374
:param stop: Inclusive, inter-residue end position of the interval
376-
:return: a list of matching VRS alleles
375+
:return:
377376
:raise InvalidSearchParamsError: if above search param requirements are violated
378-
379377
"""
378+
379+
def _encode_cursor(start: int, allele_id: str) -> str:
380+
raw = json.dumps(
381+
{"start": start, "id": allele_id}, separators=(",", ":")
382+
).encode()
383+
return base64.urlsafe_b64encode(raw).decode()
384+
385+
def _decode_cursor(cursor: str) -> tuple[int, str]:
386+
raw = base64.urlsafe_b64decode(cursor.encode())
387+
obj = json.loads(raw)
388+
return int(obj["start"]), str(obj["id"])
389+
380390
if start < 0 or stop < 0 or start > stop:
381391
raise InvalidSearchParamsError
392+
seek_start: int | None = None
393+
seek_id: str | None = None
394+
if cursor:
395+
seek_start, seek_id = _decode_cursor(cursor)
382396

383397
with self.session_factory() as session:
384-
# Query alleles with overlapping locations
385-
# NOTE: this is any overlap, not containment.
386398
stmt = (
387399
select(orm.Allele)
388400
.options(
@@ -398,10 +410,25 @@ def search_alleles(
398410
func.int8range(start, stop, "[]")
399411
),
400412
)
401-
.limit(self.MAX_ROWS)
413+
.order_by(orm.Location.start, orm.Allele.id)
414+
.limit(page_size)
402415
)
403-
db_alleles = session.scalars(stmt).all()
404416

405-
return [
406-
mapper_registry.from_db_entity(db_allele) for db_allele in db_alleles
407-
]
417+
if seek_start is not None and seek_id is not None:
418+
stmt = stmt.where(
419+
or_(
420+
orm.Location.start > seek_start,
421+
and_(orm.Location.start == seek_start, orm.Allele.id > seek_id),
422+
)
423+
)
424+
425+
page_db = session.scalars(stmt).all()
426+
427+
items = [mapper_registry.from_db_entity(a) for a in page_db]
428+
429+
if not page_db:
430+
return AlleleSearchPage(items=[], next_cursor=None)
431+
432+
last = page_db[-1]
433+
next_cursor = _encode_cursor(last.location.start, last.id)
434+
return AlleleSearchPage(items=items, next_cursor=next_cursor)

src/anyvar/storage/snowflake.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ def search_alleles(
616616
refget_accession: str,
617617
start: int,
618618
stop: int,
619-
) -> list[vrs_models.Allele]:
619+
) -> Iterable[vrs_models.Allele]:
620620
"""Find all Alleles that are located within the specified interval.
621621
622622
The interval is the closed range [start, stop] on the sequence identified by
@@ -636,9 +636,8 @@ def search_alleles(
636636
:param refget_accession: refget accession (e.g. `"SQ.IW78mgV5Cqf6M24hy52hPjyyo5tCCd86"`)
637637
:param start: Inclusive, inter-residue start position of the interval
638638
:param stop: Inclusive, inter-residue end position of the interval
639-
:return: a list of matching VRS alleles
639+
:return: an iterable of matching VRS alleles
640640
:raise InvalidSearchParamsError: if above search param requirements are violated
641-
642641
"""
643642
if start < 0 or stop < 0 or start > stop:
644643
raise InvalidSearchParamsError

0 commit comments

Comments
 (0)