2222import threading
2323import types
2424from abc import ABCMeta , abstractmethod
25- from collections import namedtuple
25+ from dataclasses import dataclass
2626from typing import (
27+ TYPE_CHECKING ,
2728 Any ,
2829 Callable ,
29- cast ,
3030 Dict ,
3131 Generic ,
32- get_args ,
3332 Iterable ,
3433 List ,
3534 Optional ,
36- overload ,
3735 Set ,
3836 Tuple ,
3937 Type ,
4038 TypeVar ,
41- TYPE_CHECKING ,
4239 Union ,
40+ cast ,
41+ get_args ,
42+ overload ,
4343)
4444
4545try :
5252# canonical. Since this typing_extensions import is only for mypy it'll work even without
5353# typing_extensions actually installed so all's good.
5454if TYPE_CHECKING :
55- from typing_extensions import _AnnotatedAlias , Annotated , get_type_hints
55+ from typing_extensions import Annotated , _AnnotatedAlias , get_type_hints
5656else :
5757 # Ignoring errors here as typing_extensions stub doesn't know about those things yet
5858 try :
59- from typing import _AnnotatedAlias , Annotated , get_type_hints
59+ from typing import Annotated , _AnnotatedAlias , get_type_hints
6060 except ImportError :
61- from typing_extensions import _AnnotatedAlias , Annotated , get_type_hints
61+ from typing_extensions import Annotated , _AnnotatedAlias , get_type_hints
6262
6363
6464__author__ = 'Alec Thomas <[email protected] >' @@ -343,16 +343,19 @@ def __repr__(self) -> str:
343343class ListOfProviders (Provider , Generic [T ]):
344344 """Provide a list of instances via other Providers."""
345345
346- _providers : List [Provider [ T ] ]
346+ _multi_bindings : List ['Binding' ]
347347
348- def __init__ (self ) -> None :
349- self ._providers = []
348+ def __init__ (self , parent : 'Binder' ) -> None :
349+ self ._multi_bindings = []
350+ self ._binder = Binder (parent .injector , auto_bind = False , parent = parent )
350351
351- def append (self , provider : Provider [T ]) -> None :
352- self ._providers .append (provider )
352+ def append (self , provider : Provider [T ], scope : Type ['Scope' ]) -> None :
353+ # HACK: generate a pseudo-type for this element in the list
354+ pseudo_type = type (f"pseudo-type-{ id (provider )} " , (provider .__class__ ,), {})
355+ self ._multi_bindings .append (Binding (pseudo_type , provider , scope ))
353356
354357 def __repr__ (self ) -> str :
355- return '%s(%r)' % (type (self ).__name__ , self ._providers )
358+ return '%s(%r)' % (type (self ).__name__ , self ._multi_bindings )
356359
357360
358361class MultiBindProvider (ListOfProviders [List [T ]]):
@@ -361,8 +364,11 @@ class MultiBindProvider(ListOfProviders[List[T]]):
361364
362365 def get (self , injector : 'Injector' ) -> List [T ]:
363366 result : List [T ] = []
364- for provider in self ._providers :
365- instances : List [T ] = _ensure_iterable (provider .get (injector ))
367+ for binding in self ._multi_bindings :
368+ scope_binding , _ = self ._binder .get_binding (binding .scope )
369+ scope_instance : Scope = scope_binding .provider .get (injector )
370+ provider_instance = scope_instance .get (binding .interface , binding .provider )
371+ instances : List [T ] = _ensure_iterable (provider_instance .get (injector ))
366372 result .extend (instances )
367373 return result
368374
@@ -372,8 +378,9 @@ class MapBindProvider(ListOfProviders[Dict[str, T]]):
372378
373379 def get (self , injector : 'Injector' ) -> Dict [str , T ]:
374380 map : Dict [str , T ] = {}
375- for provider in self ._providers :
376- map .update (provider .get (injector ))
381+ for binding in self ._multi_bindings :
382+ # TODO: support scope
383+ map .update (binding .provider .get (injector ))
377384 return map
378385
379386
@@ -387,7 +394,11 @@ def get(self, injector: 'Injector') -> Dict[str, T]:
387394 return {self ._key : self ._provider .get (injector )}
388395
389396
390- _BindingBase = namedtuple ('_BindingBase' , 'interface provider scope' )
397+ @dataclass
398+ class _BindingBase :
399+ interface : type
400+ provider : Provider
401+ scope : Type ['Scope' ]
391402
392403
393404@private
@@ -539,15 +550,15 @@ def multibind(
539550 and issubclass (interface , dict )
540551 or _get_origin (_punch_through_alias (interface )) is dict
541552 ):
542- provider = MapBindProvider ()
553+ provider = MapBindProvider (self )
543554 else :
544- provider = MultiBindProvider ()
545- binding = self .create_binding (interface , provider , scope )
555+ provider = MultiBindProvider (self )
556+ binding = self .create_binding (interface , provider )
546557 self ._bindings [interface ] = binding
547558 else :
548559 binding = self ._bindings [interface ]
560+ assert isinstance (binding .provider , ListOfProviders )
549561 provider = binding .provider
550- assert isinstance (provider , ListOfProviders )
551562
552563 if isinstance (provider , MultiBindProvider ) and isinstance (to , list ):
553564 try :
@@ -557,7 +568,8 @@ def multibind(
557568 f"Use typing.List[T] or list[T] to specify the element type of the list"
558569 )
559570 for element in to :
560- provider .append (self .provider_for (element_type , element ))
571+ element_binding = self .create_binding (element_type , element , scope )
572+ provider .append (element_binding .provider , element_binding .scope )
561573 elif isinstance (provider , MapBindProvider ) and isinstance (to , dict ):
562574 try :
563575 value_type = get_args (_punch_through_alias (interface ))[1 ]
@@ -566,9 +578,11 @@ def multibind(
566578 f"Use typing.Dict[K, V] or dict[K, V] to specify the value type of the dict"
567579 )
568580 for key , value in to .items ():
569- provider .append (KeyValueProvider (key , self .provider_for (value_type , value )))
581+ element_binding = self .create_binding (value_type , value , scope )
582+ provider .append (KeyValueProvider (key , element_binding .provider ), element_binding .scope )
570583 else :
571- provider .append (self .provider_for (interface , to ))
584+ element_binding = self .create_binding (interface , to , scope )
585+ provider .append (element_binding .provider , element_binding .scope )
572586
573587 def install (self , module : _InstallableModuleType ) -> None :
574588 """Install a module into this binder.
@@ -611,10 +625,10 @@ def create_binding(
611625 self , interface : type , to : Any = None , scope : Union ['ScopeDecorator' , Type ['Scope' ], None ] = None
612626 ) -> Binding :
613627 provider = self .provider_for (interface , to )
614- scope = scope or getattr (to or interface , '__scope__' , NoScope )
628+ scope = scope or getattr (to or interface , '__scope__' , None )
615629 if isinstance (scope , ScopeDecorator ):
616630 scope = scope .scope
617- return Binding (interface , provider , scope )
631+ return Binding (interface , provider , scope or NoScope )
618632
619633 def provider_for (self , interface : Any , to : Any = None ) -> Provider :
620634 base_type = _punch_through_alias (interface )
@@ -696,7 +710,7 @@ def get_binding(self, interface: type) -> Tuple[Binding, 'Binder']:
696710 # The special interface is added here so that requesting a special
697711 # interface with auto_bind disabled works
698712 if self ._auto_bind or self ._is_special_interface (interface ):
699- binding = ImplicitBinding (* self .create_binding (interface ))
713+ binding = ImplicitBinding (** self .create_binding (interface ). __dict__ )
700714 self ._bindings [interface ] = binding
701715 return binding , self
702716
@@ -817,7 +831,7 @@ def __repr__(self) -> str:
817831class NoScope (Scope ):
818832 """An unscoped provider."""
819833
820- def get (self , unused_key : Type [T ], provider : Provider [T ]) -> Provider [T ]:
834+ def get (self , key : Type [T ], provider : Provider [T ]) -> Provider [T ]:
821835 return provider
822836
823837
0 commit comments