diff --git a/rodi/__init__.py b/rodi/__init__.py index 7f27091..3a8c103 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -84,6 +84,10 @@ def _get_obj_locals(obj) -> Optional[Dict[str, Any]]: return getattr(obj, "_locals", None) +def _get_obj_globals(obj) -> Dict[str, Any]: + return getattr(obj, "_globals", {}) + + def class_name(input_type): if input_type in {list, set} and str( # noqa: E721 type(input_type) == "" @@ -544,9 +548,11 @@ def _resolve_by_init_method(self, context: ResolutionContext): if sys.version_info >= (3, 10): # pragma: no cover # Python 3.10 + globalns = vars(sys.modules[self.concrete_type.__module__]) + globalns.update(_get_obj_globals(self.concrete_type)) annotations = get_type_hints( self.concrete_type.__init__, - vars(sys.modules[self.concrete_type.__module__]), + globalns, _get_obj_locals(self.concrete_type), ) for key, value in params.items(): @@ -623,9 +629,11 @@ def __call__(self, context: ResolutionContext): chain.append(concrete_type) if self._has_default_init(): + globalns = vars(sys.modules[concrete_type.__module__]) + globalns.update(_get_obj_locals(concrete_type)) annotations = get_type_hints( concrete_type, - vars(sys.modules[concrete_type.__module__]), + globalns, _get_obj_locals(concrete_type), )