Skip to content
304 changes: 289 additions & 15 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ tox = "^4.14.2"
pytest = "^8.3.3"
pytest-asyncio = "^0.24.0"
eval-type-backport = "^0.2.0"
pydantic-settings = { version = "^2.12.0", python = ">3.10" }

[tool.ruff]
target-version = "py38"
Expand Down
42 changes: 41 additions & 1 deletion test/unit/test_inject_from_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import wireup
from typing_extensions import Annotated
from wireup import Inject, inject_from_container
from wireup import Inject, inject_from_container, service
from wireup._annotations import Injected
from wireup.errors import WireupError

Expand Down Expand Up @@ -123,6 +123,46 @@ def _(
) -> None: ...


async def test_unknown_service_without_default_value() -> None:
class UnknownClass: ...

@service
class BarWithoutDefaultValue:
def __init__(self, unknown_class: UnknownClass) -> None:
self.unknown_class = unknown_class

with pytest.raises(
WireupError,
match=re.escape(
"Parameter 'unknown_class' of Type test.unit.test_inject_from_container.BarWithoutDefaultValue "
"depends on an unknown service Type test.unit.test_inject_from_container.UnknownClass"
" with qualifier None."
),
):
container = wireup.create_async_container(services=[BarWithoutDefaultValue])

@inject_from_container(container)
def _(
_: Annotated[BarWithoutDefaultValue, Inject()],
) -> None: ...


async def test_unknown_service_with_default_value() -> None:
class UnknownClass: ...

@service
class BarWithDefaultValue:
def __init__(self, unknown_class: Optional[UnknownClass] = None) -> None:
self.unknown_class = unknown_class

container = wireup.create_async_container(services=[BarWithDefaultValue])

@inject_from_container(container)
def _(
_: Annotated[BarWithDefaultValue, Inject()],
) -> None: ...


async def test_injects_service_with_provided_async_scoped_container() -> None:
container = wireup.create_async_container(service_modules=[services], parameters={"env_name": "test"})

Expand Down
31 changes: 29 additions & 2 deletions test/unit/test_service_registry.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import sys
from typing import NewType

import pytest

if sys.version_info >= (3, 10):
from pydantic_settings import BaseSettings
from wireup import service
from wireup._annotations import AbstractDeclaration, ServiceDeclaration
from wireup.errors import (
DuplicateServiceRegistrationError,
FactoryReturnTypeIsEmptyError,
InvalidRegistrationTypeError,
WireupError,
)
from wireup.ioc.service_registry import ServiceRegistry

Expand Down Expand Up @@ -126,8 +132,25 @@ def test_register_invalid_target() -> None:
ServiceRegistry(impls=[ServiceDeclaration(obj=1)])


class MyService:
pass
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Pydantic settings only works on Python >= 3.10")
def test_register_factory_with_unknown_dependency_with_default() -> None:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to have some tests that capture the same behavior for configuration injection via Inject(param=)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add them!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the tests satisfactory?

@service
class Settings(BaseSettings): ...

registry = ServiceRegistry(impls=[ServiceDeclaration(obj=Settings, lifetime="singleton")])
assert Settings in registry.impls


def test_register_factory_with_unknown_dependency_no_default() -> None:
class UnknownService: ...

class MyLocalService: ...

def factory_no_default(_: UnknownService) -> MyLocalService:
return MyLocalService()

with pytest.raises(WireupError, match="depends on an unknown service"):
ServiceRegistry(impls=[ServiceDeclaration(obj=factory_no_default, lifetime="singleton")])


class MyInterface:
Expand All @@ -136,3 +159,7 @@ class MyInterface:

def random_service_factory() -> RandomService:
return RandomService()


class MyService:
pass
2 changes: 1 addition & 1 deletion test/unit/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def inner(
)
self.assertEqual(
_param_get_annotation(params.parameters["_e"]),
AnnotatedParameter(str),
AnnotatedParameter(str, has_default_value=True),
)
self.assertIsNone(
_param_get_annotation(params.parameters["_f"]),
Expand Down
15 changes: 14 additions & 1 deletion wireup/ioc/service_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,18 @@ def interface_resolve_impl(self, klass: type, qualifier: Qualifier | None) -> ty
def assert_dependencies_valid(self) -> None:
"""Assert that all required dependencies exist for this registry instance."""
for (impl, impl_qualifier), service_factory in self.factories.items():
unknown_dependencies_with_default: list[str] = []

for name, dependency in self.dependencies[service_factory.factory].items():
self.assert_dependency_exists(parameter=dependency, target=impl, name=name)
try:
self.assert_dependency_exists(parameter=dependency, target=impl, name=name)
except WireupError:
if dependency.has_default_value:
unknown_dependencies_with_default.append(name)
continue

raise

self._assert_lifetime_valid(
impl=impl,
impl_qualifier=impl_qualifier,
Expand All @@ -326,6 +336,9 @@ def assert_dependencies_valid(self) -> None:
)
self._assert_valid_resolution_path(dependency=dependency, path=[])

for name in unknown_dependencies_with_default:
del self.dependencies[service_factory.factory][name]

def _assert_lifetime_valid(
self,
*,
Expand Down
10 changes: 5 additions & 5 deletions wireup/ioc/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,10 @@ class EmptyContainerInjectionRequest(InjectableType):
class AnnotatedParameter:
"""Represent an annotated dependency parameter."""

__slots__ = ("annotation", "is_parameter", "klass", "obj_id", "qualifier_value")
__slots__ = ("annotation", "has_default_value", "is_parameter", "klass", "obj_id", "qualifier_value")

def __init__(
self,
klass: type[Any],
annotation: InjectableType | None = None,
self, klass: type[Any], annotation: InjectableType | None = None, *, has_default_value: bool = False
) -> None:
"""Create a new AnnotatedParameter.

Expand All @@ -85,6 +83,7 @@ def __init__(
self.annotation = annotation
self.qualifier_value = self.annotation.qualifier if isinstance(self.annotation, ServiceQualifier) else None
self.is_parameter = isinstance(self.annotation, ParameterWrapper)
self.has_default_value = has_default_value
self.obj_id = self.klass, self.qualifier_value

def __eq__(self, other: object) -> bool:
Expand All @@ -95,11 +94,12 @@ def __eq__(self, other: object) -> bool:
and self.annotation == other.annotation
and self.qualifier_value == other.qualifier_value
and self.is_parameter == other.is_parameter
and self.has_default_value == other.has_default_value
)

def __hash__(self) -> int:
"""Hash things."""
return hash((self.klass, self.annotation, self.qualifier_value, self.is_parameter))
return hash((self.klass, self.annotation, self.qualifier_value, self.is_parameter, self.has_default_value))


@dataclass(frozen=True, eq=True)
Expand Down
1 change: 1 addition & 0 deletions wireup/ioc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def param_get_annotation(
return AnnotatedParameter(
klass=type_analysis.normalized_type,
annotation=_get_wireup_annotation(type_analysis.annotations),
has_default_value=parameter.default is not Parameter.empty,
)


Expand Down
Loading