|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * |
| 4 | + * This source code is licensed under the MIT license found in the |
| 5 | + * LICENSE file in the root directory of this source tree. |
| 6 | + */ |
| 7 | + |
| 8 | +//! Helpers for harvesting imports and formatting type strings for inlay hints. |
| 9 | +
|
| 10 | +use std::cmp::Reverse; |
| 11 | + |
| 12 | +use dupe::Dupe; |
| 13 | +use pyrefly_python::module_name::ModuleName; |
| 14 | +use ruff_python_ast::ModModule; |
| 15 | +use ruff_python_ast::Stmt; |
| 16 | +use ruff_python_ast::StmtImport; |
| 17 | +use starlark_map::small_set::SmallSet; |
| 18 | + |
| 19 | +use crate::types::display::TypeDisplayContext; |
| 20 | +use crate::types::types::Type; |
| 21 | + |
| 22 | +/// Tracks imports already present in a module and can determine which modules are still missing |
| 23 | +/// for a given set of referenced modules. Also supports alias-aware replacement when displaying |
| 24 | +/// type strings. |
| 25 | +#[derive(Default)] |
| 26 | +pub struct ImportTracker { |
| 27 | + canonical_modules: SmallSet<ModuleName>, |
| 28 | + alias_modules: Vec<(ModuleName, String)>, |
| 29 | +} |
| 30 | + |
| 31 | +impl ImportTracker { |
| 32 | + /// Build an import tracker from the top-level `import ...` statements in a module. |
| 33 | + pub fn from_ast(ast: &ModModule) -> Self { |
| 34 | + let mut tracker = Self::default(); |
| 35 | + for stmt in &ast.body { |
| 36 | + if let Stmt::Import(stmt_import) = stmt { |
| 37 | + tracker.record_import(stmt_import); |
| 38 | + } |
| 39 | + } |
| 40 | + tracker |
| 41 | + .alias_modules |
| 42 | + .sort_by_key(|(module, _)| Reverse(module.as_str().len())); |
| 43 | + tracker |
| 44 | + } |
| 45 | + |
| 46 | + /// Record an `import ...` statement into the tracker. |
| 47 | + pub fn record_import(&mut self, stmt_import: &StmtImport) { |
| 48 | + for alias in &stmt_import.names { |
| 49 | + let module_name = ModuleName::from_str(alias.name.as_str()); |
| 50 | + if let Some(asname) = &alias.asname { |
| 51 | + self.alias_modules |
| 52 | + .push((module_name, asname.id.to_string())); |
| 53 | + } else { |
| 54 | + self.canonical_modules.insert(module_name); |
| 55 | + } |
| 56 | + } |
| 57 | + } |
| 58 | + |
| 59 | + /// Replace any module prefixes that have been imported under an alias (e.g. `import typing as t`). |
| 60 | + pub fn apply_aliases(&self, text: &str) -> String { |
| 61 | + if self.alias_modules.is_empty() { |
| 62 | + return text.to_owned(); |
| 63 | + } |
| 64 | + let bytes = text.as_bytes(); |
| 65 | + let mut result = String::with_capacity(text.len()); |
| 66 | + let mut i = 0; |
| 67 | + while i < bytes.len() { |
| 68 | + let mut replaced = false; |
| 69 | + for (module, alias) in &self.alias_modules { |
| 70 | + let module_str = module.as_str(); |
| 71 | + if module_str.is_empty() { |
| 72 | + continue; |
| 73 | + } |
| 74 | + let module_bytes = module_str.as_bytes(); |
| 75 | + if i + module_bytes.len() <= bytes.len() |
| 76 | + && &bytes[i..i + module_bytes.len()] == module_bytes |
| 77 | + && Self::is_boundary(bytes, i, i + module_bytes.len()) |
| 78 | + { |
| 79 | + result.push_str(alias); |
| 80 | + i += module_bytes.len(); |
| 81 | + replaced = true; |
| 82 | + break; |
| 83 | + } |
| 84 | + } |
| 85 | + if !replaced { |
| 86 | + result.push(bytes[i] as char); |
| 87 | + i += 1; |
| 88 | + } |
| 89 | + } |
| 90 | + result |
| 91 | + } |
| 92 | + |
| 93 | + /// Modules that are referenced in the type string but not yet imported (excluding builtins/current). |
| 94 | + pub fn missing_modules( |
| 95 | + &self, |
| 96 | + modules: &SmallSet<ModuleName>, |
| 97 | + current_module: ModuleName, |
| 98 | + ) -> SmallSet<ModuleName> { |
| 99 | + let mut missing = SmallSet::new(); |
| 100 | + for module in modules.iter() { |
| 101 | + let module = module.dupe(); |
| 102 | + if module.as_str().is_empty() |
| 103 | + || module == current_module |
| 104 | + || module == ModuleName::builtins() |
| 105 | + || module == ModuleName::extra_builtins() |
| 106 | + { |
| 107 | + continue; |
| 108 | + } |
| 109 | + if self.module_is_imported(module) { |
| 110 | + continue; |
| 111 | + } |
| 112 | + missing.insert(module); |
| 113 | + } |
| 114 | + missing |
| 115 | + } |
| 116 | + |
| 117 | + fn module_is_imported(&self, module: ModuleName) -> bool { |
| 118 | + self.alias_for(module).is_some() || self.has_canonical(module) |
| 119 | + } |
| 120 | + |
| 121 | + fn alias_for(&self, module: ModuleName) -> Option<String> { |
| 122 | + let target = module.as_str(); |
| 123 | + for (alias_module, alias_name) in &self.alias_modules { |
| 124 | + let alias_module_str = alias_module.as_str(); |
| 125 | + if alias_module_str.is_empty() { |
| 126 | + continue; |
| 127 | + } |
| 128 | + if target == alias_module_str { |
| 129 | + return Some(alias_name.clone()); |
| 130 | + } |
| 131 | + if target.len() > alias_module_str.len() |
| 132 | + && target.starts_with(alias_module_str) |
| 133 | + && target.as_bytes()[alias_module_str.len()] == b'.' |
| 134 | + { |
| 135 | + let remainder = &target[alias_module_str.len()..]; |
| 136 | + return Some(format!("{alias_name}{remainder}")); |
| 137 | + } |
| 138 | + } |
| 139 | + None |
| 140 | + } |
| 141 | + |
| 142 | + fn has_canonical(&self, module: ModuleName) -> bool { |
| 143 | + let target = module.as_str(); |
| 144 | + self.canonical_modules.iter().any(|imported| { |
| 145 | + let imported_str = imported.as_str(); |
| 146 | + imported_str == target |
| 147 | + || (target.len() > imported_str.len() |
| 148 | + && target.starts_with(imported_str) |
| 149 | + && target.as_bytes()[imported_str.len()] == b'.') |
| 150 | + }) |
| 151 | + } |
| 152 | + |
| 153 | + fn is_boundary(bytes: &[u8], start: usize, end: usize) -> bool { |
| 154 | + (start == 0 || !Self::is_ident(bytes[start - 1])) |
| 155 | + && (end == bytes.len() || !Self::is_ident(bytes[end])) |
| 156 | + } |
| 157 | + |
| 158 | + fn is_ident(byte: u8) -> bool { |
| 159 | + matches!(byte, b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_') |
| 160 | + } |
| 161 | +} |
| 162 | + |
| 163 | +/// Produce a user-facing type string (without module qualifiers) together with all referenced modules |
| 164 | +/// (captured with module qualification) so callers can insert the necessary imports. |
| 165 | +pub fn format_type_for_annotation(ty: &Type) -> (String, SmallSet<ModuleName>) { |
| 166 | + // First pass: force module names so referenced_modules collects everything, but ignore the text. |
| 167 | + let mut module_ctx = TypeDisplayContext::new(&[ty]); |
| 168 | + module_ctx.always_display_module_name_except_builtins(); |
| 169 | + let _ = module_ctx.display(ty).to_string(); |
| 170 | + let modules = module_ctx.referenced_modules(); |
| 171 | + |
| 172 | + // Second pass: produce a concise label without module qualifiers. |
| 173 | + let display_ctx = TypeDisplayContext::new(&[ty]); |
| 174 | + let text = display_ctx.display(ty).to_string(); |
| 175 | + (text, modules) |
| 176 | +} |
| 177 | + |
| 178 | +#[cfg(test)] |
| 179 | +mod tests { |
| 180 | + use super::*; |
| 181 | + |
| 182 | + #[test] |
| 183 | + fn aliases_are_applied_at_boundaries_only() { |
| 184 | + let module = ModuleName::from_str("typing"); |
| 185 | + let mut tracker = ImportTracker::default(); |
| 186 | + tracker.alias_modules.push((module, "t".to_owned())); |
| 187 | + assert_eq!(tracker.apply_aliases("typing.Literal"), "t.Literal"); |
| 188 | + // Do not replace inside longer identifiers |
| 189 | + assert_eq!(tracker.apply_aliases("mytyping"), "mytyping"); |
| 190 | + } |
| 191 | + |
| 192 | + #[test] |
| 193 | + fn missing_modules_skips_builtin_and_current() { |
| 194 | + let tracker = ImportTracker::default(); |
| 195 | + let mut modules = SmallSet::new(); |
| 196 | + let current = ModuleName::from_str("pkg.mod"); |
| 197 | + modules.insert(current.dupe()); |
| 198 | + modules.insert(ModuleName::builtins()); |
| 199 | + modules.insert(ModuleName::from_str("typing")); |
| 200 | + let missing = tracker.missing_modules(&modules, current); |
| 201 | + assert!(missing.contains(&ModuleName::from_str("typing"))); |
| 202 | + assert_eq!(missing.len(), 1); |
| 203 | + } |
| 204 | + |
| 205 | + #[test] |
| 206 | + fn format_type_collects_modules_but_returns_short_label() { |
| 207 | + let ty = Type::LiteralString; |
| 208 | + let (text, modules) = format_type_for_annotation(&ty); |
| 209 | + assert_eq!(text, "LiteralString"); |
| 210 | + assert!(modules.contains(&ModuleName::from_str("typing"))); |
| 211 | + } |
| 212 | +} |
0 commit comments