Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion pyrefly/lib/lsp/wasm/hover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use pyrefly_types::callable::Required;
use pyrefly_types::display::LspDisplayMode;
use pyrefly_types::types::Type;
use pyrefly_util::lined_buffer::LineNumber;
use pyrefly_util::visit::Visit;
use ruff_python_ast::AnyNodeRef;
use ruff_python_ast::Stmt;
use ruff_python_ast::name::Name;
Expand All @@ -41,6 +42,8 @@ use ruff_text_size::TextSize;
use crate::alt::answers_solver::AnswersSolver;
use crate::error::error::Error;
use crate::lsp::module_helpers::collect_symbol_def_paths;
use crate::lsp::wasm::signature_help::is_constructor_call;
use crate::lsp::wasm::signature_help::override_constructor_return_type;
use crate::state::lsp::DefinitionMetadata;
use crate::state::lsp::FindDefinitionItemWithDocstring;
use crate::state::lsp::FindPreference;
Expand Down Expand Up @@ -503,7 +506,41 @@ pub fn get_hover(
}

// Otherwise, fall through to the existing type hover logic
let type_ = transaction.get_type_at(handle, position)?;
let mut type_ = transaction.get_type_at(handle, position)?;

// Helper function to check if we're hovering over a callee and get its range
let find_callee_range_at_position = || -> Option<TextRange> {
use ruff_python_ast::Expr;
let mod_module = transaction.get_ast(handle)?;
let mut result = None;
mod_module.visit(&mut |expr: &Expr| {
if let Expr::Call(call) = expr {
// Check if position is within the callee (func) range
if call.func.range().contains(position) {
result = Some(call.func.range());
}
}
});
result
};

// Check both: hovering in arguments area OR hovering over the callee itself
let callee_range_opt = transaction
.get_callables_from_call(handle, position)
.map(|(_, _, _, range)| range)
.or_else(find_callee_range_at_position);

if let Some(callee_range) = callee_range_opt {
let is_constructor = transaction
.get_answers(handle)
.and_then(|ans| ans.get_type_trace(callee_range))
.is_some_and(is_constructor_call);

if is_constructor && let Some(new_type) = override_constructor_return_type(type_.clone()) {
type_ = new_type;
}
}

let fallback_name_from_type = fallback_hover_name_from_type(&type_);
let (kind, name, docstring_range, module) = if let Some(FindDefinitionItemWithDocstring {
metadata,
Expand Down
48 changes: 46 additions & 2 deletions pyrefly/lib/lsp/wasm/signature_help.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,34 @@ use crate::types::callable::Param;
use crate::types::callable::Params;
use crate::types::types::Type;

pub(crate) fn is_constructor_call(callee_type: Type) -> bool {
matches!(callee_type, Type::ClassDef(_))
|| matches!(callee_type, Type::Type(inner) if matches!(inner.as_ref(), Type::ClassType(_) | Type::ClassDef(_)))
}

pub(crate) fn override_constructor_return_type(type_: Type) -> Option<Type> {
let mut callable = type_.clone().to_callable()?;
if !callable.ret.is_none() {
return None;
}

let mut should_override = false;
if let Params::List(ref params_list) = callable.params
&& let Some(Param::Pos(name, self_type, _) | Param::PosOnly(Some(name), self_type, _)) =
params_list.items().first()
&& (name.as_str() == "self" || name.as_str() == "cls")
{
callable.ret = self_type.clone();
should_override = true;
}

if should_override {
Some(Type::Callable(Box::new(callable)))
} else {
None
}
}

/// The currently active argument in a function call for signature help.
#[derive(Debug)]
pub(crate) enum ActiveArgument {
Expand Down Expand Up @@ -280,11 +308,20 @@ impl Transaction<'_> {
active_argument: &ActiveArgument,
parameter_docs: Option<&HashMap<String, String>>,
function_docstring: Option<&Docstring>,
is_constructor_call: bool,
) -> SignatureInformation {
let type_ = type_.deterministic_printing();
let label = type_.as_lsp_string(LspDisplayMode::SignatureHelp);

// Display the return type as the class instance type instead of None
let display_type = if is_constructor_call {
override_constructor_return_type(type_.clone()).unwrap_or(type_)
} else {
type_
};

let label = display_type.as_lsp_string(LspDisplayMode::SignatureHelp);
let (parameters, active_parameter) = if let Some(params) =
Self::normalize_singleton_function_type_into_params(type_)
Self::normalize_singleton_function_type_into_params(display_type)
{
// Create a type display context for consistent parameter formatting
let param_types: Vec<&Type> = params.iter().map(|p| p.as_type()).collect();
Expand Down Expand Up @@ -335,6 +372,12 @@ impl Transaction<'_> {
|(callables, chosen_overload_index, active_argument, callee_range)| {
let parameter_docs = self.parameter_documentation_for_callee(handle, callee_range);
let function_docstring = self.function_docstring_for_callee(handle, callee_range);

let is_constructor_call = self
.get_answers(handle)
.and_then(|ans| ans.get_type_trace(callee_range))
.is_some_and(is_constructor_call);

let signatures = callables
.into_iter()
.map(|t| {
Expand All @@ -343,6 +386,7 @@ impl Transaction<'_> {
&active_argument,
parameter_docs.as_ref(),
function_docstring.as_ref(),
is_constructor_call,
)
})
.collect_vec();
Expand Down
114 changes: 114 additions & 0 deletions pyrefly/lib/test/lsp/hover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use ruff_text_size::TextSize;
use crate::lsp::wasm::hover::get_hover;
use crate::state::state::State;
use crate::test::util::get_batched_lsp_operations_report;
use crate::test::util::get_batched_lsp_operations_report_allow_error;

fn get_test_report(state: &State, handle: &Handle, position: TextSize) -> String {
match get_hover(&state.transaction(), handle, position, true) {
Expand Down Expand Up @@ -904,6 +905,23 @@ from mymod.submod.deep import Bar
);
}

#[test]
fn hover_on_constructor_shows_instance_type() {
let code = r#"
class Person:
def __init__(self, name: str, age: int) -> None: ...

Person()
#^
"#;
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
assert!(
report
.contains("def Person(\n self: Person,\n name: str,\n age: int\n) -> Person"),
"Expected constructor hover to show complete signature with -> Person, got: {report}"
);
}

#[test]
fn hover_over_in_operator_shows_contains_dunder() {
let code = r#"
Expand All @@ -922,6 +940,22 @@ c = Container()
);
}

#[test]
fn hover_on_constructor_with_arguments() {
let code = r#"
class Person:
def __init__(self, name: str, age: int) -> None: ...

Person("Alice", 25)
#^
"#;
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
assert!(
report.contains("-> Person"),
"Expected constructor hover to show -> Person, got: {report}"
);
}

#[test]
fn hover_over_in_keyword_in_for_loop() {
let code = r#"
Expand All @@ -937,6 +971,27 @@ for x in [1, 2, 3]:
);
}

#[test]
fn hover_on_direct_init_call_shows_none() {
let code = r#"
class Person:
def __init__(self, name: str) -> None: ...

p = Person.__new__(Person)
Person.__init__(p, "Alice")
# ^
"#;
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
assert!(
report.contains("-> None"),
"Expected direct __init__ call to show -> None, got: {report}"
);
assert!(
!report.contains("-> Person") || report.contains("__init__"),
"Direct __init__ call should show -> None, got: {report}"
);
}

#[test]
fn hover_over_in_keyword_in_list_comprehension() {
let code = r#"
Expand All @@ -951,6 +1006,65 @@ result = [x for x in [1, 2, 3] if x in [1]]
);
}

#[test]
fn hover_on_method_call_unchanged() {
let code = r#"
class Foo:
def method(self) -> str: ...

foo = Foo()
foo.method()
# ^
"#;
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
assert!(
report.contains("-> str"),
"Expected method hover to show -> str, got: {report}"
);
}

#[test]
fn hover_on_argument_shows_argument_type() {
let code = r#"
class Person:
def __init__(self, name: str) -> None: ...

Person("Alice")
# ^
"#;
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
// Hovering over a string literal shows its literal type
assert!(
report.contains("Literal['Alice']") || report.contains("str"),
"Expected argument hover to show literal type or str, got: {report}"
);
// The argument hover should not show the constructor signature
assert!(
!report.contains("__init__") || !report.contains("name: str"),
"Argument hover should show argument type, not constructor, got: {report}"
);
}

#[test]
fn hover_on_generic_constructor() {
let code = r#"
from typing import Generic, TypeVar

T = TypeVar("T")

class Box(Generic[T]):
def __init__(self, value: T) -> None: ...

Box[str]("hello")
#^
"#;
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
assert!(
report.contains("Box[str]"),
"Expected generic constructor to show Box[str], got: {report}"
);
}

#[test]
fn hover_over_in_keyword_for_membership_in_comprehension() {
let code = r#"
Expand Down
101 changes: 101 additions & 0 deletions pyrefly/lib/test/lsp/signature_help.rs
Original file line number Diff line number Diff line change
Expand Up @@ -903,3 +903,104 @@ Signature Help Result: active=0
report.trim(),
);
}

#[test]
fn constructor_signature_shows_instance_type() {
let code = r#"
class Person:
def __init__(self, name: str, age: int) -> None: ...

Person()
# ^
Person("Alice", )
# ^
"#;
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
assert_eq!(
r#"
# main.py
5 | Person()
^
Signature Help Result: active=0
- (self: Person, name: str, age: int) -> Person, parameters=[name: str, age: int], active parameter = 0

7 | Person("Alice", )
^
Signature Help Result: active=0
- (self: Person, name: str, age: int) -> Person, parameters=[name: str, age: int], active parameter = 1
"#
.trim(),
report.trim(),
);
}

#[test]
fn direct_init_call_shows_none() {
let code = r#"
class Person:
def __init__(self, name: str) -> None: ...

p = Person.__new__(Person)
Person.__init__(p, )
# ^
"#;
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
// Direct __init__ call should still show -> None
assert!(
report.contains("-> None"),
"Expected direct __init__ call to show -> None, got: {report}"
);
assert!(
report.contains("parameters=[name: str]"),
"Expected parameters, got: {report}"
);
}

#[test]
fn generic_constructor_signature() {
let code = r#"
from typing import Generic, TypeVar

T = TypeVar("T")

class Box(Generic[T]):
def __init__(self, value: T) -> None: ...

Box[str]()
# ^
Box[int](42)
# ^
"#;
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
// Generic constructors should show the specialized instance type
assert!(
report.contains("-> Box[str]") || report.contains("Box[str]"),
"Expected generic constructor to show Box[str], got: {report}"
);
assert!(
report.contains("-> Box[int]") || report.contains("Box[int]"),
"Expected generic constructor to show Box[int], got: {report}"
);
}

#[test]
fn method_call_signature_unchanged() {
let code = r#"
class Foo:
def method(self, x: int) -> str: ...

foo = Foo()
foo.method()
# ^
"#;
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
// Method calls should still show their original return type
assert!(
report.contains("-> str"),
"Expected method signature to show -> str, got: {report}"
);
assert!(
report.contains("parameters=[x: int]"),
"Expected parameters, got: {report}"
);
}
Loading