Skip to content

Commit 2482138

Browse files
committed
Add support for implicit models when using a single storage
1 parent c8ec36d commit 2482138

File tree

6 files changed

+130
-50
lines changed

6 files changed

+130
-50
lines changed

.zed/settings.json

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,5 @@
1818
}
1919
]
2020
}
21-
},
22-
"lsp": {
23-
"pyright": {
24-
"settings": {
25-
"python": {
26-
"pythonPath": ".hatch/fief/bin/python"
27-
}
28-
}
29-
}
3021
}
3122
}

fief/_core.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def __init__(self, fief: "Fief[U]", container: dishka.Container) -> None:
8383
self.container = container
8484

8585
def get_storage(self, model: type[M]) -> StorageProtocol[M]:
86-
component_key = self.fief._model_component_map.get(model)
86+
component_key = self.fief._model_component_map.get(
87+
model, dishka.DEFAULT_COMPONENT
88+
)
8789
return self.container.get(StorageProtocol[model], component_key) # type: ignore[valid-type]
8890

8991
def get_method(self, type: type[MP], name: str | None = None) -> MP:
@@ -211,13 +213,19 @@ def close(self) -> None:
211213
def _init_storage(
212214
self, storage: StorageProvider | Sequence[StorageProvider]
213215
) -> list["BaseProvider"]:
216+
self._model_component_map = {}
217+
214218
if isinstance(storage, StorageProvider):
215-
storage = [storage]
219+
return [storage]
216220

217221
providers: list[BaseProvider] = []
218222
model_component_map: dict[type, str] = {}
219223
for i, s in enumerate(storage):
220224
component_key = f"storage_{i}"
225+
if s.models is None:
226+
raise ValueError( # noqa: TRY003
227+
"You must provide models on the storage providers if you have multiple ones"
228+
)
221229
for model in s.models:
222230
if model in model_component_map:
223231
raise ValueError( # noqa: TRY003
@@ -240,7 +248,9 @@ def _init_methods(
240248
raise ValueError( # noqa: TRY003
241249
f"Method with name {method.name} is already registered"
242250
)
243-
model_component_key = self._model_component_map[method.model]
251+
model_component_key = self._model_component_map.get(
252+
method.model, dishka.DEFAULT_COMPONENT
253+
)
244254
provider = dishka.Provider(component=method.name)
245255
provider.provide(
246256
method.get_provider(model_component_key), scope=dishka.Scope.REQUEST
@@ -259,7 +269,9 @@ def __init__(self, fief: "FiefAsync[U]", container: dishka.AsyncContainer) -> No
259269
self.container = container
260270

261271
async def get_storage(self, model: type[M]) -> AsyncStorageProtocol[M]:
262-
component_key = self.fief._model_component_map.get(model)
272+
component_key = self.fief._model_component_map.get(
273+
model, dishka.DEFAULT_COMPONENT
274+
)
263275
return await self.container.get(AsyncStorageProtocol[model], component_key) # type: ignore[valid-type]
264276

265277
async def get_user_storage(self) -> AsyncStorageProtocol[U]:
@@ -389,13 +401,19 @@ async def close(self) -> None:
389401
def _init_storage(
390402
self, storage: AsyncStorageProvider | Sequence[AsyncStorageProvider]
391403
) -> list["BaseProvider"]:
404+
self._model_component_map = {}
405+
392406
if isinstance(storage, AsyncStorageProvider):
393-
storage = [storage]
407+
return [storage]
394408

395409
providers: list[BaseProvider] = []
396410
model_component_map: dict[type, str] = {}
397411
for i, s in enumerate(storage):
398412
component_key = f"storage_{i}"
413+
if s.models is None:
414+
raise ValueError( # noqa: TRY003
415+
"You must provide models on the storage providers if you have multiple ones"
416+
)
399417
for model in s.models:
400418
if model in model_component_map:
401419
raise ValueError( # noqa: TRY003
@@ -420,7 +438,9 @@ def _init_methods(
420438
raise ValueError( # noqa: TRY003
421439
f"Method with name {method.name} is already registered"
422440
)
423-
model_component_key = self._model_component_map[method.model]
441+
model_component_key = self._model_component_map.get(
442+
method.model, dishka.DEFAULT_COMPONENT
443+
)
424444
provider = dishka.Provider(component=method.name)
425445
provider.provide(
426446
method.get_provider(model_component_key), scope=dishka.Scope.REQUEST

fief/storage/_protocol.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@
77

88

99
class StorageProvider(dishka.Provider):
10-
models: list[type]
10+
models: list[type] | None
1111

1212
def __init__(self, models: Sequence[type] | None = None):
1313
super().__init__()
14-
self.models = list(models or [])
15-
for model in self.models:
16-
self.provide(self.get_model_provider(model), scope=dishka.Scope.REQUEST)
14+
self.models = list(models) if models is not None else None
15+
self.provide(self.get_provider(), scope=dishka.Scope.REQUEST)
1716

18-
def get_model_provider(self, model: type[M]) -> Callable[..., "StorageProtocol[M]"]:
17+
def get_provider(self) -> Callable[..., "StorageProtocol[typing.Any]"]:
1918
raise NotImplementedError()
2019

2120

@@ -73,17 +72,14 @@ def delete(self, id: typing.Any) -> M | None: # pragma: no cover
7372

7473

7574
class AsyncStorageProvider(dishka.Provider):
76-
models: list[type]
75+
models: list[type] | None
7776

7877
def __init__(self, models: Sequence[type] | None = None):
7978
super().__init__()
80-
self.models = list(models or [])
81-
for model in self.models:
82-
self.provide(self.get_model_provider(model), scope=dishka.Scope.REQUEST)
79+
self.models = list(models) if models is not None else None
80+
self.provide(self.get_provider(), scope=dishka.Scope.REQUEST)
8381

84-
def get_model_provider(
85-
self, model: type[M]
86-
) -> Callable[..., "AsyncStorageProtocol[M]"]:
82+
def get_provider(self) -> Callable[..., "AsyncStorageProtocol[typing.Any]"]:
8783
raise NotImplementedError()
8884

8985

fief/storage/sqlalchemy.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,11 @@ def session(self, engine: Engine) -> Iterable[Session]:
7676
raise
7777
session.commit()
7878

79-
def get_model_provider(self, model: type[M]) -> Callable[..., StorageProtocol[M]]:
79+
def get_provider(self) -> Callable[..., StorageProtocol[typing.Any]]:
8080
def _provide(
81+
model: type[M],
8182
session: Session,
82-
) -> StorageProtocol[model]: # type: ignore[valid-type]
83+
) -> StorageProtocol[M]:
8384
return self.storage_class(model, session)
8485

8586
return _provide
@@ -147,12 +148,11 @@ async def session(self, engine: AsyncEngine) -> AsyncIterable[AsyncSession]:
147148
raise
148149
await session.commit()
149150

150-
def get_model_provider(
151-
self, model: type[M]
152-
) -> Callable[..., AsyncStorageProtocol[M]]:
151+
def get_provider(self) -> Callable[..., AsyncStorageProtocol[typing.Any]]:
153152
def _provide(
153+
model: type[M],
154154
session: AsyncSession,
155-
) -> AsyncStorageProtocol[model]: # type: ignore[valid-type]
155+
) -> AsyncStorageProtocol[M]:
156156
return self.storage_class(model, session)
157157

158158
return _provide

tests/fixtures.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ async def delete(self, id: typing.Any) -> M | None:
8181
class MockProvider(StorageProvider):
8282
storage_class = MockStorage
8383

84-
def get_model_provider(self, model: type[M]) -> Callable[..., "StorageProtocol[M]"]:
85-
def _provide() -> StorageProtocol[model]: # type: ignore[valid-type]
84+
def get_provider(self) -> Callable[..., StorageProtocol[typing.Any]]:
85+
def _provide(model: type[M]) -> StorageProtocol[M]:
8686
return self.storage_class(model)
8787

8888
return _provide
@@ -91,10 +91,8 @@ def _provide() -> StorageProtocol[model]: # type: ignore[valid-type]
9191
class MockAsyncProvider(AsyncStorageProvider):
9292
storage_class = MockAsyncStorage
9393

94-
def get_model_provider(
95-
self, model: type[M]
96-
) -> Callable[..., "AsyncStorageProtocol[M]"]:
97-
def _provide() -> AsyncStorageProtocol[model]: # type: ignore[valid-type]
94+
def get_provider(self) -> Callable[..., AsyncStorageProtocol[typing.Any]]:
95+
def _provide(model: type[M]) -> AsyncStorageProtocol[M]:
9896
return self.storage_class(model)
9997

10098
return _provide

tests/test_core.py

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,87 @@ class ModelB(Base):
5353
sa.Integer, primary_key=True, autoincrement=True
5454
)
5555

56+
fief = Fief(
57+
storage=SQLAlchemyProvider("sqlite:///db.db"),
58+
methods=(),
59+
user_model=UserModel,
60+
)
61+
62+
with fief as fief_request:
63+
storage_a = fief_request.get_storage(ModelA)
64+
assert isinstance(storage_a, SQLAlchemyStorage)
65+
assert storage_a.model is ModelA
66+
storage_a_bind = storage_a.session.bind
67+
assert isinstance(storage_a_bind, sa.Engine)
68+
assert str(storage_a_bind.url) == "sqlite:///db.db"
69+
70+
storage_b = fief_request.get_storage(ModelB)
71+
assert isinstance(storage_b, SQLAlchemyStorage)
72+
assert storage_b.model is ModelB
73+
storage_b_bind = storage_b.session.bind
74+
assert isinstance(storage_b_bind, sa.Engine)
75+
assert str(storage_b_bind.url) == "sqlite:///db.db"
76+
77+
fief.close()
78+
79+
80+
@pytest.mark.parametrize("anyio_backend", ["asyncio"])
81+
async def test_async_storage_provider(anyio_backend: str) -> None:
82+
class Base(DeclarativeBase):
83+
pass
84+
85+
class ModelA(Base):
86+
__tablename__ = "model_a"
87+
id: Mapped[int] = mapped_column(
88+
sa.Integer, primary_key=True, autoincrement=True
89+
)
90+
91+
class ModelB(Base):
92+
__tablename__ = "model_b"
93+
id: Mapped[int] = mapped_column(
94+
sa.Integer, primary_key=True, autoincrement=True
95+
)
96+
97+
fief = FiefAsync(
98+
storage=SQLAlchemyAsyncProvider("sqlite+aiosqlite:///db.db"),
99+
methods=(),
100+
user_model=UserModel,
101+
)
102+
103+
async with fief as fief_request:
104+
storage_a = await fief_request.get_storage(ModelA)
105+
assert isinstance(storage_a, SQLAlchemyAsyncStorage)
106+
assert storage_a.model is ModelA
107+
storage_a_bind = storage_a.session.bind
108+
assert isinstance(storage_a_bind, sa_asyncio.AsyncEngine)
109+
assert str(storage_a_bind.url) == "sqlite+aiosqlite:///db.db"
110+
111+
storage_b = await fief_request.get_storage(ModelB)
112+
assert isinstance(storage_b, SQLAlchemyAsyncStorage)
113+
assert storage_b.model is ModelB
114+
storage_b_bind = storage_b.session.bind
115+
assert isinstance(storage_b_bind, sa_asyncio.AsyncEngine)
116+
assert str(storage_b_bind.url) == "sqlite+aiosqlite:///db.db"
117+
118+
await fief.close()
119+
120+
121+
def test_multiple_storage_providers() -> None:
122+
class Base(DeclarativeBase):
123+
pass
124+
125+
class ModelA(Base):
126+
__tablename__ = "model_a"
127+
id: Mapped[int] = mapped_column(
128+
sa.Integer, primary_key=True, autoincrement=True
129+
)
130+
131+
class ModelB(Base):
132+
__tablename__ = "model_b"
133+
id: Mapped[int] = mapped_column(
134+
sa.Integer, primary_key=True, autoincrement=True
135+
)
136+
56137
fief = Fief(
57138
storage=(
58139
SQLAlchemyProvider("sqlite:///db_a.db", models=[ModelA]),
@@ -86,7 +167,7 @@ class ModelB(Base):
86167

87168

88169
@pytest.mark.parametrize("anyio_backend", ["asyncio"])
89-
async def test_async_storage_provider(anyio_backend: str) -> None:
170+
async def test_async_multiple_storage_providers(anyio_backend: str) -> None:
90171
class Base(DeclarativeBase):
91172
pass
92173

@@ -136,7 +217,7 @@ class ModelB(Base):
136217

137218
def test_method_provider() -> None:
138219
fief = Fief(
139-
storage=MockProvider(models=[UserModel, MethodModel]),
220+
storage=MockProvider(),
140221
methods=(
141222
PasswordMethodProvider(MethodModel),
142223
PasswordMethodProvider(MethodModel, name="password2"),
@@ -156,7 +237,7 @@ def test_method_provider() -> None:
156237
@pytest.mark.anyio
157238
async def test_async_method_provider() -> None:
158239
fief = FiefAsync(
159-
storage=MockAsyncProvider(models=[UserModel, MethodModel]),
240+
storage=MockAsyncProvider(),
160241
methods=(
161242
PasswordAsyncMethodProvider(MethodModel),
162243
PasswordAsyncMethodProvider(MethodModel, name="password2"),
@@ -186,10 +267,7 @@ class PasswordMethodModel:
186267
@pytest.fixture
187268
def fief() -> Fief[UserModel]:
188269
return Fief(
189-
storage=(
190-
MockProvider(models=[UserModel]),
191-
MockProvider(models=[PasswordMethodModel]),
192-
),
270+
storage=MockProvider(),
193271
methods=(PasswordMethodProvider(PasswordMethodModel),),
194272
user_model=UserModel,
195273
)
@@ -204,10 +282,7 @@ def fief_request(fief: Fief[UserModel]) -> Generator[FiefRequest[UserModel]]:
204282
@pytest.fixture
205283
def fief_async() -> FiefAsync[UserModel]:
206284
return FiefAsync(
207-
storage=(
208-
MockAsyncProvider(models=[UserModel]),
209-
MockAsyncProvider(models=[PasswordMethodModel]),
210-
),
285+
storage=MockAsyncProvider(),
211286
methods=(PasswordAsyncMethodProvider(PasswordMethodModel),),
212287
user_model=UserModel,
213288
)

0 commit comments

Comments
 (0)