Skip to content

Commit c680b38

Browse files
author
remimd
committed
feat: No longer instantiate a dependency if it has been explicitly passed
1 parent e6737cf commit c680b38

File tree

1 file changed

+32
-22
lines changed

1 file changed

+32
-22
lines changed

injection/_core/module.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
AsyncIterator,
99
Awaitable,
1010
Callable,
11+
Container,
1112
Generator,
1213
Iterable,
1314
Iterator,
@@ -18,6 +19,7 @@
1819
from enum import StrEnum
1920
from functools import partial, partialmethod, singledispatchmethod, update_wrapper
2021
from inspect import (
22+
BoundArguments,
2123
Signature,
2224
isasyncgenfunction,
2325
isclass,
@@ -739,28 +741,32 @@ def mod(name: str | None = None, /) -> Module:
739741
class Dependencies:
740742
lazy_mapping: Lazy[Mapping[str, Injectable[Any]]]
741743

742-
def __iter__(self) -> Iterator[tuple[str, Any]]:
743-
for name, injectable in self.items():
744+
def iter(self, exclude: Container[str]) -> Iterator[tuple[str, Any]]:
745+
for name, injectable in self.items(exclude):
744746
with suppress(SkipInjectable):
745747
yield name, injectable.get_instance()
746748

747-
async def __aiter__(self) -> AsyncIterator[tuple[str, Any]]:
748-
for name, injectable in self.items():
749+
async def aiter(self, exclude: Container[str]) -> AsyncIterator[tuple[str, Any]]:
750+
for name, injectable in self.items(exclude):
749751
with suppress(SkipInjectable):
750752
yield name, await injectable.aget_instance()
751753

752754
@property
753755
def are_resolved(self) -> bool:
754756
return self.lazy_mapping.is_set
755757

756-
async def aget_arguments(self) -> dict[str, Any]:
757-
return {key: value async for key, value in self}
758+
async def aget_arguments(self, *, exclude: Container[str]) -> dict[str, Any]:
759+
return {key: value async for key, value in self.aiter(exclude)}
758760

759-
def get_arguments(self) -> dict[str, Any]:
760-
return dict(self)
761+
def get_arguments(self, *, exclude: Container[str]) -> dict[str, Any]:
762+
return dict(self.iter(exclude))
761763

762-
def items(self) -> Iterator[tuple[str, Injectable[Any]]]:
763-
return iter((~self.lazy_mapping).items())
764+
def items(self, exclude: Container[str]) -> Iterator[tuple[str, Injectable[Any]]]:
765+
return (
766+
(name, injectable)
767+
for name, injectable in (~self.lazy_mapping).items()
768+
if name not in exclude
769+
)
764770

765771
@classmethod
766772
def from_iterable(cls, iterable: Iterable[tuple[str, Injectable[Any]]]) -> Self:
@@ -863,16 +869,18 @@ async def abind(
863869
args: Iterable[Any] = (),
864870
kwargs: Mapping[str, Any] | None = None,
865871
) -> Arguments:
866-
additional_arguments = await self.__dependencies.aget_arguments()
867-
return self.__bind(args, kwargs, additional_arguments)
872+
bound = self.__bind(args, kwargs)
873+
dependencies = await self.__dependencies.aget_arguments(exclude=bound.arguments)
874+
return self.__build_arguments(bound, dependencies)
868875

869876
def bind(
870877
self,
871878
args: Iterable[Any] = (),
872879
kwargs: Mapping[str, Any] | None = None,
873880
) -> Arguments:
874-
additional_arguments = self.__dependencies.get_arguments()
875-
return self.__bind(args, kwargs, additional_arguments)
881+
bound = self.__bind(args, kwargs)
882+
dependencies = self.__dependencies.get_arguments(exclude=bound.arguments)
883+
return self.__build_arguments(bound, dependencies)
876884

877885
async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
878886
with self.__lock:
@@ -925,23 +933,25 @@ def __bind(
925933
self,
926934
args: Iterable[Any],
927935
kwargs: Mapping[str, Any] | None,
928-
additional_arguments: dict[str, Any] | None,
929-
) -> Arguments:
936+
) -> BoundArguments:
930937
if kwargs is None:
931938
kwargs = {}
932939

933-
if not additional_arguments:
934-
return Arguments(args, kwargs)
935-
936-
bound = self.signature.bind_partial(*args, **kwargs)
937-
bound.arguments = bound.arguments | additional_arguments | bound.arguments
938-
return Arguments(bound.args, bound.kwargs)
940+
return self.signature.bind_partial(*args, **kwargs)
939941

940942
def __run_tasks(self) -> None:
941943
while tasks := self.__tasks:
942944
task = tasks.popleft()
943945
task()
944946

947+
@staticmethod
948+
def __build_arguments(
949+
bound: BoundArguments,
950+
additional_arguments: dict[str, Any],
951+
) -> Arguments:
952+
bound.arguments = bound.arguments | additional_arguments
953+
return Arguments(bound.args, bound.kwargs)
954+
945955

946956
class InjectedFunction[**P, T](HiddenCaller[P, T], ABC):
947957
__slots__ = ("__dict__", "__injection_metadata__")

0 commit comments

Comments
 (0)