Skip to content
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[flake8]
exclude = __pycache__,built,build,venv
ignore = E203, E266, W503
ignore = E203, E266, W503, E704, E701
max-line-length = 88
max-complexity = 18
select = B,C,E,F,W,T4,B9
2 changes: 1 addition & 1 deletion rodi/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.0.6"
__version__ = "2.0.7.dev1"
54 changes: 43 additions & 11 deletions rodi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def __init__(self, _type):
super().__init__(
f"The factory specified for type {class_name(_type)} is not "
f"valid, it must be a function with either these signatures: "
f"def example_factory(context, type): "
f"def example_factory(context, activating_type, registered_type): "
f"or,"
f"def example_factory(context, activating_type): "
f"or,"
f"def example_factory(context): "
f"or,"
Expand Down Expand Up @@ -342,9 +344,9 @@ def __init__(self, _type, factory):
self._type = _type
self.factory = factory

def __call__(self, context: ActivationScope, parent_type):
def __call__(self, context: ActivationScope, parent_type: type):
assert isinstance(context, ActivationScope)
return self.factory(context, parent_type)
return self.factory(context, parent_type, self._type)


class SingletonFactoryTypeProvider:
Expand All @@ -355,9 +357,9 @@ def __init__(self, _type, factory):
self.factory = factory
self.instance = None

def __call__(self, context: ActivationScope, parent_type):
def __call__(self, context: ActivationScope, parent_type: Type):
if self.instance is None:
self.instance = self.factory(context, parent_type)
self.instance = self.factory(context, parent_type, self._type)
return self.instance


Expand All @@ -368,11 +370,11 @@ def __init__(self, _type, factory):
self._type = _type
self.factory = factory

def __call__(self, context: ActivationScope, parent_type):
def __call__(self, context: ActivationScope, parent_type: Type):
if self._type in context.scoped_services:
return context.scoped_services[self._type]

instance = self.factory(context, parent_type)
instance = self.factory(context, parent_type, self._type)
context.scoped_services[self._type] = instance
return instance

Expand Down Expand Up @@ -418,7 +420,7 @@ def get_annotations_type_provider(
life_style: ServiceLifeStyle,
resolver_context: ResolutionContext,
):
def factory(context, parent_type):
def factory(context, parent_type, registered_type):
instance = concrete_type()
for name, resolver in resolvers.items():
setattr(instance, name, resolver(context, parent_type))
Expand Down Expand Up @@ -509,7 +511,7 @@ def _get_resolvers_for_parameters(
# but at least Optional could be supported in the future
raise UnsupportedUnionTypeException(param_name, concrete_type)

if param_type is _empty:
if param_type is _empty or param_type not in services._map:
if services.strict:
raise CannotResolveParameterException(param_name, concrete_type)

Expand All @@ -521,6 +523,14 @@ def _get_resolvers_for_parameters(
else:
aliases = services._aliases[param_name]

if not aliases:
cls_name = class_name(param_type)
aliases = (
services._aliases[cls_name]
or services._aliases[cls_name.lower()]
or services._aliases[to_standard_param_name(cls_name)]
)

if aliases:
assert (
len(aliases) == 1
Expand Down Expand Up @@ -736,6 +746,13 @@ def get(
scope = ActivationScope(self)

resolver = self._map.get(desired_type)
if not resolver:
cls_name = class_name(desired_type)
resolver = (
self._map.get(cls_name)
or self._map.get(cls_name.lower())
or self._map.get(to_standard_param_name(cls_name))
)
scoped_service = scope.scoped_services.get(desired_type) if scope else None

if not resolver and not scoped_service:
Expand Down Expand Up @@ -810,10 +827,12 @@ def exec(
FactoryCallableNoArguments = Callable[[], Any]
FactoryCallableSingleArgument = Callable[[ActivationScope], Any]
FactoryCallableTwoArguments = Callable[[ActivationScope, Type], Any]
FactoryCallableThreeArguments = Callable[[ActivationScope, Type, Type], Any]
FactoryCallableType = Union[
FactoryCallableNoArguments,
FactoryCallableSingleArgument,
FactoryCallableTwoArguments,
FactoryCallableThreeArguments,
]


Expand All @@ -823,7 +842,7 @@ class FactoryWrapperNoArgs:
def __init__(self, factory):
self.factory = factory

def __call__(self, context, activating_type):
def __call__(self, context, activating_type, registered_type):
return self.factory()


Expand All @@ -833,10 +852,20 @@ class FactoryWrapperContextArg:
def __init__(self, factory):
self.factory = factory

def __call__(self, context, activating_type):
def __call__(self, context, activating_type, registered_type):
return self.factory(context)


class FactoryWrapperPartentArg:
__slots__ = ("factory",)

def __init__(self, factory):
self.factory = factory

def __call__(self, context, activating_type, registered_type):
return self.factory(context, activating_type)


class Container(ContainerProtocol):
"""
Configuration class for a collection of services.
Expand Down Expand Up @@ -1123,6 +1152,9 @@ def _check_factory(factory, signature, handled_type) -> Callable:
return FactoryWrapperContextArg(factory)

if params_len == 2:
return FactoryWrapperPartentArg(factory)

if params_len == 3:
return factory

raise InvalidFactory(handled_type)
Expand Down
1 change: 1 addition & 0 deletions tests/test_fn_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Functions exec tests.
exec functions are designed to enable executing any function injecting parameters.
"""

import pytest

from rodi import Container, inject
Expand Down
119 changes: 104 additions & 15 deletions tests/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,36 @@ def __init__(self, cats_controller, service_settings):
assert isinstance(u.cats_controller.cat_request_handler, GetCatRequestHandler)


def test_alias_dep_resolving():
container = arrange_cats_example()

class BaseClass:
pass

class DerivedClass(BaseClass):
pass

class UsingAliasByType:
def __init__(self, example: BaseClass):
self.example = example

def resolve_derived_class(_) -> DerivedClass:
return DerivedClass()

container.add_scoped_by_factory(resolve_derived_class, DerivedClass)
container.add_alias("BaseClass", DerivedClass)
container.add_scoped(UsingAliasByType)

provider = container.build_provider()
u = provider.get(UsingAliasByType)

assert isinstance(u, UsingAliasByType)
assert isinstance(u.example, DerivedClass)

b = provider.get(BaseClass)
assert isinstance(b, DerivedClass)


def test_get_service_by_name_or_alias():
container = arrange_cats_example()
container.add_alias("k", CatsController)
Expand Down Expand Up @@ -798,19 +828,21 @@ def test_invalid_factory_too_many_arguments_throws(method_name):
container = Container()
method = getattr(container, method_name)

def factory(context, activating_type, extra_argument_mistake):
def factory(context, activating_type, requested_type, extra_argument_mistake):
return Cat("Celine")

with raises(InvalidFactory):
method(factory, Cat)

def factory(context, activating_type, extra_argument_mistake, two):
def factory(context, activating_type, requested_type, extra_argument_mistake, two):
return Cat("Celine")

with raises(InvalidFactory):
method(factory, Cat)

def factory(context, activating_type, extra_argument_mistake, two, three):
def factory(
context, activating_type, requested_type, extra_argument_mistake, two, three
):
return Cat("Celine")

with raises(InvalidFactory):
Expand Down Expand Up @@ -993,6 +1025,15 @@ def cat_factory_with_context_and_activating_type(context, activating_type) -> Ca
return Cat("Celine")


def cat_factory_with_context_activating_type_and_requested_type(
context, activating_type, requested_type
) -> Cat:
assert isinstance(context, ActivationScope)
assert activating_type is Cat
assert requested_type is Cat
return Cat("Celine")


@pytest.mark.parametrize(
"method_name,factory",
[
Expand All @@ -1006,6 +1047,7 @@ def cat_factory_with_context_and_activating_type(context, activating_type) -> Ca
cat_factory_no_args,
cat_factory_with_context,
cat_factory_with_context_and_activating_type,
cat_factory_with_context_activating_type_and_requested_type,
]
],
)
Expand Down Expand Up @@ -1200,6 +1242,53 @@ def factory(_, activating_type) -> Logger:
)


@pytest.mark.parametrize(
"method_name", ["add_transient_by_factory", "add_scoped_by_factory"]
)
def test_factory_can_receive_requested_type_as_parameter(method_name):
class Db:
def __init__(self, activating, requested):
self.activating = activating
self.requested = requested

class Fetcher:
def __init__(self, db: Db):
self.db = db

container = Container()
container._add_exact_transient(Foo)

def factory(self, activating_type, requested_type) -> Db:
return Db(
activating_type.__module__ + "." + activating_type.__name__,
requested_type.__module__ + "." + requested_type.__name__,
)

method = getattr(container, method_name)
method(factory, Db)

container._add_exact_transient(Fetcher)

provider = container.build_provider()

db = provider.get(Db)

assert db is not None
assert db.activating is not None
assert db.activating == "tests.test_services.Db"
assert db.requested is not None
assert db.requested == "tests.test_services.Db"

fetcher = provider.get(Fetcher)

assert fetcher is not None
assert fetcher.db is not None
assert fetcher.db.activating is not None
assert fetcher.db.activating == "tests.test_services.Fetcher"
assert fetcher.db.requested is not None
assert fetcher.db.requested == "tests.test_services.Db"


def test_service_provider_supports_set_by_class():
provider = Services()

Expand Down Expand Up @@ -2323,7 +2412,7 @@ def factory() -> annotation:

def test_factory_without_locals_raises():
def factory_without_context() -> None:
...
pass

with pytest.raises(FactoryMissingContextException):
_get_factory_annotations_or_throw(factory_without_context)
Expand All @@ -2332,7 +2421,7 @@ def factory_without_context() -> None:
def test_factory_with_locals_get_annotations():
@inject()
def factory_without_context() -> "Cat":
...
pass

annotations = _get_factory_annotations_or_throw(factory_without_context)

Expand All @@ -2350,20 +2439,20 @@ def test_deps_github_scenario():
"""

class HTTPClient:
...
pass

class CommentsService:
...
pass

class ChecksService:
...
pass

class CLAHandler:
comments_service: CommentsService
checks_service: ChecksService

class GitHubSettings:
...
pass

class GitHubAuthHandler:
settings: GitHubSettings
Expand Down Expand Up @@ -2478,7 +2567,7 @@ class B:
def test_provide_protocol_with_attribute_dependency() -> None:
class P(Protocol):
def foo(self) -> Any:
...
pass

class Dependency:
pass
Expand Down Expand Up @@ -2506,7 +2595,7 @@ def foo(self) -> Any:
def test_provide_protocol_with_init_dependency() -> None:
class P(Protocol):
def foo(self) -> Any:
...
pass

class Dependency:
pass
Expand Down Expand Up @@ -2536,10 +2625,10 @@ def test_provide_protocol_generic() -> None:

class P(Protocol[T]):
def foo(self, t: T) -> T:
...
pass

class A:
...
pass

class Impl(P[A]):
def foo(self, t: A) -> A:
Expand All @@ -2562,10 +2651,10 @@ def test_provide_protocol_generic_with_inner_dependency() -> None:

class P(Protocol[T]):
def foo(self, t: T) -> T:
...
pass

class A:
...
pass

class Dependency:
pass
Expand Down