diff --git a/pyrefly/lib/lsp/wasm/semantic_tokens.rs b/pyrefly/lib/lsp/wasm/semantic_tokens.rs index 5877426696..12e4417329 100644 --- a/pyrefly/lib/lsp/wasm/semantic_tokens.rs +++ b/pyrefly/lib/lsp/wasm/semantic_tokens.rs @@ -5,10 +5,18 @@ * LICENSE file in the root directory of this source tree. */ +use std::collections::HashMap; + use lsp_types::SemanticToken; use pyrefly_build::handle::Handle; +use pyrefly_python::module_name::ModuleName; +use pyrefly_python::short_identifier::ShortIdentifier; +use pyrefly_python::symbol_kind::SymbolKind; +use pyrefly_util::visit::Visit; +use ruff_python_ast::Stmt; use ruff_text_size::TextRange; +use crate::binding::binding::Binding; use crate::binding::binding::Key; use crate::binding::bindings::Bindings; use crate::export::exports::Export; @@ -70,6 +78,7 @@ impl Transaction<'_> { let legends = SemanticTokensLegends::new(); let disabled_ranges = disabled_ranges_for_module(ast.as_ref(), handle.sys_info()); let mut builder = SemanticTokenBuilder::new(limit_range, disabled_ranges); + let mut symbol_kinds: HashMap = HashMap::new(); for NamedBinding { definition_handle, definition_export, @@ -81,9 +90,20 @@ impl Transaction<'_> { .. } = definition_export { - builder.process_key(&key, definition_handle.module(), symbol_kind) + let binding = bindings.get(bindings.key_to_idx(&key)); + let definition_module = match binding { + Binding::Import(module, _, _) | Binding::Module(module, ..) => *module, + _ => definition_handle.module(), + }; + if let Key::Definition(short) = &key { + symbol_kinds.insert(short.clone(), (definition_module, symbol_kind)); + } + builder.process_key(&key, definition_module, symbol_kind); } } + for stmt in &ast.body { + add_import_from_alias_tokens(&mut builder, stmt, &symbol_kinds); + } builder.process_ast(&ast, &|range| self.get_type_trace(handle, range)); Some(legends.convert_tokens_into_lsp_semantic_tokens( &builder.all_tokens_sorted(), @@ -92,3 +112,21 @@ impl Transaction<'_> { )) } } + +fn add_import_from_alias_tokens( + builder: &mut SemanticTokenBuilder, + stmt: &Stmt, + symbol_kinds: &HashMap, +) { + if let Stmt::ImportFrom(import_from) = stmt { + for alias in &import_from.names { + if let Some(asname) = &alias.asname { + let key = ShortIdentifier::new(asname); + if let Some((definition_module, symbol_kind)) = symbol_kinds.get(&key) { + builder.process_range(alias.name.range, *definition_module, *symbol_kind); + } + } + } + } + stmt.recurse(&mut |inner| add_import_from_alias_tokens(builder, inner, symbol_kinds)); +} diff --git a/pyrefly/lib/state/semantic_tokens.rs b/pyrefly/lib/state/semantic_tokens.rs index 63042c9864..f154137e8e 100644 --- a/pyrefly/lib/state/semantic_tokens.rs +++ b/pyrefly/lib/state/semantic_tokens.rs @@ -223,13 +223,12 @@ impl SemanticTokenBuilder { .any(|disabled| disabled.contains_range(range)) } - pub fn process_key( + fn push_symbol_range( &mut self, - key: &Key, + reference_range: TextRange, definition_module: ModuleName, symbol_kind: SymbolKind, ) { - let reference_range = key.range(); let (token_type, mut token_modifiers) = symbol_kind.to_lsp_semantic_token_type_with_modifiers(); let is_default_library = { @@ -244,6 +243,24 @@ impl SemanticTokenBuilder { self.push_if_in_range(reference_range, token_type, token_modifiers); } + pub fn process_key( + &mut self, + key: &Key, + definition_module: ModuleName, + symbol_kind: SymbolKind, + ) { + self.push_symbol_range(key.range(), definition_module, symbol_kind); + } + + pub fn process_range( + &mut self, + range: TextRange, + definition_module: ModuleName, + symbol_kind: SymbolKind, + ) { + self.push_symbol_range(range, definition_module, symbol_kind); + } + fn process_arguments(&mut self, args: &Arguments) { for keyword in &args.keywords { if let Some(arg) = &keyword.arg { @@ -341,19 +358,10 @@ impl SemanticTokenBuilder { } } } - Stmt::ImportFrom(StmtImportFrom { module, names, .. }) => { + Stmt::ImportFrom(StmtImportFrom { module, .. }) => { if let Some(module) = module { self.push_if_in_range(module.range, SemanticTokenType::NAMESPACE, vec![]); } - for alias in names { - if alias.asname.is_some() { - self.push_if_in_range( - alias.name.range, - SemanticTokenType::NAMESPACE, - vec![], - ); - } - } } Stmt::AnnAssign(ann_assign) => { if let Expr::Name(name) = &*ann_assign.target { diff --git a/pyrefly/lib/test/lsp/semantic_tokens.rs b/pyrefly/lib/test/lsp/semantic_tokens.rs index 22418a67b4..66880a65cb 100644 --- a/pyrefly/lib/test/lsp/semantic_tokens.rs +++ b/pyrefly/lib/test/lsp/semantic_tokens.rs @@ -883,7 +883,7 @@ line: 1, column: 5, length: 3, text: lib token-type: namespace line: 1, column: 16, length: 4, text: func -token-type: namespace +token-type: function line: 1, column: 24, length: 4, text: func token-type: function @@ -913,7 +913,7 @@ line: 1, column: 5, length: 3, text: foo token-type: namespace line: 1, column: 16, length: 3, text: bar -token-type: namespace +token-type: function line: 1, column: 23, length: 3, text: baz token-type: function