Skip to content

Commit df2da73

Browse files
authored
Make interfaces protocols (#66)
1 parent 855a640 commit df2da73

File tree

5 files changed

+39
-30
lines changed

5 files changed

+39
-30
lines changed

.idea/misc.xml

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/sqlalchemy-bind-manager.iml

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/repository/usage.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,25 @@ async def some_async_function(repository: SQLAlchemyAsyncRepositoryInterface[MyM
5757
...
5858
```
5959

60-
Both repository and related interface are Generic, accepting the model class as a typing argument.
60+
Both repository and related interface are Protocols, accepting the model class as a typing argument. You can also
61+
extend the protocols with your custom methods.
62+
63+
```python
64+
from typing import Protocol
65+
from sqlalchemy_bind_manager.repository import SQLAlchemyRepositoryInterface, SQLAlchemyRepository
66+
67+
# SQLAlchemy model
68+
class MyModel:
69+
...
70+
71+
class MyCustomRepositoryInterface(SQLAlchemyRepositoryInterface[MyModel], Protocol):
72+
def some_custom_method(self, model: MyModel) -> MyModel:
73+
...
74+
75+
class MyCustomRepository(SQLAlchemyRepository[MyModel]):
76+
def some_custom_method(self, model: MyModel) -> MyModel:
77+
return model
78+
```
6179
///
6280

6381
### Maximum query limit

sqlalchemy_bind_manager/_repository/abstract.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,13 @@
2727
# Software is furnished to do so, subject to the following conditions:
2828
#
2929
#
30-
from abc import ABC, abstractmethod
3130
from typing import (
3231
Any,
33-
Generic,
3432
Iterable,
3533
List,
3634
Literal,
3735
Mapping,
36+
Protocol,
3837
Tuple,
3938
Union,
4039
)
@@ -48,8 +47,7 @@
4847
)
4948

5049

51-
class SQLAlchemyAsyncRepositoryInterface(Generic[MODEL], ABC):
52-
@abstractmethod
50+
class SQLAlchemyAsyncRepositoryInterface(Protocol[MODEL]):
5351
async def get(self, identifier: PRIMARY_KEY) -> MODEL:
5452
"""Get a model by primary key.
5553
@@ -59,7 +57,6 @@ async def get(self, identifier: PRIMARY_KEY) -> MODEL:
5957
"""
6058
...
6159

62-
@abstractmethod
6360
async def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
6461
"""Get a list of models by primary keys.
6562
@@ -68,7 +65,6 @@ async def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
6865
"""
6966
...
7067

71-
@abstractmethod
7268
async def save(self, instance: MODEL) -> MODEL:
7369
"""Persist a model.
7470
@@ -77,7 +73,6 @@ async def save(self, instance: MODEL) -> MODEL:
7773
"""
7874
...
7975

80-
@abstractmethod
8176
async def save_many(self, instances: Iterable[MODEL]) -> Iterable[MODEL]:
8277
"""Persist many models in a single database get_session.
8378
@@ -86,23 +81,20 @@ async def save_many(self, instances: Iterable[MODEL]) -> Iterable[MODEL]:
8681
"""
8782
...
8883

89-
@abstractmethod
9084
async def delete(self, instance: MODEL) -> None:
9185
"""Deletes a model.
9286
9387
:param instance: The model instance
9488
"""
9589
...
9690

97-
@abstractmethod
9891
async def delete_many(self, instances: Iterable[MODEL]) -> None:
9992
"""Deletes a collection of models in a single transaction.
10093
10194
:param instances: The model instances
10295
"""
10396
...
10497

105-
@abstractmethod
10698
async def find(
10799
self,
108100
search_params: Union[None, Mapping[str, Any]] = None,
@@ -130,7 +122,6 @@ async def find(
130122
"""
131123
...
132124

133-
@abstractmethod
134125
async def paginated_find(
135126
self,
136127
items_per_page: int,
@@ -169,7 +160,6 @@ async def paginated_find(
169160
"""
170161
...
171162

172-
@abstractmethod
173163
async def cursor_paginated_find(
174164
self,
175165
items_per_page: int,
@@ -205,8 +195,7 @@ async def cursor_paginated_find(
205195
...
206196

207197

208-
class SQLAlchemyRepositoryInterface(Generic[MODEL], ABC):
209-
@abstractmethod
198+
class SQLAlchemyRepositoryInterface(Protocol[MODEL]):
210199
def get(self, identifier: PRIMARY_KEY) -> MODEL:
211200
"""Get a model by primary key.
212201
@@ -216,7 +205,6 @@ def get(self, identifier: PRIMARY_KEY) -> MODEL:
216205
"""
217206
...
218207

219-
@abstractmethod
220208
def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
221209
"""Get a list of models by primary keys.
222210
@@ -225,7 +213,6 @@ def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
225213
"""
226214
...
227215

228-
@abstractmethod
229216
def save(self, instance: MODEL) -> MODEL:
230217
"""Persist a model.
231218
@@ -234,7 +221,6 @@ def save(self, instance: MODEL) -> MODEL:
234221
"""
235222
...
236223

237-
@abstractmethod
238224
def save_many(self, instances: Iterable[MODEL]) -> Iterable[MODEL]:
239225
"""Persist many models in a single database get_session.
240226
@@ -243,23 +229,20 @@ def save_many(self, instances: Iterable[MODEL]) -> Iterable[MODEL]:
243229
"""
244230
...
245231

246-
@abstractmethod
247232
def delete(self, instance: MODEL) -> None:
248233
"""Deletes a model.
249234
250235
:param instance: The model instance
251236
"""
252237
...
253238

254-
@abstractmethod
255239
def delete_many(self, instances: Iterable[MODEL]) -> None:
256240
"""Deletes a collection of models in a single transaction.
257241
258242
:param instances: The model instances
259243
"""
260244
...
261245

262-
@abstractmethod
263246
def find(
264247
self,
265248
search_params: Union[None, Mapping[str, Any]] = None,
@@ -287,7 +270,6 @@ def find(
287270
"""
288271
...
289272

290-
@abstractmethod
291273
def paginated_find(
292274
self,
293275
items_per_page: int,
@@ -326,7 +308,6 @@ def paginated_find(
326308
"""
327309
...
328310

329-
@abstractmethod
330311
def cursor_paginated_find(
331312
self,
332313
items_per_page: int,

tests/test_interfaces.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from inspect import signature
2+
from typing import Protocol, runtime_checkable
23

34
from sqlalchemy_bind_manager.repository import (
45
SQLAlchemyAsyncRepository,
@@ -8,9 +9,17 @@
89
)
910

1011

12+
@runtime_checkable
13+
class RuntimeRepoProtocol(SQLAlchemyRepositoryInterface, Protocol): ...
14+
15+
16+
@runtime_checkable
17+
class RuntimeAsyncRepoProtocol(SQLAlchemyAsyncRepositoryInterface, Protocol): ...
18+
19+
1120
def test_interfaces():
12-
assert issubclass(SQLAlchemyRepository, SQLAlchemyRepositoryInterface)
13-
assert issubclass(SQLAlchemyAsyncRepository, SQLAlchemyAsyncRepositoryInterface)
21+
assert issubclass(SQLAlchemyRepository, RuntimeRepoProtocol)
22+
assert issubclass(SQLAlchemyAsyncRepository, RuntimeAsyncRepoProtocol)
1423

1524
sync_methods = [
1625
method
@@ -26,15 +35,15 @@ def test_interfaces():
2635
assert sync_methods == async_methods
2736

2837
for method in sync_methods:
29-
# Sync signature is the same as sync protocol
38+
# Concrete sync signature is the same as sync protocol signature
3039
assert signature(getattr(SQLAlchemyRepository, method)) == signature(
3140
getattr(SQLAlchemyRepositoryInterface, method)
3241
)
33-
# Async signature is the same as async protocol
42+
# Concrete async signature is the same as async protocol signature
3443
assert signature(getattr(SQLAlchemyAsyncRepository, method)) == signature(
3544
getattr(SQLAlchemyAsyncRepositoryInterface, method)
3645
)
37-
# Sync signature is the same as async signature
46+
# Sync protocol signature is the same as async protocol signature
3847
assert signature(
3948
getattr(SQLAlchemyAsyncRepositoryInterface, method)
4049
) == signature(getattr(SQLAlchemyRepositoryInterface, method))

0 commit comments

Comments
 (0)