Skip to content

Commit b16073a

Browse files
committed
🐛 fix: not resolving aliases
1 parent 1a13acc commit b16073a

File tree

2 files changed

+60
-13
lines changed

2 files changed

+60
-13
lines changed

rodi/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ def _get_resolvers_for_parameters(
511511
# but at least Optional could be supported in the future
512512
raise UnsupportedUnionTypeException(param_name, concrete_type)
513513

514-
if param_type is _empty:
514+
if param_type is _empty or param_type not in services._map:
515515
if services.strict:
516516
raise CannotResolveParameterException(param_name, concrete_type)
517517

@@ -523,6 +523,14 @@ def _get_resolvers_for_parameters(
523523
else:
524524
aliases = services._aliases[param_name]
525525

526+
if not aliases:
527+
cls_name = class_name(param_type)
528+
aliases = (
529+
services._aliases[cls_name]
530+
or services._aliases[cls_name.lower()]
531+
or services._aliases[to_standard_param_name(cls_name)]
532+
)
533+
526534
if aliases:
527535
assert (
528536
len(aliases) == 1

tests/test_services.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,33 @@ def __init__(self, cats_controller, service_settings):
687687
assert isinstance(u.cats_controller.cat_request_handler, GetCatRequestHandler)
688688

689689

690+
def test_alias_dep_resolving():
691+
container = arrange_cats_example()
692+
693+
class BaseClass:
694+
pass
695+
696+
class DerivedClass(BaseClass):
697+
pass
698+
699+
class UsingAliasByType:
700+
def __init__(self, example: BaseClass):
701+
self.example = example
702+
703+
def resolve_derived_class(_) -> DerivedClass:
704+
return DerivedClass()
705+
706+
container.add_scoped_by_factory(resolve_derived_class, DerivedClass)
707+
container.add_alias("BaseClass", DerivedClass)
708+
container.add_scoped(UsingAliasByType)
709+
710+
provider = container.build_provider()
711+
u = provider.get(UsingAliasByType)
712+
713+
assert isinstance(u, UsingAliasByType)
714+
assert isinstance(u.example, DerivedClass)
715+
716+
690717
def test_get_service_by_name_or_alias():
691718
container = arrange_cats_example()
692719
container.add_alias("k", CatsController)
@@ -2381,15 +2408,17 @@ def factory() -> annotation:
23812408

23822409

23832410
def test_factory_without_locals_raises():
2384-
def factory_without_context() -> None: ...
2411+
def factory_without_context() -> None:
2412+
pass
23852413

23862414
with pytest.raises(FactoryMissingContextException):
23872415
_get_factory_annotations_or_throw(factory_without_context)
23882416

23892417

23902418
def test_factory_with_locals_get_annotations():
23912419
@inject()
2392-
def factory_without_context() -> "Cat": ...
2420+
def factory_without_context() -> "Cat":
2421+
pass
23932422

23942423
annotations = _get_factory_annotations_or_throw(factory_without_context)
23952424

@@ -2406,17 +2435,21 @@ def test_deps_github_scenario():
24062435
└── HTTPClient
24072436
"""
24082437

2409-
class HTTPClient: ...
2438+
class HTTPClient:
2439+
pass
24102440

2411-
class CommentsService: ...
2441+
class CommentsService:
2442+
pass
24122443

2413-
class ChecksService: ...
2444+
class ChecksService:
2445+
pass
24142446

24152447
class CLAHandler:
24162448
comments_service: CommentsService
24172449
checks_service: ChecksService
24182450

2419-
class GitHubSettings: ...
2451+
class GitHubSettings:
2452+
pass
24202453

24212454
class GitHubAuthHandler:
24222455
settings: GitHubSettings
@@ -2530,7 +2563,8 @@ class B:
25302563

25312564
def test_provide_protocol_with_attribute_dependency() -> None:
25322565
class P(Protocol):
2533-
def foo(self) -> Any: ...
2566+
def foo(self) -> Any:
2567+
pass
25342568

25352569
class Dependency:
25362570
pass
@@ -2557,7 +2591,8 @@ def foo(self) -> Any:
25572591

25582592
def test_provide_protocol_with_init_dependency() -> None:
25592593
class P(Protocol):
2560-
def foo(self) -> Any: ...
2594+
def foo(self) -> Any:
2595+
pass
25612596

25622597
class Dependency:
25632598
pass
@@ -2586,9 +2621,11 @@ def test_provide_protocol_generic() -> None:
25862621
T = TypeVar("T")
25872622

25882623
class P(Protocol[T]):
2589-
def foo(self, t: T) -> T: ...
2624+
def foo(self, t: T) -> T:
2625+
pass
25902626

2591-
class A: ...
2627+
class A:
2628+
pass
25922629

25932630
class Impl(P[A]):
25942631
def foo(self, t: A) -> A:
@@ -2610,9 +2647,11 @@ def test_provide_protocol_generic_with_inner_dependency() -> None:
26102647
T = TypeVar("T")
26112648

26122649
class P(Protocol[T]):
2613-
def foo(self, t: T) -> T: ...
2650+
def foo(self, t: T) -> T:
2651+
pass
26142652

2615-
class A: ...
2653+
class A:
2654+
pass
26162655

26172656
class Dependency:
26182657
pass

0 commit comments

Comments
 (0)