diff --git a/pyrefly/lib/lsp/wasm/hover.rs b/pyrefly/lib/lsp/wasm/hover.rs index f2c8b83671..fff3fc3e59 100644 --- a/pyrefly/lib/lsp/wasm/hover.rs +++ b/pyrefly/lib/lsp/wasm/hover.rs @@ -31,6 +31,7 @@ use pyrefly_types::callable::Required; use pyrefly_types::display::LspDisplayMode; use pyrefly_types::types::Type; use pyrefly_util::lined_buffer::LineNumber; +use pyrefly_util::visit::Visit; use ruff_python_ast::AnyNodeRef; use ruff_python_ast::Stmt; use ruff_python_ast::name::Name; @@ -41,6 +42,8 @@ use ruff_text_size::TextSize; use crate::alt::answers_solver::AnswersSolver; use crate::error::error::Error; use crate::lsp::module_helpers::collect_symbol_def_paths; +use crate::lsp::wasm::signature_help::is_constructor_call; +use crate::lsp::wasm::signature_help::override_constructor_return_type; use crate::state::lsp::DefinitionMetadata; use crate::state::lsp::FindDefinitionItemWithDocstring; use crate::state::lsp::FindPreference; @@ -503,7 +506,41 @@ pub fn get_hover( } // Otherwise, fall through to the existing type hover logic - let type_ = transaction.get_type_at(handle, position)?; + let mut type_ = transaction.get_type_at(handle, position)?; + + // Helper function to check if we're hovering over a callee and get its range + let find_callee_range_at_position = || -> Option { + use ruff_python_ast::Expr; + let mod_module = transaction.get_ast(handle)?; + let mut result = None; + mod_module.visit(&mut |expr: &Expr| { + if let Expr::Call(call) = expr { + // Check if position is within the callee (func) range + if call.func.range().contains(position) { + result = Some(call.func.range()); + } + } + }); + result + }; + + // Check both: hovering in arguments area OR hovering over the callee itself + let callee_range_opt = transaction + .get_callables_from_call(handle, position) + .map(|(_, _, _, range)| range) + .or_else(find_callee_range_at_position); + + if let Some(callee_range) = callee_range_opt { + let is_constructor = transaction + .get_answers(handle) + .and_then(|ans| ans.get_type_trace(callee_range)) + .is_some_and(is_constructor_call); + + if is_constructor && let Some(new_type) = override_constructor_return_type(type_.clone()) { + type_ = new_type; + } + } + let fallback_name_from_type = fallback_hover_name_from_type(&type_); let (kind, name, docstring_range, module) = if let Some(FindDefinitionItemWithDocstring { metadata, diff --git a/pyrefly/lib/lsp/wasm/signature_help.rs b/pyrefly/lib/lsp/wasm/signature_help.rs index 594c6e9f59..dc75c6f926 100644 --- a/pyrefly/lib/lsp/wasm/signature_help.rs +++ b/pyrefly/lib/lsp/wasm/signature_help.rs @@ -37,6 +37,34 @@ use crate::types::callable::Param; use crate::types::callable::Params; use crate::types::types::Type; +pub(crate) fn is_constructor_call(callee_type: Type) -> bool { + matches!(callee_type, Type::ClassDef(_)) + || matches!(callee_type, Type::Type(inner) if matches!(inner.as_ref(), Type::ClassType(_) | Type::ClassDef(_))) +} + +pub(crate) fn override_constructor_return_type(type_: Type) -> Option { + let mut callable = type_.clone().to_callable()?; + if !callable.ret.is_none() { + return None; + } + + let mut should_override = false; + if let Params::List(ref params_list) = callable.params + && let Some(Param::Pos(name, self_type, _) | Param::PosOnly(Some(name), self_type, _)) = + params_list.items().first() + && (name.as_str() == "self" || name.as_str() == "cls") + { + callable.ret = self_type.clone(); + should_override = true; + } + + if should_override { + Some(Type::Callable(Box::new(callable))) + } else { + None + } +} + /// The currently active argument in a function call for signature help. #[derive(Debug)] pub(crate) enum ActiveArgument { @@ -280,11 +308,20 @@ impl Transaction<'_> { active_argument: &ActiveArgument, parameter_docs: Option<&HashMap>, function_docstring: Option<&Docstring>, + is_constructor_call: bool, ) -> SignatureInformation { let type_ = type_.deterministic_printing(); - let label = type_.as_lsp_string(LspDisplayMode::SignatureHelp); + + // Display the return type as the class instance type instead of None + let display_type = if is_constructor_call { + override_constructor_return_type(type_.clone()).unwrap_or(type_) + } else { + type_ + }; + + let label = display_type.as_lsp_string(LspDisplayMode::SignatureHelp); let (parameters, active_parameter) = if let Some(params) = - Self::normalize_singleton_function_type_into_params(type_) + Self::normalize_singleton_function_type_into_params(display_type) { // Create a type display context for consistent parameter formatting let param_types: Vec<&Type> = params.iter().map(|p| p.as_type()).collect(); @@ -335,6 +372,12 @@ impl Transaction<'_> { |(callables, chosen_overload_index, active_argument, callee_range)| { let parameter_docs = self.parameter_documentation_for_callee(handle, callee_range); let function_docstring = self.function_docstring_for_callee(handle, callee_range); + + let is_constructor_call = self + .get_answers(handle) + .and_then(|ans| ans.get_type_trace(callee_range)) + .is_some_and(is_constructor_call); + let signatures = callables .into_iter() .map(|t| { @@ -343,6 +386,7 @@ impl Transaction<'_> { &active_argument, parameter_docs.as_ref(), function_docstring.as_ref(), + is_constructor_call, ) }) .collect_vec(); diff --git a/pyrefly/lib/test/lsp/hover.rs b/pyrefly/lib/test/lsp/hover.rs index c40c699b9c..ddab560af7 100644 --- a/pyrefly/lib/test/lsp/hover.rs +++ b/pyrefly/lib/test/lsp/hover.rs @@ -14,6 +14,7 @@ use ruff_text_size::TextSize; use crate::lsp::wasm::hover::get_hover; use crate::state::state::State; use crate::test::util::get_batched_lsp_operations_report; +use crate::test::util::get_batched_lsp_operations_report_allow_error; fn get_test_report(state: &State, handle: &Handle, position: TextSize) -> String { match get_hover(&state.transaction(), handle, position, true) { @@ -904,6 +905,23 @@ from mymod.submod.deep import Bar ); } +#[test] +fn hover_on_constructor_shows_instance_type() { + let code = r#" +class Person: + def __init__(self, name: str, age: int) -> None: ... + +Person() +#^ +"#; + let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report); + assert!( + report + .contains("def Person(\n self: Person,\n name: str,\n age: int\n) -> Person"), + "Expected constructor hover to show complete signature with -> Person, got: {report}" + ); +} + #[test] fn hover_over_in_operator_shows_contains_dunder() { let code = r#" @@ -922,6 +940,22 @@ c = Container() ); } +#[test] +fn hover_on_constructor_with_arguments() { + let code = r#" +class Person: + def __init__(self, name: str, age: int) -> None: ... + +Person("Alice", 25) +#^ +"#; + let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report); + assert!( + report.contains("-> Person"), + "Expected constructor hover to show -> Person, got: {report}" + ); +} + #[test] fn hover_over_in_keyword_in_for_loop() { let code = r#" @@ -937,6 +971,27 @@ for x in [1, 2, 3]: ); } +#[test] +fn hover_on_direct_init_call_shows_none() { + let code = r#" +class Person: + def __init__(self, name: str) -> None: ... + +p = Person.__new__(Person) +Person.__init__(p, "Alice") +# ^ +"#; + let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report); + assert!( + report.contains("-> None"), + "Expected direct __init__ call to show -> None, got: {report}" + ); + assert!( + !report.contains("-> Person") || report.contains("__init__"), + "Direct __init__ call should show -> None, got: {report}" + ); +} + #[test] fn hover_over_in_keyword_in_list_comprehension() { let code = r#" @@ -951,6 +1006,65 @@ result = [x for x in [1, 2, 3] if x in [1]] ); } +#[test] +fn hover_on_method_call_unchanged() { + let code = r#" +class Foo: + def method(self) -> str: ... + +foo = Foo() +foo.method() +# ^ +"#; + let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report); + assert!( + report.contains("-> str"), + "Expected method hover to show -> str, got: {report}" + ); +} + +#[test] +fn hover_on_argument_shows_argument_type() { + let code = r#" +class Person: + def __init__(self, name: str) -> None: ... + +Person("Alice") +# ^ +"#; + let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report); + // Hovering over a string literal shows its literal type + assert!( + report.contains("Literal['Alice']") || report.contains("str"), + "Expected argument hover to show literal type or str, got: {report}" + ); + // The argument hover should not show the constructor signature + assert!( + !report.contains("__init__") || !report.contains("name: str"), + "Argument hover should show argument type, not constructor, got: {report}" + ); +} + +#[test] +fn hover_on_generic_constructor() { + let code = r#" +from typing import Generic, TypeVar + +T = TypeVar("T") + +class Box(Generic[T]): + def __init__(self, value: T) -> None: ... + +Box[str]("hello") +#^ +"#; + let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report); + assert!( + report.contains("Box[str]"), + "Expected generic constructor to show Box[str], got: {report}" + ); +} + #[test] fn hover_over_in_keyword_for_membership_in_comprehension() { let code = r#" diff --git a/pyrefly/lib/test/lsp/signature_help.rs b/pyrefly/lib/test/lsp/signature_help.rs index 1699d09834..4d5b0356c6 100644 --- a/pyrefly/lib/test/lsp/signature_help.rs +++ b/pyrefly/lib/test/lsp/signature_help.rs @@ -903,3 +903,104 @@ Signature Help Result: active=0 report.trim(), ); } + +#[test] +fn constructor_signature_shows_instance_type() { + let code = r#" +class Person: + def __init__(self, name: str, age: int) -> None: ... + +Person() +# ^ +Person("Alice", ) +# ^ +"#; + let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report); + assert_eq!( + r#" +# main.py +5 | Person() + ^ +Signature Help Result: active=0 +- (self: Person, name: str, age: int) -> Person, parameters=[name: str, age: int], active parameter = 0 + +7 | Person("Alice", ) + ^ +Signature Help Result: active=0 +- (self: Person, name: str, age: int) -> Person, parameters=[name: str, age: int], active parameter = 1 +"# + .trim(), + report.trim(), + ); +} + +#[test] +fn direct_init_call_shows_none() { + let code = r#" +class Person: + def __init__(self, name: str) -> None: ... + +p = Person.__new__(Person) +Person.__init__(p, ) +# ^ +"#; + let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report); + // Direct __init__ call should still show -> None + assert!( + report.contains("-> None"), + "Expected direct __init__ call to show -> None, got: {report}" + ); + assert!( + report.contains("parameters=[name: str]"), + "Expected parameters, got: {report}" + ); +} + +#[test] +fn generic_constructor_signature() { + let code = r#" +from typing import Generic, TypeVar + +T = TypeVar("T") + +class Box(Generic[T]): + def __init__(self, value: T) -> None: ... + +Box[str]() +# ^ +Box[int](42) +# ^ +"#; + let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report); + // Generic constructors should show the specialized instance type + assert!( + report.contains("-> Box[str]") || report.contains("Box[str]"), + "Expected generic constructor to show Box[str], got: {report}" + ); + assert!( + report.contains("-> Box[int]") || report.contains("Box[int]"), + "Expected generic constructor to show Box[int], got: {report}" + ); +} + +#[test] +fn method_call_signature_unchanged() { + let code = r#" +class Foo: + def method(self, x: int) -> str: ... + +foo = Foo() +foo.method() +# ^ +"#; + let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report); + // Method calls should still show their original return type + assert!( + report.contains("-> str"), + "Expected method signature to show -> str, got: {report}" + ); + assert!( + report.contains("parameters=[x: int]"), + "Expected parameters, got: {report}" + ); +}