Skip to content

Commit 1032e80

Browse files
committed
Get scopes working for list binder
1 parent bdbf66c commit 1032e80

File tree

1 file changed

+44
-30
lines changed

1 file changed

+44
-30
lines changed

injector/__init__.py

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,24 @@
2222
import threading
2323
import types
2424
from abc import ABCMeta, abstractmethod
25-
from collections import namedtuple
25+
from dataclasses import dataclass
2626
from typing import (
27+
TYPE_CHECKING,
2728
Any,
2829
Callable,
29-
cast,
3030
Dict,
3131
Generic,
32-
get_args,
3332
Iterable,
3433
List,
3534
Optional,
36-
overload,
3735
Set,
3836
Tuple,
3937
Type,
4038
TypeVar,
41-
TYPE_CHECKING,
4239
Union,
40+
cast,
41+
get_args,
42+
overload,
4343
)
4444

4545
try:
@@ -52,13 +52,13 @@
5252
# canonical. Since this typing_extensions import is only for mypy it'll work even without
5353
# typing_extensions actually installed so all's good.
5454
if TYPE_CHECKING:
55-
from typing_extensions import _AnnotatedAlias, Annotated, get_type_hints
55+
from typing_extensions import Annotated, _AnnotatedAlias, get_type_hints
5656
else:
5757
# Ignoring errors here as typing_extensions stub doesn't know about those things yet
5858
try:
59-
from typing import _AnnotatedAlias, Annotated, get_type_hints
59+
from typing import Annotated, _AnnotatedAlias, get_type_hints
6060
except ImportError:
61-
from typing_extensions import _AnnotatedAlias, Annotated, get_type_hints
61+
from typing_extensions import Annotated, _AnnotatedAlias, get_type_hints
6262

6363

6464
__author__ = 'Alec Thomas <[email protected]>'
@@ -343,16 +343,19 @@ def __repr__(self) -> str:
343343
class ListOfProviders(Provider, Generic[T]):
344344
"""Provide a list of instances via other Providers."""
345345

346-
_providers: List[Provider[T]]
346+
_multi_bindings: List['Binding']
347347

348-
def __init__(self) -> None:
349-
self._providers = []
348+
def __init__(self, parent: 'Binder') -> None:
349+
self._multi_bindings = []
350+
self._binder = Binder(parent.injector, auto_bind=False, parent=parent)
350351

351-
def append(self, provider: Provider[T]) -> None:
352-
self._providers.append(provider)
352+
def append(self, provider: Provider[T], scope: Type['Scope']) -> None:
353+
# HACK: generate a pseudo-type for this element in the list
354+
pseudo_type = type(f"pseudo-type-{id(provider)}", (provider.__class__,), {})
355+
self._multi_bindings.append(Binding(pseudo_type, provider, scope))
353356

354357
def __repr__(self) -> str:
355-
return '%s(%r)' % (type(self).__name__, self._providers)
358+
return '%s(%r)' % (type(self).__name__, self._multi_bindings)
356359

357360

358361
class MultiBindProvider(ListOfProviders[List[T]]):
@@ -361,8 +364,11 @@ class MultiBindProvider(ListOfProviders[List[T]]):
361364

362365
def get(self, injector: 'Injector') -> List[T]:
363366
result: List[T] = []
364-
for provider in self._providers:
365-
instances: List[T] = _ensure_iterable(provider.get(injector))
367+
for binding in self._multi_bindings:
368+
scope_binding, _ = self._binder.get_binding(binding.scope)
369+
scope_instance: Scope = scope_binding.provider.get(injector)
370+
provider_instance = scope_instance.get(binding.interface, binding.provider)
371+
instances: List[T] = _ensure_iterable(provider_instance.get(injector))
366372
result.extend(instances)
367373
return result
368374

@@ -372,8 +378,9 @@ class MapBindProvider(ListOfProviders[Dict[str, T]]):
372378

373379
def get(self, injector: 'Injector') -> Dict[str, T]:
374380
map: Dict[str, T] = {}
375-
for provider in self._providers:
376-
map.update(provider.get(injector))
381+
for binding in self._multi_bindings:
382+
# TODO: support scope
383+
map.update(binding.provider.get(injector))
377384
return map
378385

379386

@@ -387,7 +394,11 @@ def get(self, injector: 'Injector') -> Dict[str, T]:
387394
return {self._key: self._provider.get(injector)}
388395

389396

390-
_BindingBase = namedtuple('_BindingBase', 'interface provider scope')
397+
@dataclass
398+
class _BindingBase:
399+
interface: type
400+
provider: Provider
401+
scope: Type['Scope']
391402

392403

393404
@private
@@ -539,15 +550,15 @@ def multibind(
539550
and issubclass(interface, dict)
540551
or _get_origin(_punch_through_alias(interface)) is dict
541552
):
542-
provider = MapBindProvider()
553+
provider = MapBindProvider(self)
543554
else:
544-
provider = MultiBindProvider()
545-
binding = self.create_binding(interface, provider, scope)
555+
provider = MultiBindProvider(self)
556+
binding = self.create_binding(interface, provider)
546557
self._bindings[interface] = binding
547558
else:
548559
binding = self._bindings[interface]
560+
assert isinstance(binding.provider, ListOfProviders)
549561
provider = binding.provider
550-
assert isinstance(provider, ListOfProviders)
551562

552563
if isinstance(provider, MultiBindProvider) and isinstance(to, list):
553564
try:
@@ -557,7 +568,8 @@ def multibind(
557568
f"Use typing.List[T] or list[T] to specify the element type of the list"
558569
)
559570
for element in to:
560-
provider.append(self.provider_for(element_type, element))
571+
element_binding = self.create_binding(element_type, element, scope)
572+
provider.append(element_binding.provider, element_binding.scope)
561573
elif isinstance(provider, MapBindProvider) and isinstance(to, dict):
562574
try:
563575
value_type = get_args(_punch_through_alias(interface))[1]
@@ -566,9 +578,11 @@ def multibind(
566578
f"Use typing.Dict[K, V] or dict[K, V] to specify the value type of the dict"
567579
)
568580
for key, value in to.items():
569-
provider.append(KeyValueProvider(key, self.provider_for(value_type, value)))
581+
element_binding = self.create_binding(value_type, value, scope)
582+
provider.append(KeyValueProvider(key, element_binding.provider), element_binding.scope)
570583
else:
571-
provider.append(self.provider_for(interface, to))
584+
element_binding = self.create_binding(interface, to, scope)
585+
provider.append(element_binding.provider, element_binding.scope)
572586

573587
def install(self, module: _InstallableModuleType) -> None:
574588
"""Install a module into this binder.
@@ -611,10 +625,10 @@ def create_binding(
611625
self, interface: type, to: Any = None, scope: Union['ScopeDecorator', Type['Scope'], None] = None
612626
) -> Binding:
613627
provider = self.provider_for(interface, to)
614-
scope = scope or getattr(to or interface, '__scope__', NoScope)
628+
scope = scope or getattr(to or interface, '__scope__', None)
615629
if isinstance(scope, ScopeDecorator):
616630
scope = scope.scope
617-
return Binding(interface, provider, scope)
631+
return Binding(interface, provider, scope or NoScope)
618632

619633
def provider_for(self, interface: Any, to: Any = None) -> Provider:
620634
base_type = _punch_through_alias(interface)
@@ -696,7 +710,7 @@ def get_binding(self, interface: type) -> Tuple[Binding, 'Binder']:
696710
# The special interface is added here so that requesting a special
697711
# interface with auto_bind disabled works
698712
if self._auto_bind or self._is_special_interface(interface):
699-
binding = ImplicitBinding(*self.create_binding(interface))
713+
binding = ImplicitBinding(**self.create_binding(interface).__dict__)
700714
self._bindings[interface] = binding
701715
return binding, self
702716

@@ -817,7 +831,7 @@ def __repr__(self) -> str:
817831
class NoScope(Scope):
818832
"""An unscoped provider."""
819833

820-
def get(self, unused_key: Type[T], provider: Provider[T]) -> Provider[T]:
834+
def get(self, key: Type[T], provider: Provider[T]) -> Provider[T]:
821835
return provider
822836

823837

0 commit comments

Comments
 (0)