11//! Shared helpers for deriving Python module names from filenames and module metadata.
22
33use std:: borrow:: Cow ;
4+ use std:: cmp:: Ordering ;
5+ use std:: env;
46use std:: path:: { Component , Path } ;
57use std:: sync:: Arc ;
68
@@ -27,8 +29,9 @@ impl ModuleIdentityResolver {
2729
2830 /// Construct a resolver from an explicit list of module roots. Visible for tests.
2931 pub fn from_roots ( roots : Vec < String > ) -> Self {
32+ let module_roots = canonicalise_module_roots ( roots) ;
3033 Self {
31- module_roots : Arc :: from ( roots ) ,
34+ module_roots,
3235 cache : DashMap :: new ( ) ,
3336 }
3437 }
@@ -38,9 +41,24 @@ impl ModuleIdentityResolver {
3841 if let Some ( entry) = self . cache . get ( absolute) {
3942 return entry. clone ( ) ;
4043 }
41- let resolved = module_name_from_roots ( self . module_roots ( ) , absolute)
42- . or_else ( || lookup_module_name ( py, absolute) ) ;
43- self . cache . insert ( absolute. to_string ( ) , resolved. clone ( ) ) ;
44+ let mut path_candidate = module_name_from_roots ( self . module_roots ( ) , absolute) ;
45+ if path_candidate
46+ . as_ref ( )
47+ . map ( |name| is_filesystem_shaped_name ( name, absolute) )
48+ . unwrap_or ( true )
49+ {
50+ if let Some ( heuristic) = module_name_from_heuristics ( absolute) {
51+ path_candidate = Some ( heuristic) ;
52+ }
53+ }
54+ let sys_candidate = lookup_module_name ( py, absolute) ;
55+ let ( resolved, cacheable) = match ( sys_candidate, path_candidate) {
56+ ( Some ( preferred) , _) => ( Some ( preferred) , true ) ,
57+ ( None , candidate) => ( candidate, false ) ,
58+ } ;
59+ if cacheable {
60+ self . cache . insert ( absolute. to_string ( ) , resolved. clone ( ) ) ;
61+ }
4462 resolved
4563 }
4664
@@ -84,6 +102,7 @@ impl ModuleIdentityCache {
84102 return entry. clone ( ) ;
85103 }
86104
105+ let globals_name = hints. globals_name . and_then ( sanitise_module_name) ;
87106 let mut resolved = hints
88107 . preferred
89108 . and_then ( sanitise_module_name)
@@ -97,7 +116,18 @@ impl ModuleIdentityCache {
97116 . absolute_path
98117 . and_then ( |absolute| self . resolver . resolve_absolute ( py, absolute) )
99118 } )
100- . or_else ( || hints. globals_name . and_then ( sanitise_module_name) ) ;
119+ . or_else ( || globals_name. clone ( ) ) ;
120+
121+ if let Some ( globals) = globals_name. as_ref ( ) {
122+ if globals == "__main__"
123+ && resolved
124+ . as_deref ( )
125+ . map ( |candidate| candidate != "__main__" )
126+ . unwrap_or ( true )
127+ {
128+ resolved = Some ( globals. clone ( ) ) ;
129+ }
130+ }
101131
102132 if resolved. is_none ( ) && hints. absolute_path . is_none ( ) {
103133 if let Ok ( filename) = code. filename ( py) {
@@ -153,11 +183,20 @@ impl<'a> ModuleNameHints<'a> {
153183
154184fn collect_module_roots ( py : Python < ' _ > ) -> Vec < String > {
155185 let mut roots = Vec :: new ( ) ;
186+ let cwd = env:: current_dir ( )
187+ . ok ( )
188+ . and_then ( |dir| normalise_to_posix ( dir. as_path ( ) ) ) ;
156189 if let Ok ( sys) = py. import ( "sys" ) {
157190 if let Ok ( path_obj) = sys. getattr ( "path" ) {
158191 if let Ok ( path_list) = path_obj. downcast_into :: < PyList > ( ) {
159192 for entry in path_list. iter ( ) {
160193 if let Ok ( raw) = entry. extract :: < String > ( ) {
194+ if raw. is_empty ( ) {
195+ if let Some ( dir) = cwd. as_ref ( ) {
196+ roots. push ( dir. clone ( ) ) ;
197+ }
198+ continue ;
199+ }
161200 if let Some ( normalized) = normalise_to_posix ( Path :: new ( & raw ) ) {
162201 roots. push ( normalized) ;
163202 }
@@ -169,6 +208,38 @@ fn collect_module_roots(py: Python<'_>) -> Vec<String> {
169208 roots
170209}
171210
211+ fn canonicalise_module_roots ( roots : Vec < String > ) -> Arc < [ String ] > {
212+ let mut canonical: Vec < String > = roots. into_iter ( ) . map ( canonicalise_root) . collect ( ) ;
213+ canonical. sort_by ( |a, b| compare_roots ( a, b) ) ;
214+ canonical. dedup ( ) ;
215+ Arc :: from ( canonical)
216+ }
217+
218+ fn canonicalise_root ( mut root : String ) -> String {
219+ if root. is_empty ( ) {
220+ if let Ok ( cwd) = std:: env:: current_dir ( ) {
221+ if let Some ( normalized) = normalise_to_posix ( cwd. as_path ( ) ) {
222+ return normalized;
223+ }
224+ }
225+ return "/" . to_string ( ) ;
226+ }
227+ while root. len ( ) > 1 && root. ends_with ( '/' ) {
228+ root. pop ( ) ;
229+ }
230+ root
231+ }
232+
233+ fn compare_roots ( a : & str , b : & str ) -> Ordering {
234+ let len_a = a. len ( ) ;
235+ let len_b = b. len ( ) ;
236+ if len_a == len_b {
237+ a. cmp ( b)
238+ } else {
239+ len_b. cmp ( & len_a)
240+ }
241+ }
242+
172243pub ( crate ) fn module_name_from_roots ( roots : & [ String ] , absolute : & str ) -> Option < String > {
173244 for base in roots {
174245 if let Some ( relative) = strip_posix_prefix ( absolute, base) {
@@ -180,6 +251,14 @@ pub(crate) fn module_name_from_roots(roots: &[String], absolute: &str) -> Option
180251 None
181252}
182253
254+ fn module_name_from_heuristics ( absolute : & str ) -> Option < String > {
255+ let roots = heuristic_roots_for_absolute ( absolute) ;
256+ if roots. is_empty ( ) {
257+ return None ;
258+ }
259+ module_name_from_roots ( & roots, absolute)
260+ }
261+
183262fn lookup_module_name ( py : Python < ' _ > , absolute : & str ) -> Option < String > {
184263 let sys = py. import ( "sys" ) . ok ( ) ?;
185264 let modules_obj = sys. getattr ( "modules" ) . ok ( ) ?;
@@ -364,6 +443,69 @@ fn strip_posix_prefix<'a>(path: &'a str, base: &str) -> Option<&'a str> {
364443 }
365444}
366445
446+ fn module_from_absolute ( absolute : & str ) -> Option < String > {
447+ let without_root = absolute. trim_start_matches ( '/' ) ;
448+ let trimmed = trim_drive_prefix ( without_root) ;
449+ if trimmed. is_empty ( ) {
450+ return None ;
451+ }
452+ module_from_relative ( trimmed)
453+ }
454+
455+ fn trim_drive_prefix ( path : & str ) -> & str {
456+ if let Some ( ( prefix, remainder) ) = path. split_once ( '/' ) {
457+ if prefix. ends_with ( ':' ) {
458+ return remainder;
459+ }
460+ }
461+ path
462+ }
463+
464+ fn is_filesystem_shaped_name ( candidate : & str , absolute : & str ) -> bool {
465+ module_from_absolute ( absolute)
466+ . as_deref ( )
467+ . map ( |path_like| path_like == candidate)
468+ . unwrap_or ( false )
469+ }
470+
471+ fn heuristic_roots_for_absolute ( absolute : & str ) -> Vec < String > {
472+ if let Some ( project_root) = find_nearest_project_root ( absolute) {
473+ vec ! [ project_root]
474+ } else if let Some ( parent) = Path :: new ( absolute)
475+ . parent ( )
476+ . and_then ( |dir| normalise_to_posix ( dir) )
477+ {
478+ vec ! [ parent]
479+ } else {
480+ Vec :: new ( )
481+ }
482+ }
483+
484+ fn find_nearest_project_root ( absolute : & str ) -> Option < String > {
485+ let mut current = Path :: new ( absolute) . parent ( ) ;
486+ while let Some ( dir) = current {
487+ if has_project_marker ( dir) {
488+ return normalise_to_posix ( dir) ;
489+ }
490+ current = dir. parent ( ) ;
491+ }
492+ None
493+ }
494+
495+ fn has_project_marker ( dir : & Path ) -> bool {
496+ const PROJECT_MARKER_FILES : & [ & str ] = & [ "pyproject.toml" , "setup.cfg" , "setup.py" ] ;
497+ const PROJECT_MARKER_DIRS : & [ & str ] = & [ ".git" , ".hg" , ".svn" ] ;
498+
499+ PROJECT_MARKER_DIRS
500+ . iter ( )
501+ . map ( |marker| dir. join ( marker) )
502+ . any ( |marker_dir| marker_dir. exists ( ) )
503+ || PROJECT_MARKER_FILES
504+ . iter ( )
505+ . map ( |marker| dir. join ( marker) )
506+ . any ( |marker_file| marker_file. exists ( ) )
507+ }
508+
367509/// Normalise a filesystem path to a POSIX-style string used by trace filters.
368510pub fn normalise_to_posix ( path : & Path ) -> Option < String > {
369511 if path. as_os_str ( ) . is_empty ( ) {
@@ -486,4 +628,138 @@ mod tests {
486628 Some ( "pkg.module.sub" )
487629 ) ;
488630 }
631+
632+ #[ test]
633+ fn module_name_from_roots_prefers_specific_root_over_catch_all ( ) {
634+ Python :: with_gil ( |py| {
635+ let tmp = tempfile:: tempdir ( ) . expect ( "tempdir" ) ;
636+ let module_dir = tmp. path ( ) . join ( "pkg" ) ;
637+ std:: fs:: create_dir_all ( & module_dir) . expect ( "mkdir" ) ;
638+ let module_path = module_dir. join ( "mod.py" ) ;
639+ std:: fs:: write ( & module_path, "def foo():\n return 1\n " ) . expect ( "write" ) ;
640+
641+ let project_root = normalise_to_posix ( tmp. path ( ) ) . expect ( "normalize root" ) ;
642+ let resolver =
643+ ModuleIdentityResolver :: from_roots ( vec ! [ "/" . to_string( ) , project_root. clone( ) ] ) ;
644+ let absolute_norm =
645+ normalise_to_posix ( module_path. as_path ( ) ) . expect ( "normalize absolute" ) ;
646+ let derived = module_name_from_roots ( resolver. module_roots ( ) , absolute_norm. as_str ( ) ) ;
647+
648+ assert_eq ! ( derived. as_deref( ) , Some ( "pkg.mod" ) ) ;
649+
650+ // suppress unused warnings
651+ let _ = py;
652+ } ) ;
653+ }
654+
655+ #[ test]
656+ fn resolve_absolute_prefers_sys_modules_name_over_path_fallback ( ) {
657+ Python :: with_gil ( |py| {
658+ let tmp = tempfile:: tempdir ( ) . expect ( "tempdir" ) ;
659+ let module_dir = tmp. path ( ) . join ( "pkg" ) ;
660+ std:: fs:: create_dir_all ( & module_dir) . expect ( "mkdir" ) ;
661+ let module_path = module_dir. join ( "mod.py" ) ;
662+ std:: fs:: write ( & module_path, "def foo():\n return 42\n " ) . expect ( "write" ) ;
663+
664+ let module_path_str = module_path. to_string_lossy ( ) . to_string ( ) ;
665+ let module = load_module (
666+ py,
667+ "pkg.mod" ,
668+ module_path_str. as_str ( ) ,
669+ "def foo():\n return 42\n " ,
670+ )
671+ . expect ( "load module" ) ;
672+ let _code = get_code ( & module, "foo" ) ;
673+
674+ let resolver = ModuleIdentityResolver :: from_roots ( vec ! [ "/" . to_string( ) ] ) ;
675+ let absolute_norm =
676+ normalise_to_posix ( module_path. as_path ( ) ) . expect ( "normalize absolute" ) ;
677+ let resolved = resolver. resolve_absolute ( py, absolute_norm. as_str ( ) ) ;
678+ assert_eq ! ( resolved. as_deref( ) , Some ( "pkg.mod" ) ) ;
679+
680+ // clean up sys.modules to avoid cross-test contamination
681+ if let Ok ( sys) = py. import ( "sys" ) {
682+ if let Ok ( modules) = sys. getattr ( "modules" ) {
683+ let _ = modules. del_item ( "pkg.mod" ) ;
684+ }
685+ }
686+ } ) ;
687+ }
688+
689+ #[ test]
690+ fn resolve_absolute_uses_project_marker_root ( ) {
691+ Python :: with_gil ( |py| {
692+ let tmp = tempfile:: tempdir ( ) . expect ( "tempdir" ) ;
693+ let project_dir = tmp. path ( ) . join ( "project" ) ;
694+ let tests_dir = project_dir. join ( "tests" ) ;
695+ std:: fs:: create_dir_all ( & tests_dir) . expect ( "mkdir tests" ) ;
696+ std:: fs:: create_dir ( project_dir. join ( ".git" ) ) . expect ( "mkdir git" ) ;
697+ let module_path = tests_dir. join ( "test_mod.py" ) ;
698+ std:: fs:: write ( & module_path, "def sample():\n return 7\n " ) . expect ( "write" ) ;
699+
700+ let resolver = ModuleIdentityResolver :: from_roots ( vec ! [ "/" . to_string( ) ] ) ;
701+ let absolute_norm =
702+ normalise_to_posix ( module_path. as_path ( ) ) . expect ( "normalize absolute" ) ;
703+ let resolved = resolver. resolve_absolute ( py, absolute_norm. as_str ( ) ) ;
704+
705+ assert_eq ! ( resolved. as_deref( ) , Some ( "tests.test_mod" ) ) ;
706+ } ) ;
707+ }
708+
709+ #[ test]
710+ fn resolve_absolute_returns_main_for_runpy_module ( ) {
711+ Python :: with_gil ( |py| {
712+ let tmp = tempfile:: tempdir ( ) . expect ( "tempdir" ) ;
713+ let script_path = tmp. path ( ) . join ( "cli.py" ) ;
714+ std:: fs:: write (
715+ & script_path,
716+ "def entrypoint():\n return 0\n \n if __name__ == '__main__':\n entrypoint()\n " ,
717+ )
718+ . expect ( "write script" ) ;
719+
720+ let script_norm =
721+ normalise_to_posix ( script_path. as_path ( ) ) . expect ( "normalize absolute path" ) ;
722+
723+ let sys = py. import ( "sys" ) . expect ( "import sys" ) ;
724+ let modules = sys. getattr ( "modules" ) . expect ( "sys.modules" ) ;
725+ let original_main = modules
726+ . get_item ( "__main__" )
727+ . ok ( )
728+ . map ( |obj| obj. clone ( ) . unbind ( ) ) ;
729+
730+ let module = PyModule :: new ( py, "__main__" ) . expect ( "create module" ) ;
731+ module
732+ . setattr ( "__file__" , script_path. to_string_lossy ( ) . as_ref ( ) )
733+ . expect ( "set __file__" ) ;
734+ module
735+ . setattr ( "__name__" , "__main__" )
736+ . expect ( "set __name__" ) ;
737+ module
738+ . setattr ( "__package__" , py. None ( ) )
739+ . expect ( "set __package__" ) ;
740+ module
741+ . setattr ( "__spec__" , py. None ( ) )
742+ . expect ( "set __spec__" ) ;
743+ modules
744+ . set_item ( "__main__" , module)
745+ . expect ( "register __main__" ) ;
746+
747+ let resolver = ModuleIdentityResolver :: from_roots ( vec ! [ "/" . to_string( ) ] ) ;
748+ let resolved = resolver. resolve_absolute ( py, script_norm. as_str ( ) ) ;
749+ assert_eq ! ( resolved. as_deref( ) , Some ( "__main__" ) ) ;
750+
751+ match original_main {
752+ Some ( previous) => {
753+ modules
754+ . set_item ( "__main__" , previous)
755+ . expect ( "restore __main__" ) ;
756+ }
757+ None => {
758+ modules
759+ . del_item ( "__main__" )
760+ . expect ( "remove temporary __main__" ) ;
761+ }
762+ }
763+ } ) ;
764+ }
489765}
0 commit comments