Skip to content

Commit e7e1413

Browse files
committed
Added support config providers with strings
1 parent 94cfda1 commit e7e1413

File tree

3 files changed

+82
-11
lines changed

3 files changed

+82
-11
lines changed

ellar/di/service_config.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,54 @@
3434

3535

3636
class ProviderConfig(t.Generic[T]):
37+
"""
38+
ProviderConfig is a class that configures a provider for a service.
39+
40+
Example:
41+
42+
>>> class SomeClass:
43+
... def __init__(self, a: InjectByTag('A'), b: AnotherType):
44+
... self.a = a
45+
... self.b = b
46+
47+
>>> provider_config = ProviderConfig(SomeClass)
48+
>>> provider_config.register(container)
49+
50+
Example with string:
51+
52+
>>> provider_config = ProviderConfig('path.to:SomeClass')
53+
>>> provider_config.register(container)
54+
55+
Example with value:
56+
57+
>>> provider_config = ProviderConfig(SomeClass, use_value=SomeValue)
58+
>>> provider_config.register(container)
59+
60+
Example with class:
61+
62+
>>> provider_config = ProviderConfig(SomeClass, use_class=AnotherType)
63+
>>> provider_config.register(container)
64+
65+
Example with use_class as string:
66+
67+
>>> provider_config = ProviderConfig(SomeClass, use_class='path.to:AnotherType')
68+
>>> provider_config.register(container)
69+
70+
Example with scope:
71+
72+
>>> provider_config = ProviderConfig(SomeClass, scope=request_scope)
73+
>>> provider_config.register(container)
74+
75+
Example with tag:
76+
77+
>>> provider_config = ProviderConfig(SomeClass, tag='some_tag')
78+
>>> provider_config.register(container)
79+
>>> instance = container.get(InjectByTag('some_tag'))
80+
>>> assert isinstance(instance, SomeClass)
81+
82+
83+
"""
84+
3785
__slots__ = (
3886
"base_type",
3987
"use_value",
@@ -46,7 +94,7 @@ class ProviderConfig(t.Generic[T]):
4694

4795
def __init__(
4896
self,
49-
base_type: t.Union[t.Type[T], t.Type],
97+
base_type: t.Union[t.Type[T], t.Type, str],
5098
*,
5199
use_value: t.Optional[T] = None,
52100
use_class: t.Union[t.Type[T], t.Any] = None,
@@ -69,33 +117,46 @@ def __init__(
69117
self.core = core
70118

71119
def get_type(self) -> t.Type:
72-
return self.base_type
120+
return self._resolve_type(self.base_type)
121+
122+
def get_use_class(self) -> t.Type:
123+
return self._resolve_type(self.use_class)
124+
125+
def _resolve_type(self, type_or_str: t.Union[t.Type, str]) -> t.Type:
126+
from ellar.utils.importer import import_from_string
127+
128+
if isinstance(type_or_str, str):
129+
return t.cast(t.Type, import_from_string(type_or_str))
130+
return type_or_str
73131

74132
def register(self, container: "Container") -> None:
75-
scope = get_scope(self.base_type) or self.scope
76-
if self.use_class:
77-
scope = get_scope(self.use_class) or scope
133+
base_type = self.get_type()
134+
scope = get_scope(base_type) or self.scope
135+
use_class = self.get_use_class()
136+
137+
if use_class:
138+
scope = get_scope(use_class) or scope
78139
container.register(
79-
base_type=self.base_type,
80-
concrete_type=self.use_class,
140+
base_type=base_type,
141+
concrete_type=use_class,
81142
scope=scope,
82143
tag=self.tag,
83144
)
84145
elif self.use_value:
85146
container.register(
86-
base_type=self.base_type,
147+
base_type=base_type,
87148
concrete_type=self.use_value,
88149
scope=scope,
89150
tag=self.tag,
90151
)
91-
elif not isinstance(self.base_type, type):
152+
elif not isinstance(base_type, type):
92153
raise DIImproperConfiguration(
93-
f"couldn't determine provider setup for {self.base_type}. "
154+
f"couldn't determine provider setup for {base_type}. "
94155
f"Please use `ProviderConfig` or `register_services` function in a "
95156
f"Module to configure the provider"
96157
)
97158
else:
98-
container.register(base_type=self.base_type, scope=scope, tag=self.tag)
159+
container.register(base_type=base_type, scope=scope, tag=self.tag)
99160

100161

101162
@t.overload

tests/test_di/test_provider_scopes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ellar.di import EllarInjector, ProviderConfig, has_binding
55
from ellar.di.exceptions import DIImproperConfiguration
66
from ellar.di.scopes import RequestScope, SingletonScope, TransientScope
7+
from ellar.utils.importer import get_class_import
78
from injector import UnsatisfiedRequirement, inject
89

910
from .examples import AnyContext, Foo, IContext, TransientRequestContext
@@ -12,6 +13,7 @@
1213
@pytest.mark.parametrize(
1314
"provider, ref_type, expected_scope",
1415
[
16+
(ProviderConfig(get_class_import(Foo)), Foo, SingletonScope),
1517
(ProviderConfig(Foo, use_value=Foo()), Foo, SingletonScope),
1618
(ProviderConfig(Foo), Foo, SingletonScope),
1719
(ProviderConfig(Foo, scope=TransientScope), Foo, TransientScope),

tests/test_di/test_providers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
)
99
from ellar.di.providers import ClassProvider, ModuleProvider
1010
from ellar.di.scopes import SingletonScope, TransientScope
11+
from ellar.utils.importer import get_class_import
1112
from injector import (
1213
CircularDependency,
1314
UnsatisfiedRequirement,
@@ -20,10 +21,12 @@
2021
Foo1,
2122
Foo2,
2223
FooDBCatsRepository,
24+
IContext,
2325
IDBContext,
2426
InjectType,
2527
InjectType2,
2628
IRepository,
29+
TransientRequestContext,
2730
)
2831

2932

@@ -83,6 +86,9 @@ def test_provider_advance_use_case():
8386
ProviderConfig(
8487
IRepository, use_class=FooDBCatsRepository
8588
), # register base type against a concrete_type
89+
ProviderConfig(
90+
IContext, use_class=get_class_import(TransientRequestContext)
91+
), # register base type against a concrete_type
8692
ProviderConfig(
8793
IDBContext, use_class=AnyDBContext
8894
), # register base type against a concrete_type
@@ -93,8 +99,10 @@ def test_provider_advance_use_case():
9399
provider.register(injector.container)
94100

95101
repository = injector.get(IRepository)
102+
ictx = injector.get(IContext)
96103
db_context = injector.get(IDBContext)
97104
assert isinstance(repository, FooDBCatsRepository)
105+
assert isinstance(ictx, TransientRequestContext)
98106
assert isinstance(db_context, AnyDBContext)
99107
assert repository.context == db_context # service registered as singleton
100108

0 commit comments

Comments
 (0)