3131from librt .internal import cache_version
3232
3333import mypy .semanal_main
34- from mypy .cache import CACHE_VERSION , Buffer , CacheMeta
34+ from mypy .cache import CACHE_VERSION , CacheMeta , ReadBuffer , WriteBuffer
3535from mypy .checker import TypeChecker
3636from mypy .error_formatter import OUTPUT_CHOICES , ErrorFormatter
3737from mypy .errors import CompileError , ErrorInfo , Errors , report_internal_error
@@ -603,6 +603,7 @@ def __init__(
603603 self .options = options
604604 self .version_id = version_id
605605 self .modules : dict [str , MypyFile ] = {}
606+ self .import_map : dict [str , set [str ]] = {}
606607 self .missing_modules : set [str ] = set ()
607608 self .fg_deps_meta : dict [str , FgDepMeta ] = {}
608609 # fg_deps holds the dependencies of every module that has been
@@ -623,6 +624,7 @@ def __init__(
623624 self .incomplete_namespaces ,
624625 self .errors ,
625626 self .plugin ,
627+ self .import_map ,
626628 )
627629 self .all_types : dict [Expression , Type ] = {} # Enabled by export_types
628630 self .indirection_detector = TypeIndirectionVisitor ()
@@ -742,6 +744,26 @@ def getmtime(self, path: str) -> int:
742744 else :
743745 return int (self .metastore .getmtime (path ))
744746
747+ def correct_rel_imp (self , file : MypyFile , imp : ImportFrom | ImportAll ) -> str :
748+ """Function to correct for relative imports."""
749+ file_id = file .fullname
750+ rel = imp .relative
751+ if rel == 0 :
752+ return imp .id
753+ if os .path .basename (file .path ).startswith ("__init__." ):
754+ rel -= 1
755+ if rel != 0 :
756+ file_id = "." .join (file_id .split ("." )[:- rel ])
757+ new_id = file_id + "." + imp .id if imp .id else file_id
758+
759+ if not new_id :
760+ self .errors .set_file (file .path , file .name , self .options )
761+ self .errors .report (
762+ imp .line , 0 , "No parent module -- cannot perform relative import" , blocker = True
763+ )
764+
765+ return new_id
766+
745767 def all_imported_modules_in_file (self , file : MypyFile ) -> list [tuple [int , str , int ]]:
746768 """Find all reachable import statements in a file.
747769
@@ -750,27 +772,6 @@ def all_imported_modules_in_file(self, file: MypyFile) -> list[tuple[int, str, i
750772
751773 Can generate blocking errors on bogus relative imports.
752774 """
753-
754- def correct_rel_imp (imp : ImportFrom | ImportAll ) -> str :
755- """Function to correct for relative imports."""
756- file_id = file .fullname
757- rel = imp .relative
758- if rel == 0 :
759- return imp .id
760- if os .path .basename (file .path ).startswith ("__init__." ):
761- rel -= 1
762- if rel != 0 :
763- file_id = "." .join (file_id .split ("." )[:- rel ])
764- new_id = file_id + "." + imp .id if imp .id else file_id
765-
766- if not new_id :
767- self .errors .set_file (file .path , file .name , self .options )
768- self .errors .report (
769- imp .line , 0 , "No parent module -- cannot perform relative import" , blocker = True
770- )
771-
772- return new_id
773-
774775 res : list [tuple [int , str , int ]] = []
775776 for imp in file .imports :
776777 if not imp .is_unreachable :
@@ -785,7 +786,7 @@ def correct_rel_imp(imp: ImportFrom | ImportAll) -> str:
785786 ancestors .append (part )
786787 res .append ((ancestor_pri , "." .join (ancestors ), imp .line ))
787788 elif isinstance (imp , ImportFrom ):
788- cur_id = correct_rel_imp (imp )
789+ cur_id = self . correct_rel_imp (file , imp )
789790 all_are_submodules = True
790791 # Also add any imported names that are submodules.
791792 pri = import_priority (imp , PRI_MED )
@@ -805,7 +806,7 @@ def correct_rel_imp(imp: ImportFrom | ImportAll) -> str:
805806 res .append ((pri , cur_id , imp .line ))
806807 elif isinstance (imp , ImportAll ):
807808 pri = import_priority (imp , PRI_HIGH )
808- res .append ((pri , correct_rel_imp (imp ), imp .line ))
809+ res .append ((pri , self . correct_rel_imp (file , imp ), imp .line ))
809810
810811 # Sort such that module (e.g. foo.bar.baz) comes before its ancestors (e.g. foo
811812 # and foo.bar) so that, if FindModuleCache finds the target module in a
@@ -1342,7 +1343,7 @@ def find_cache_meta(id: str, path: str, manager: BuildManager) -> CacheMeta | No
13421343 if meta [0 ] != cache_version () or meta [1 ] != CACHE_VERSION :
13431344 manager .log (f"Metadata abandoned for { id } : incompatible cache format" )
13441345 return None
1345- data_io = Buffer (meta [2 :])
1346+ data_io = ReadBuffer (meta [2 :])
13461347 m = CacheMeta .read (data_io , data_file )
13471348 else :
13481349 m = CacheMeta .deserialize (meta , data_file )
@@ -1593,7 +1594,7 @@ def write_cache(
15931594
15941595 # Serialize data and analyze interface
15951596 if manager .options .fixed_format_cache :
1596- data_io = Buffer ()
1597+ data_io = WriteBuffer ()
15971598 tree .write (data_io )
15981599 data_bytes = data_io .getvalue ()
15991600 else :
@@ -1677,7 +1678,7 @@ def write_cache_meta(meta: CacheMeta, manager: BuildManager, meta_file: str) ->
16771678 # Write meta cache file
16781679 metastore = manager .metastore
16791680 if manager .options .fixed_format_cache :
1680- data_io = Buffer ()
1681+ data_io = WriteBuffer ()
16811682 meta .write (data_io )
16821683 # Prefix with both low- and high-level cache format versions for future validation.
16831684 # TODO: switch to something like librt.internal.write_byte() if this is slow.
@@ -2110,7 +2111,7 @@ def load_tree(self, temporary: bool = False) -> None:
21102111 t0 = time .time ()
21112112 # TODO: Assert data file wasn't changed.
21122113 if isinstance (data , bytes ):
2113- data_io = Buffer (data )
2114+ data_io = ReadBuffer (data )
21142115 self .tree = MypyFile .read (data_io )
21152116 else :
21162117 self .tree = MypyFile .deserialize (data )
@@ -2483,7 +2484,7 @@ def write_cache(self) -> tuple[CacheMeta, str] | None:
24832484 if self .options .debug_serialize :
24842485 try :
24852486 if self .manager .options .fixed_format_cache :
2486- data = Buffer ()
2487+ data = WriteBuffer ()
24872488 self .tree .write (data )
24882489 else :
24892490 self .tree .serialize ()
@@ -2898,6 +2899,9 @@ def dispatch(sources: list[BuildSource], manager: BuildManager, stdout: TextIO)
28982899 manager .cache_enabled = False
28992900 graph = load_graph (sources , manager )
29002901
2902+ for id in graph :
2903+ manager .import_map [id ] = set (graph [id ].dependencies + graph [id ].suppressed )
2904+
29012905 t1 = time .time ()
29022906 manager .add_stats (
29032907 graph_size = len (graph ),
0 commit comments