Skip to content

Commit da683a7

Browse files
committed
Add get_many
1 parent d57c79a commit da683a7

File tree

5 files changed

+104
-1
lines changed

5 files changed

+104
-1
lines changed

sqlalchemy_bind_manager/_repository/async_.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Union,
1313
)
1414

15+
from sqlalchemy import select
1516
from sqlalchemy.ext.asyncio import AsyncSession
1617

1718
from .._bind_manager import SQLAlchemyAsyncBind
@@ -84,6 +85,14 @@ async def get(self, identifier: PRIMARY_KEY) -> MODEL:
8485
raise ModelNotFound("No rows found for provided primary key.")
8586
return model
8687

88+
async def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
89+
stmt = select(self._model).where(
90+
getattr(self._model, self._model_pk()).in_(identifiers)
91+
)
92+
93+
async with self._get_session(commit=False) as session:
94+
return [x for x in (await session.execute(stmt)).scalars()]
95+
8796
async def delete(
8897
self,
8998
entity: Union[MODEL, PRIMARY_KEY],

sqlalchemy_bind_manager/_repository/sync.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Union,
1313
)
1414

15+
from sqlalchemy import select
1516
from sqlalchemy.orm import Session
1617

1718
from .._bind_manager import SQLAlchemyBind
@@ -73,6 +74,14 @@ def get(self, identifier: PRIMARY_KEY) -> MODEL:
7374
raise ModelNotFound("No rows found for provided primary key.")
7475
return model
7576

77+
def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
78+
stmt = select(self._model).where(
79+
getattr(self._model, self._model_pk()).in_(identifiers)
80+
)
81+
82+
with self._get_session(commit=False) as session:
83+
return [x for x in session.execute(stmt).scalars()]
84+
7685
def delete(self, entity: Union[MODEL, PRIMARY_KEY]) -> None:
7786
# TODO: delete without loading the model
7887
if isinstance(entity, self._model):

sqlalchemy_bind_manager/protocols.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,16 @@ async def get(self, identifier: PRIMARY_KEY) -> MODEL:
4747
:return: A model instance
4848
:raises ModelNotFound: No model has been found using the primary key
4949
"""
50-
# TODO: implement get_many()
50+
...
51+
52+
async def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
53+
"""Get a list of models by primary keys.
54+
55+
:param identifiers: A list of primary keys
56+
:type identifiers: List
57+
:return: A list of models
58+
:rtype: List
59+
"""
5160
...
5261

5362
async def delete(self, entity: Union[MODEL, PRIMARY_KEY]) -> None:
@@ -191,6 +200,16 @@ def get(self, identifier: PRIMARY_KEY) -> MODEL:
191200
# TODO: implement get_many()
192201
...
193202

203+
def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
204+
"""Get a list of models by primary keys.
205+
206+
:param identifiers: A list of primary keys
207+
:type identifiers: List
208+
:return: A list of models
209+
:rtype: List
210+
"""
211+
...
212+
194213
def delete(self, entity: Union[MODEL, PRIMARY_KEY]) -> None:
195214
"""Deletes a model.
196215

tests/repository/async_/test_get.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,39 @@ async def test_get_returns_model(repository_class, model_class, sa_manager):
1919
assert isinstance(result, model_class)
2020

2121

22+
async def test_get_many_returns_models(repository_class, model_class, sa_manager):
23+
model = model_class(
24+
model_id=1,
25+
name="Someone",
26+
)
27+
model2 = model_class(
28+
model_id=2,
29+
name="SomeoneElse",
30+
)
31+
model3 = model_class(
32+
model_id=3,
33+
name="StillSomeoneElse",
34+
)
35+
repo = repository_class(sa_manager.get_bind())
36+
await repo.save_many({model, model2, model3})
37+
38+
result = await repo.get_many([1, 2])
39+
assert isinstance(result, list)
40+
assert len(result) == 2
41+
assert result[0].model_id == 1
42+
assert result[1].model_id == 2
43+
44+
45+
async def test_get_many_returns_empty_list_if_nothing_found(
46+
repository_class, model_class, sa_manager
47+
):
48+
repo = repository_class(sa_manager.get_bind())
49+
50+
result = await repo.get_many([1, 2])
51+
assert isinstance(result, list)
52+
assert len(result) == 0
53+
54+
2255
async def test_get_raises_exception_if_not_found(repository_class, sa_manager):
2356
repo = repository_class(sa_manager.get_bind())
2457

tests/repository/sync/test_get.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,39 @@ def test_get_returns_model(repository_class, model_class, sa_manager):
1919
assert isinstance(result, model_class)
2020

2121

22+
def test_get_many_returns_models(repository_class, model_class, sa_manager):
23+
model = model_class(
24+
model_id=1,
25+
name="Someone",
26+
)
27+
model2 = model_class(
28+
model_id=2,
29+
name="SomeoneElse",
30+
)
31+
model3 = model_class(
32+
model_id=3,
33+
name="StillSomeoneElse",
34+
)
35+
repo = repository_class(sa_manager.get_bind())
36+
repo.save_many({model, model2, model3})
37+
38+
result = repo.get_many([1, 2])
39+
assert isinstance(result, list)
40+
assert len(result) == 2
41+
assert result[0].model_id == 1
42+
assert result[1].model_id == 2
43+
44+
45+
def test_get_many_returns_empty_list_if_nothing_found(
46+
repository_class, model_class, sa_manager
47+
):
48+
repo = repository_class(sa_manager.get_bind())
49+
50+
result = repo.get_many([1, 2])
51+
assert isinstance(result, list)
52+
assert len(result) == 0
53+
54+
2255
def test_get_raises_exception_if_not_found(repository_class, sa_manager):
2356
repo = repository_class(sa_manager.get_bind())
2457

0 commit comments

Comments
 (0)