Skip to content

Commit 964c1fd

Browse files
committed
feat: implement recursive forward ref resolution in test module dependencies
- Make _resolve_forward_refs recursive to traverse entire module dependency tree - Use resolved.extend() to capture nested dependencies from recursive calls - Add ModuleSetup instance handling in forward ref resolution - Update tests to reflect recursive resolution behavior This ensures ForwardRefModule instances are resolved at all dependency levels, not just at the top level, providing complete module dependency resolution for testing scenarios.
1 parent f451bb1 commit 964c1fd

File tree

3 files changed

+160
-24
lines changed

3 files changed

+160
-24
lines changed

ellar/testing/dependency_analyzer.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
if t.TYPE_CHECKING: # pragma: no cover
1414
from ellar.common import ControllerBase
15-
from ellar.core import ForwardRefModule, ModuleBase
15+
from ellar.core import ForwardRefModule, ModuleBase, ModuleSetup
1616
from ellar.di import ModuleTreeManager
1717

1818

@@ -60,6 +60,13 @@ def __init__(self, application_module: t.Union[t.Type["ModuleBase"], str]):
6060

6161
self._module_tree = self._build_module_tree()
6262

63+
def get_application_module_providers(self) -> t.List[t.Type]:
64+
"""Get all provider types from the ApplicationModule tree"""
65+
mod_data = self._module_tree.get_app_module()
66+
if mod_data:
67+
return list(mod_data.providers.values())
68+
return []
69+
6370
def _build_module_tree(self) -> "ModuleTreeManager":
6471
"""Build complete module tree for ApplicationModule"""
6572
from ellar.app import AppFactory
@@ -164,7 +171,7 @@ def collect_dependencies(mod: t.Type["ModuleBase"]) -> None:
164171

165172
def resolve_forward_ref(
166173
self, forward_ref: "ForwardRefModule"
167-
) -> t.Optional[t.Type["ModuleBase"]]:
174+
) -> t.Optional["ModuleSetup"]:
168175
"""
169176
Resolve a ForwardRefModule to its actual module from ApplicationModule tree
170177
@@ -181,7 +188,7 @@ def resolve_forward_ref(
181188
filter_item=lambda data: True,
182189
find_predicate=lambda data: data.name == forward_ref.module_name,
183190
)
184-
return t.cast(t.Type["ModuleBase"], result.value.module) if result else None
191+
return t.cast("ModuleSetup", result.value) if result else None
185192

186193
elif hasattr(forward_ref, "module") and forward_ref.module:
187194
# Module can be a Type or a string import path
@@ -197,12 +204,8 @@ def resolve_forward_ref(
197204

198205
# Search for this module type in the tree
199206
module_data = self._module_tree.get_module(module_cls)
200-
return (
201-
t.cast(t.Type["ModuleBase"], module_data.value.module)
202-
if module_data
203-
else None
204-
)
205-
207+
if module_data:
208+
return t.cast("ModuleSetup", module_data.value)
206209
return None
207210

208211

ellar/testing/module.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
constants,
1212
)
1313
from ellar.common.types import T
14-
from ellar.core import ModuleBase
14+
from ellar.core import ModuleBase, ModuleSetup
1515
from ellar.core.routing import EllarControllerMount
1616
from ellar.di import ProviderConfig
1717
from ellar.reflect import reflect
@@ -166,16 +166,12 @@ def create_test_module(
166166
app_analyzer = ApplicationModuleDependencyAnalyzer(application_module)
167167
controller_analyzer = ControllerDependencyAnalyzer()
168168

169-
# 1. Resolve ForwardRefs in registered modules
170-
resolved_modules = cls._resolve_forward_refs(modules_list, app_analyzer)
171-
modules_list = resolved_modules
172-
173169
# 2. Analyze controllers and find required modules (with recursive dependencies)
174170
required_modules = cls._analyze_and_resolve_controller_dependencies(
175171
controllers, controller_analyzer, app_analyzer
176172
)
177173

178-
# 3. Add required modules that aren't already registered
174+
# 2. Add required modules that aren't already registered
179175
# Use type comparison to avoid duplicates
180176
existing_module_types = {
181177
m if isinstance(m, type) else m.module if hasattr(m, "module") else m
@@ -186,6 +182,15 @@ def create_test_module(
186182
modules_list.append(required_module)
187183
existing_module_types.add(required_module)
188184

185+
# 4. Resolve ForwardRefs in registered modules
186+
resolved_modules = cls._resolve_forward_refs(modules_list, app_analyzer)
187+
modules_list.extend(resolved_modules)
188+
189+
providers = list(providers)
190+
# 5. Add application module providers, since this is the root module
191+
# and it will be used to resolve dependencies
192+
providers.extend(app_analyzer.get_application_module_providers())
193+
189194
# Create the module with complete dependency list
190195
module = Module(
191196
modules=modules_list,
@@ -229,20 +234,30 @@ def _resolve_forward_refs(
229234
modules: t.List[t.Any],
230235
app_analyzer: "ApplicationModuleDependencyAnalyzer",
231236
) -> t.List[t.Any]:
232-
"""Resolve ForwardRefModule instances from ApplicationModule"""
237+
"""Resolve ForwardRefModule instances from ApplicationModule recursively"""
233238
from ellar.core import ForwardRefModule
234239

235240
resolved = []
236241
for module in modules:
242+
# Resolve current module if it's a ForwardRefModule
237243
if isinstance(module, ForwardRefModule):
238244
actual_module = app_analyzer.resolve_forward_ref(module)
239-
if actual_module:
240-
resolved.append(actual_module)
241-
else:
242-
# Keep original if can't resolve (might be test-specific)
243-
resolved.append(module)
245+
current_module = actual_module.module
246+
resolved.append(actual_module)
247+
elif isinstance(module, ModuleSetup):
248+
current_module = module.module
244249
else:
245-
resolved.append(module)
250+
current_module = module
251+
252+
# Recursively resolve forward refs in module's dependencies
253+
registered_modules = (
254+
reflect.get_metadata(constants.MODULE_METADATA.MODULES, current_module)
255+
or []
256+
)
257+
if registered_modules:
258+
resolved.extend(
259+
cls._resolve_forward_refs(registered_modules, app_analyzer)
260+
)
246261

247262
return resolved
248263

tests/test_testing_dependency_resolution.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,13 +276,39 @@ def test_application_module_analyzer_get_module_dependencies_none():
276276
assert len(dependencies) == 0
277277

278278

279+
def test_application_module_analyzer_get_application_module_providers():
280+
"""Test getting providers from ApplicationModule"""
281+
from ellar.di import ProviderConfig
282+
283+
@injectable
284+
class AppLevelService:
285+
pass
286+
287+
@Module(
288+
name="TestAppModuleWithProviders",
289+
modules=[AuthModule],
290+
providers=[ProviderConfig(AppLevelService, use_class=AppLevelService)],
291+
)
292+
class TestAppModuleWithProviders(ModuleBase):
293+
pass
294+
295+
analyzer = ApplicationModuleDependencyAnalyzer(TestAppModuleWithProviders)
296+
providers = analyzer.get_application_module_providers()
297+
298+
# Should include AppLevelService
299+
assert AppLevelService in providers or any(
300+
hasattr(p, "get_type") and p.get_type() == AppLevelService for p in providers
301+
)
302+
303+
279304
# ============================================================================
280305
# Unit Tests: ForwardRefModule Resolution
281306
# ============================================================================
282307

283308

284309
def test_forward_ref_resolution_by_type():
285310
"""Test resolving ForwardRefModule by type"""
311+
from ellar.core.modules import ModuleSetup
286312

287313
# Need to have DatabaseModule actually registered in the application tree
288314
@Module(
@@ -298,7 +324,9 @@ class ForwardRefTestModule(ModuleBase):
298324
forward_ref = ForwardRefModule(module=DatabaseModule)
299325
resolved = analyzer.resolve_forward_ref(forward_ref)
300326

301-
assert resolved == DatabaseModule
327+
# Should return a ModuleSetup instance
328+
assert isinstance(resolved, ModuleSetup)
329+
assert resolved.module == DatabaseModule
302330

303331

304332
def test_forward_ref_resolution_by_name():
@@ -318,6 +346,7 @@ class ForwardRefTestModule2(ModuleBase):
318346
forward_ref = ForwardRefModule(module_name="DatabaseModule")
319347
resolved = analyzer.resolve_forward_ref(forward_ref)
320348

349+
# When resolving by name, it returns the module type directly
321350
assert resolved == DatabaseModule
322351

323352

@@ -336,6 +365,52 @@ class ForwardRefTestModule3(ModuleBase):
336365
assert resolved is None
337366

338367

368+
def test_resolve_forward_refs_handles_module_setup():
369+
"""Test that _resolve_forward_refs properly handles ModuleSetup instances"""
370+
from ellar.testing.module import Test
371+
372+
@Module(name="TestModuleForSetup", modules=[AuthModule, DatabaseModule])
373+
class TestModuleForSetup(ModuleBase):
374+
pass
375+
376+
analyzer = ApplicationModuleDependencyAnalyzer(TestModuleForSetup)
377+
378+
# Pass ForwardRefModule instances that will be resolved
379+
forward_ref_auth = ForwardRefModule(module=AuthModule)
380+
forward_ref_db = ForwardRefModule(module=DatabaseModule)
381+
382+
modules = [forward_ref_auth, forward_ref_db]
383+
resolved = Test._resolve_forward_refs(modules, analyzer)
384+
385+
# Should resolve both ForwardRefModules (and potentially their dependencies)
386+
assert len(resolved) >= 2
387+
388+
389+
def test_resolve_forward_refs_recursive_extension():
390+
"""Test that _resolve_forward_refs recursively extends with nested modules"""
391+
from ellar.testing.module import Test
392+
393+
# DatabaseModule has LoggingModule as dependency
394+
@Module(
395+
name="TestModuleForRecursive",
396+
modules=[DatabaseModule, AuthModule],
397+
)
398+
class TestModuleForRecursive(ModuleBase):
399+
pass
400+
401+
analyzer = ApplicationModuleDependencyAnalyzer(TestModuleForRecursive)
402+
403+
# Start with ForwardRefModule to DatabaseModule (which has LoggingModule as dependency)
404+
forward_ref_db = ForwardRefModule(module=DatabaseModule)
405+
modules = [forward_ref_db]
406+
resolved = Test._resolve_forward_refs(modules, analyzer)
407+
408+
# Should return resolved DatabaseModule (and potentially nested dependencies)
409+
# The exact count depends on whether DatabaseModule's LoggingModule dependency
410+
# has any ForwardRefModules in its metadata
411+
assert len(resolved) >= 1
412+
413+
339414
# ============================================================================
340415
# Integration Tests: Test.create_test_module()
341416
# ============================================================================
@@ -469,7 +544,7 @@ class TestAppWithForwardRef(ModuleBase):
469544

470545
tm = Test.create_test_module(
471546
controllers=[UserController],
472-
modules=[ModuleWithForwardRef], # Contains ForwardRef to AuthModule
547+
# Don't manually add ForwardRef module - let auto-resolution handle it
473548
application_module=TestAppWithForwardRef,
474549
)
475550

@@ -622,6 +697,49 @@ def test_create_test_module_with_import_string_application_module(reflect_contex
622697
assert isinstance(controller.auth_service, IAuthService)
623698

624699

700+
def test_create_test_module_includes_application_module_providers(reflect_context):
701+
"""Test that test module includes providers from ApplicationModule"""
702+
703+
@injectable
704+
class AppLevelService:
705+
def get_value(self):
706+
return "app_level"
707+
708+
@Module(
709+
name="AppModuleWithProviders",
710+
modules=[AuthModule],
711+
providers=[ProviderConfig(AppLevelService, use_class=AppLevelService)],
712+
)
713+
class AppModuleWithProviders(ModuleBase):
714+
pass
715+
716+
@Controller()
717+
class ControllerUsingAppService:
718+
def __init__(self, app_service: AppLevelService):
719+
self.app_service = app_service
720+
721+
@get("/test")
722+
def test_endpoint(self):
723+
return {"value": self.app_service.get_value()}
724+
725+
tm = Test.create_test_module(
726+
controllers=[ControllerUsingAppService],
727+
application_module=AppModuleWithProviders,
728+
)
729+
730+
tm.create_application()
731+
732+
# Should be able to get the app-level service
733+
app_service = tm.get(AppLevelService)
734+
assert app_service is not None
735+
assert app_service.get_value() == "app_level"
736+
737+
# Controller should also work
738+
controller = tm.get(ControllerUsingAppService)
739+
assert controller is not None
740+
assert isinstance(controller.app_service, AppLevelService)
741+
742+
625743
def test_application_module_analyzer_with_import_string():
626744
"""Test that ApplicationModuleDependencyAnalyzer accepts import strings"""
627745
import_string = "tests.test_testing_dependency_resolution:ApplicationModule"

0 commit comments

Comments
 (0)