diff --git a/.flake8 b/.flake8 index 0f6f5a8..6703612 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,6 @@ [flake8] exclude = __pycache__,built,build,venv -ignore = E203, E266, W503 +ignore = E203, E266, W503, E704, E701 max-line-length = 88 max-complexity = 18 select = B,C,E,F,W,T4,B9 diff --git a/rodi/__about__.py b/rodi/__about__.py index ff6ef86..1e0555b 100644 --- a/rodi/__about__.py +++ b/rodi/__about__.py @@ -1 +1 @@ -__version__ = "2.0.6" +__version__ = "2.0.7.dev1" diff --git a/rodi/__init__.py b/rodi/__init__.py index 7f27091..ee30708 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -209,7 +209,9 @@ def __init__(self, _type): super().__init__( f"The factory specified for type {class_name(_type)} is not " f"valid, it must be a function with either these signatures: " - f"def example_factory(context, type): " + f"def example_factory(context, activating_type, registered_type): " + f"or," + f"def example_factory(context, activating_type): " f"or," f"def example_factory(context): " f"or," @@ -342,9 +344,9 @@ def __init__(self, _type, factory): self._type = _type self.factory = factory - def __call__(self, context: ActivationScope, parent_type): + def __call__(self, context: ActivationScope, parent_type: type): assert isinstance(context, ActivationScope) - return self.factory(context, parent_type) + return self.factory(context, parent_type, self._type) class SingletonFactoryTypeProvider: @@ -355,9 +357,9 @@ def __init__(self, _type, factory): self.factory = factory self.instance = None - def __call__(self, context: ActivationScope, parent_type): + def __call__(self, context: ActivationScope, parent_type: Type): if self.instance is None: - self.instance = self.factory(context, parent_type) + self.instance = self.factory(context, parent_type, self._type) return self.instance @@ -368,11 +370,11 @@ def __init__(self, _type, factory): self._type = _type self.factory = factory - def __call__(self, context: ActivationScope, parent_type): + def __call__(self, context: ActivationScope, parent_type: Type): if self._type in context.scoped_services: return context.scoped_services[self._type] - instance = self.factory(context, parent_type) + instance = self.factory(context, parent_type, self._type) context.scoped_services[self._type] = instance return instance @@ -418,7 +420,7 @@ def get_annotations_type_provider( life_style: ServiceLifeStyle, resolver_context: ResolutionContext, ): - def factory(context, parent_type): + def factory(context, parent_type, registered_type): instance = concrete_type() for name, resolver in resolvers.items(): setattr(instance, name, resolver(context, parent_type)) @@ -509,7 +511,7 @@ def _get_resolvers_for_parameters( # but at least Optional could be supported in the future raise UnsupportedUnionTypeException(param_name, concrete_type) - if param_type is _empty: + if param_type is _empty or param_type not in services._map: if services.strict: raise CannotResolveParameterException(param_name, concrete_type) @@ -521,6 +523,14 @@ def _get_resolvers_for_parameters( else: aliases = services._aliases[param_name] + if not aliases: + cls_name = class_name(param_type) + aliases = ( + services._aliases[cls_name] + or services._aliases[cls_name.lower()] + or services._aliases[to_standard_param_name(cls_name)] + ) + if aliases: assert ( len(aliases) == 1 @@ -736,6 +746,13 @@ def get( scope = ActivationScope(self) resolver = self._map.get(desired_type) + if not resolver: + cls_name = class_name(desired_type) + resolver = ( + self._map.get(cls_name) + or self._map.get(cls_name.lower()) + or self._map.get(to_standard_param_name(cls_name)) + ) scoped_service = scope.scoped_services.get(desired_type) if scope else None if not resolver and not scoped_service: @@ -810,10 +827,12 @@ def exec( FactoryCallableNoArguments = Callable[[], Any] FactoryCallableSingleArgument = Callable[[ActivationScope], Any] FactoryCallableTwoArguments = Callable[[ActivationScope, Type], Any] +FactoryCallableThreeArguments = Callable[[ActivationScope, Type, Type], Any] FactoryCallableType = Union[ FactoryCallableNoArguments, FactoryCallableSingleArgument, FactoryCallableTwoArguments, + FactoryCallableThreeArguments, ] @@ -823,7 +842,7 @@ class FactoryWrapperNoArgs: def __init__(self, factory): self.factory = factory - def __call__(self, context, activating_type): + def __call__(self, context, activating_type, registered_type): return self.factory() @@ -833,10 +852,20 @@ class FactoryWrapperContextArg: def __init__(self, factory): self.factory = factory - def __call__(self, context, activating_type): + def __call__(self, context, activating_type, registered_type): return self.factory(context) +class FactoryWrapperPartentArg: + __slots__ = ("factory",) + + def __init__(self, factory): + self.factory = factory + + def __call__(self, context, activating_type, registered_type): + return self.factory(context, activating_type) + + class Container(ContainerProtocol): """ Configuration class for a collection of services. @@ -1123,6 +1152,9 @@ def _check_factory(factory, signature, handled_type) -> Callable: return FactoryWrapperContextArg(factory) if params_len == 2: + return FactoryWrapperPartentArg(factory) + + if params_len == 3: return factory raise InvalidFactory(handled_type) diff --git a/tests/test_fn_exec.py b/tests/test_fn_exec.py index 541e800..91ff62e 100644 --- a/tests/test_fn_exec.py +++ b/tests/test_fn_exec.py @@ -2,6 +2,7 @@ Functions exec tests. exec functions are designed to enable executing any function injecting parameters. """ + import pytest from rodi import Container, inject diff --git a/tests/test_services.py b/tests/test_services.py index 1015049..bdd3fe3 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -687,6 +687,36 @@ def __init__(self, cats_controller, service_settings): assert isinstance(u.cats_controller.cat_request_handler, GetCatRequestHandler) +def test_alias_dep_resolving(): + container = arrange_cats_example() + + class BaseClass: + pass + + class DerivedClass(BaseClass): + pass + + class UsingAliasByType: + def __init__(self, example: BaseClass): + self.example = example + + def resolve_derived_class(_) -> DerivedClass: + return DerivedClass() + + container.add_scoped_by_factory(resolve_derived_class, DerivedClass) + container.add_alias("BaseClass", DerivedClass) + container.add_scoped(UsingAliasByType) + + provider = container.build_provider() + u = provider.get(UsingAliasByType) + + assert isinstance(u, UsingAliasByType) + assert isinstance(u.example, DerivedClass) + + b = provider.get(BaseClass) + assert isinstance(b, DerivedClass) + + def test_get_service_by_name_or_alias(): container = arrange_cats_example() container.add_alias("k", CatsController) @@ -798,19 +828,21 @@ def test_invalid_factory_too_many_arguments_throws(method_name): container = Container() method = getattr(container, method_name) - def factory(context, activating_type, extra_argument_mistake): + def factory(context, activating_type, requested_type, extra_argument_mistake): return Cat("Celine") with raises(InvalidFactory): method(factory, Cat) - def factory(context, activating_type, extra_argument_mistake, two): + def factory(context, activating_type, requested_type, extra_argument_mistake, two): return Cat("Celine") with raises(InvalidFactory): method(factory, Cat) - def factory(context, activating_type, extra_argument_mistake, two, three): + def factory( + context, activating_type, requested_type, extra_argument_mistake, two, three + ): return Cat("Celine") with raises(InvalidFactory): @@ -993,6 +1025,15 @@ def cat_factory_with_context_and_activating_type(context, activating_type) -> Ca return Cat("Celine") +def cat_factory_with_context_activating_type_and_requested_type( + context, activating_type, requested_type +) -> Cat: + assert isinstance(context, ActivationScope) + assert activating_type is Cat + assert requested_type is Cat + return Cat("Celine") + + @pytest.mark.parametrize( "method_name,factory", [ @@ -1006,6 +1047,7 @@ def cat_factory_with_context_and_activating_type(context, activating_type) -> Ca cat_factory_no_args, cat_factory_with_context, cat_factory_with_context_and_activating_type, + cat_factory_with_context_activating_type_and_requested_type, ] ], ) @@ -1200,6 +1242,53 @@ def factory(_, activating_type) -> Logger: ) +@pytest.mark.parametrize( + "method_name", ["add_transient_by_factory", "add_scoped_by_factory"] +) +def test_factory_can_receive_requested_type_as_parameter(method_name): + class Db: + def __init__(self, activating, requested): + self.activating = activating + self.requested = requested + + class Fetcher: + def __init__(self, db: Db): + self.db = db + + container = Container() + container._add_exact_transient(Foo) + + def factory(self, activating_type, requested_type) -> Db: + return Db( + activating_type.__module__ + "." + activating_type.__name__, + requested_type.__module__ + "." + requested_type.__name__, + ) + + method = getattr(container, method_name) + method(factory, Db) + + container._add_exact_transient(Fetcher) + + provider = container.build_provider() + + db = provider.get(Db) + + assert db is not None + assert db.activating is not None + assert db.activating == "tests.test_services.Db" + assert db.requested is not None + assert db.requested == "tests.test_services.Db" + + fetcher = provider.get(Fetcher) + + assert fetcher is not None + assert fetcher.db is not None + assert fetcher.db.activating is not None + assert fetcher.db.activating == "tests.test_services.Fetcher" + assert fetcher.db.requested is not None + assert fetcher.db.requested == "tests.test_services.Db" + + def test_service_provider_supports_set_by_class(): provider = Services() @@ -2323,7 +2412,7 @@ def factory() -> annotation: def test_factory_without_locals_raises(): def factory_without_context() -> None: - ... + pass with pytest.raises(FactoryMissingContextException): _get_factory_annotations_or_throw(factory_without_context) @@ -2332,7 +2421,7 @@ def factory_without_context() -> None: def test_factory_with_locals_get_annotations(): @inject() def factory_without_context() -> "Cat": - ... + pass annotations = _get_factory_annotations_or_throw(factory_without_context) @@ -2350,20 +2439,20 @@ def test_deps_github_scenario(): """ class HTTPClient: - ... + pass class CommentsService: - ... + pass class ChecksService: - ... + pass class CLAHandler: comments_service: CommentsService checks_service: ChecksService class GitHubSettings: - ... + pass class GitHubAuthHandler: settings: GitHubSettings @@ -2478,7 +2567,7 @@ class B: def test_provide_protocol_with_attribute_dependency() -> None: class P(Protocol): def foo(self) -> Any: - ... + pass class Dependency: pass @@ -2506,7 +2595,7 @@ def foo(self) -> Any: def test_provide_protocol_with_init_dependency() -> None: class P(Protocol): def foo(self) -> Any: - ... + pass class Dependency: pass @@ -2536,10 +2625,10 @@ def test_provide_protocol_generic() -> None: class P(Protocol[T]): def foo(self, t: T) -> T: - ... + pass class A: - ... + pass class Impl(P[A]): def foo(self, t: A) -> A: @@ -2562,10 +2651,10 @@ def test_provide_protocol_generic_with_inner_dependency() -> None: class P(Protocol[T]): def foo(self, t: T) -> T: - ... + pass class A: - ... + pass class Dependency: pass