Skip to content

Commit 72c6e08

Browse files
committed
fixed service look up and fixed forwardRef duplicating
1 parent f5e0666 commit 72c6e08

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

ellar/common/decorators/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def Module(
5353
*,
5454
name: t.Optional[str] = None,
5555
controllers: t.Sequence[t.Union[t.Type[ControllerBase], t.Type]] = (),
56-
routers: t.Sequence[t.Union[ModuleRouter, Mount, Host]] = (),
56+
routers: t.Sequence[t.Union[ModuleRouter, Mount, Host, t.Callable]] = (),
5757
providers: t.Sequence[t.Union[t.Type, "ProviderConfig"]] = (),
5858
exports: t.Sequence[t.Union[t.Type]] = (),
5959
template_folder: t.Optional[str] = "templates",

ellar/di/injector/tree_manager.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,20 @@ def add_forward_ref(
126126
forward_ref: "ModuleForwardRef",
127127
) -> "ModuleTreeManager":
128128
module_node = self.get_module(module_type)
129-
module_node.dependencies.append(forward_ref)
129+
_forward_data = self._forward_refs.get(forward_ref)
130+
131+
if not _forward_data:
132+
# Get referenced module node
133+
_forward_module_node = self.get_module(forward_ref.module)
134+
_forward_data = TreeData(
135+
value=forward_ref,
136+
parent=_forward_module_node.parent,
137+
dependencies=_forward_module_node.dependencies,
138+
)
130139

131-
data = TreeData(
132-
value=forward_ref,
133-
parent=module_node.parent,
134-
dependencies=module_node.dependencies,
135-
)
140+
self._forward_refs[forward_ref] = _forward_data
136141

137-
self._forward_refs[forward_ref] = data
142+
module_node.dependencies.append(_forward_data.value)
138143

139144
return self
140145

@@ -245,14 +250,18 @@ def search_module_tree(
245250
:param find_predicate:
246251
:return: The node with the given ID, or None if not found.
247252
"""
253+
_stack_cycle: t.Tuple[t.Any] = () # type:ignore[assignment]
248254

249255
def dfs(current_node: TreeData) -> t.Optional[TreeData]:
256+
nonlocal _stack_cycle
257+
_stack_cycle += (current_node.value.module,) # type:ignore[assignment]
258+
250259
if find_predicate(current_node):
251260
return current_node
252261

253262
for child_id in current_node.dependencies:
254263
child_node = self.get_module(child_id)
255-
if child_node:
264+
if child_node and child_node.value.module not in _stack_cycle:
256265
res = dfs(child_node)
257266
if res:
258267
return res

0 commit comments

Comments
 (0)