Skip to content

Commit 0bd2243

Browse files
committed
add docstring etc
1 parent dd4c28b commit 0bd2243

File tree

5 files changed

+110
-55
lines changed

5 files changed

+110
-55
lines changed

src/anyvar/storage/base.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Provide PostgreSQL-based storage implementation."""
22

3+
import base64
4+
import json
35
from abc import ABC, abstractmethod
46
from collections.abc import Iterable
57
from dataclasses import dataclass
@@ -181,6 +183,30 @@ def delete_annotation(self, annotation: metadata.Annotation) -> None:
181183
depended upon by another object
182184
"""
183185

186+
@staticmethod
187+
def _encode_search_cursor(start: int, allele_id: str) -> str:
188+
"""Create cursor for search
189+
190+
:param start: start value for next row
191+
:param allele_id: ID for next row
192+
:return: cursor to use to fetch next page
193+
"""
194+
raw = json.dumps(
195+
{"start": start, "id": allele_id}, separators=(",", ":")
196+
).encode()
197+
return base64.urlsafe_b64encode(raw).decode()
198+
199+
@staticmethod
200+
def _decode_search_cursor(cursor: str) -> tuple[int, str]:
201+
"""Decode cursor for getting next page during search
202+
203+
:param cursor: opaque key included with previous result
204+
:return: start and ID values indicating the first row of the next page
205+
"""
206+
raw = base64.urlsafe_b64decode(cursor.encode())
207+
obj = json.loads(raw)
208+
return int(obj["start"]), str(obj["id"])
209+
184210
@abstractmethod
185211
def search_alleles(
186212
self,
@@ -196,6 +222,9 @@ def search_alleles(
196222
the RefGet SequenceReference accession (`SQ.*`). Both `start` and `stop` are
197223
inclusive and represent inter-residue positions.
198224
225+
Uses keyset pagination, meaning that altering the page size while looping through
226+
successive cursors will effectively nullify the search loop.
227+
199228
Currently, any variation which overlaps the queried region is returned.
200229
201230
Todo (see Issue #338):
@@ -209,8 +238,8 @@ def search_alleles(
209238
:param refget_accession: refget accession (e.g. `"SQ.IW78mgV5Cqf6M24hy52hPjyyo5tCCd86"`)
210239
:param start: Inclusive, inter-residue start position of the interval
211240
:param stop: Inclusive, inter-residue end position of the interval
212-
:param page_size:
213-
:param cursor:
214-
:return:
241+
:param page_size: Max # of results to return
242+
:param cursor: Opaque key indicating start location for query in pagination
243+
:return: Results page including variants and a cursor for next result page, if available
215244
:raise InvalidSearchParamsError: if above search param requirements are violated
216245
"""

src/anyvar/storage/postgres.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Provide PostgreSQL-based storage implementation."""
22

3-
import base64
43
import json
54
import logging
65
from collections import defaultdict
@@ -149,7 +148,6 @@ def get_objects(
149148
)
150149
)
151150
.where(orm.Allele.id.in_(object_ids_list))
152-
.limit(self.MAX_ROWS)
153151
)
154152
db_objects = session.scalars(stmt).all()
155153
elif object_type is vrs_models.SequenceLocation:
@@ -158,15 +156,12 @@ def get_objects(
158156
select(orm.Location)
159157
.options(joinedload(orm.Location.sequence_reference))
160158
.where(orm.Location.id.in_(object_ids_list))
161-
.limit(self.MAX_ROWS)
162159
)
163160
db_objects = session.scalars(stmt).all()
164161
elif object_type is vrs_models.SequenceReference:
165162
# Get sequence references
166-
stmt = (
167-
select(orm.SequenceReference)
168-
.where(orm.SequenceReference.id.in_(object_ids_list))
169-
.limit(self.MAX_ROWS)
163+
stmt = select(orm.SequenceReference).where(
164+
orm.SequenceReference.id.in_(object_ids_list)
170165
)
171166

172167
db_objects = session.scalars(stmt).all()
@@ -382,29 +377,17 @@ def search_alleles(
382377
:param refget_accession: refget accession (e.g. `"SQ.IW78mgV5Cqf6M24hy52hPjyyo5tCCd86"`)
383378
:param start: Inclusive, inter-residue start position of the interval
384379
:param stop: Inclusive, inter-residue end position of the interval
385-
:param page_size:
386-
:param cursor:
387-
:return:
380+
:param page_size: Max # of results to return
381+
:param cursor: Opaque key indicating start location for query in pagination
382+
:return: Results page including variants and a cursor for next result page, if available
388383
:raise InvalidSearchParamsError: if above search param requirements are violated
389384
"""
390-
391-
def _encode_cursor(start: int, allele_id: str) -> str:
392-
raw = json.dumps(
393-
{"start": start, "id": allele_id}, separators=(",", ":")
394-
).encode()
395-
return base64.urlsafe_b64encode(raw).decode()
396-
397-
def _decode_cursor(cursor: str) -> tuple[int, str]:
398-
raw = base64.urlsafe_b64decode(cursor.encode())
399-
obj = json.loads(raw)
400-
return int(obj["start"]), str(obj["id"])
401-
402385
if start < 0 or stop < 0 or start > stop:
403386
raise InvalidSearchParamsError
404387
seek_start: int | None = None
405388
seek_id: str | None = None
406389
if cursor:
407-
seek_start, seek_id = _decode_cursor(cursor)
390+
seek_start, seek_id = self._decode_search_cursor(cursor)
408391

409392
with self.session_factory() as session:
410393
stmt = (
@@ -440,5 +423,5 @@ def _decode_cursor(cursor: str) -> tuple[int, str]:
440423
if not page_db:
441424
return AlleleSearchPage(items=[], next_cursor=None)
442425
last = page_db[-1]
443-
next_cursor = _encode_cursor(last.location.start, last.id)
426+
next_cursor = self._encode_search_cursor(last.location.start, last.id)
444427
return AlleleSearchPage(items=items, next_cursor=next_cursor)

src/anyvar/storage/snowflake.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,17 @@
1414
from ga4gh.vrs import models as vrs_models
1515
from snowflake.sqlalchemy import MergeInto
1616
from snowflake.sqlalchemy.snowdialect import SnowflakeDialect
17-
from sqlalchemy import String, column, create_engine, delete, insert, select, text
17+
from sqlalchemy import (
18+
String,
19+
and_,
20+
column,
21+
create_engine,
22+
delete,
23+
insert,
24+
or_,
25+
select,
26+
text,
27+
)
1828
from sqlalchemy.engine.url import URL
1929
from sqlalchemy.exc import IntegrityError
2030
from sqlalchemy.ext.compiler import compiles
@@ -27,6 +37,7 @@
2737
from anyvar.core import objects as anyvar_objects
2838
from anyvar.storage import orm
2939
from anyvar.storage.base import (
40+
AlleleSearchPage,
3041
DataIntegrityError,
3142
IncompleteVrsObjectError,
3243
InvalidSearchParamsError,
@@ -616,13 +627,18 @@ def search_alleles(
616627
refget_accession: str,
617628
start: int,
618629
stop: int,
619-
) -> Iterable[vrs_models.Allele]:
630+
page_size: int = 1000,
631+
cursor: str | None = None,
632+
) -> AlleleSearchPage:
620633
"""Find all Alleles that are located within the specified interval.
621634
622635
The interval is the closed range [start, stop] on the sequence identified by
623636
the RefGet SequenceReference accession (`SQ.*`). Both `start` and `stop` are
624637
inclusive and represent inter-residue positions.
625638
639+
Uses keyset pagination, meaning that altering the page size while looping through
640+
successive cursors will effectively nullify the search loop.
641+
626642
Currently, any variation which overlaps the queried region is returned.
627643
628644
Todo (see Issue #338):
@@ -636,12 +652,19 @@ def search_alleles(
636652
:param refget_accession: refget accession (e.g. `"SQ.IW78mgV5Cqf6M24hy52hPjyyo5tCCd86"`)
637653
:param start: Inclusive, inter-residue start position of the interval
638654
:param stop: Inclusive, inter-residue end position of the interval
639-
:return: an iterable of matching VRS alleles
655+
:param page_size: Max # of results to return
656+
:param cursor: Opaque key indicating start location for query in pagination
657+
:return: Results page including variants and a cursor for next result page, if available
640658
:raise InvalidSearchParamsError: if above search param requirements are violated
641659
"""
642660
if start < 0 or stop < 0 or start > stop:
643661
raise InvalidSearchParamsError
644662

663+
seek_start: int | None = None
664+
seek_id: str | None = None
665+
if cursor:
666+
seek_start, seek_id = self._decode_search_cursor(cursor)
667+
645668
with self.session_factory() as session:
646669
# Query alleles with overlapping locations
647670
# NOTE: this is any overlap, not containment.
@@ -659,13 +682,26 @@ def search_alleles(
659682
orm.Location.start <= stop,
660683
orm.Location.end >= start,
661684
)
662-
.limit(self.MAX_ROWS)
685+
.order_by(orm.Location.start, orm.Allele.id)
686+
.limit(page_size)
663687
)
664-
db_alleles = session.scalars(stmt).all()
665688

666-
return [
667-
mapper_registry.from_db_entity(db_allele) for db_allele in db_alleles
668-
]
689+
# seek predicate -- assumes ORDER BY location.start ASC, allele.id ASC
690+
if seek_start is not None and seek_id is not None:
691+
stmt = stmt.where(
692+
or_(
693+
orm.Location.start > seek_start,
694+
and_(orm.Location.start == seek_start, orm.Allele.id > seek_id),
695+
)
696+
)
697+
698+
page_db = session.scalars(stmt).all()
699+
items = [mapper_registry.from_db_entity(a) for a in page_db]
700+
if not page_db:
701+
return AlleleSearchPage(items=[], next_cursor=None)
702+
last = page_db[-1]
703+
next_cursor = self._encode_search_cursor(last.location.start, last.id)
704+
return AlleleSearchPage(items=items, next_cursor=next_cursor)
669705

670706

671707
@compiles(Insert, "snowflake")

tests/unit/storage/storage_test_funcs.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -425,30 +425,30 @@ def run_search_alleles(
425425
result = storage.search_alleles(
426426
rle.location.sequenceReference.refgetAccession, 36561660, 36561665
427427
)
428-
assert result == [rle]
428+
assert result.items == [rle]
429429
result = storage.search_alleles(
430430
egfr_variant.location.sequenceReference.refgetAccession, 55174010, 140753340
431431
)
432-
sorted(result, key=lambda a: a.id)
433-
assert result == [egfr_variant, braf_variant]
432+
sorted(result.items, key=lambda a: a.id)
433+
assert result.items == [egfr_variant, braf_variant]
434434

435435
# result partially overlaps with interval
436436
result = storage.search_alleles(
437437
rle.location.sequenceReference.refgetAccession, 36561662, 36561665
438438
)
439-
assert result == [rle]
439+
assert result.items == [rle]
440440
assert storage.search_alleles(
441441
rle.location.sequenceReference.refgetAccession, 36561662, 36561663
442-
) == [rle]
442+
).items == [rle]
443443

444444
# position ranges are inclusive
445445
result = storage.search_alleles(
446446
braf_variant.location.sequenceReference.refgetAccession, 140753335, 140753336
447447
)
448-
assert result == [braf_variant]
448+
assert result.items == [braf_variant]
449449

450450
# handle unrecognized accession
451-
assert storage.search_alleles("SQ.unknown-sequence", 1, 10) == []
451+
assert storage.search_alleles("SQ.unknown-sequence", 1, 10).items == []
452452

453453
# handle invalid params
454454
with pytest.raises(InvalidSearchParamsError):
@@ -469,15 +469,32 @@ def run_search_alleles(
469469
assert (
470470
storage.search_alleles(
471471
long_ins.location.sequenceReference.refgetAccession, 10599292, 10599295
472-
)
472+
).items
473473
== []
474474
)
475475
assert (
476476
storage.search_alleles(
477477
rle_del.location.sequenceReference.refgetAccession, 905, 910
478-
)
478+
).items
479479
== []
480480
)
481481
assert storage.search_alleles(
482482
rle_del.location.sequenceReference.refgetAccession, 904, 910
483-
) == [rle_del]
483+
).items == [rle_del]
484+
485+
# test pagination
486+
result = storage.search_alleles(
487+
egfr_variant.location.sequenceReference.refgetAccession,
488+
55174010,
489+
140753340,
490+
page_size=1,
491+
)
492+
assert result.items == [egfr_variant]
493+
result = storage.search_alleles(
494+
egfr_variant.location.sequenceReference.refgetAccession,
495+
55174010,
496+
140753340,
497+
page_size=1,
498+
cursor=result.next_cursor,
499+
)
500+
assert result.items == [braf_variant]

tests/unit/storage/test_postgres.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
run_incomplete_objects_error,
1515
run_mappings_crud,
1616
run_objects_raises_integrityerror,
17-
run_query_max_rows,
1817
run_search_alleles,
1918
run_sequencelocations_crud,
2019
run_sequencereferences_crud,
@@ -49,15 +48,6 @@ def test_db_lifecycle(
4948
run_db_lifecycle(storage, validated_vrs_alleles)
5049

5150

52-
@pytest.mark.ci_ok
53-
def test_query_max_rows(
54-
monkeypatch,
55-
postgres_storage: PostgresObjectStore,
56-
focus_alleles: tuple[models.Allele, models.Allele, models.Allele],
57-
):
58-
run_query_max_rows(monkeypatch, postgres_storage, focus_alleles)
59-
60-
6151
@pytest.mark.ci_ok
6252
def test_alleles_crud(
6353
postgres_storage: PostgresObjectStore,

0 commit comments

Comments
 (0)