diff --git a/pyrefly/lib/alt/call.rs b/pyrefly/lib/alt/call.rs index 716594001..76ad6e3d7 100644 --- a/pyrefly/lib/alt/call.rs +++ b/pyrefly/lib/alt/call.rs @@ -867,54 +867,253 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { metadata, }, ), - ) => self.callable_infer( - signature, - Some(&metadata.kind), - tparams.as_deref(), - Some(obj), - args, - keywords, - range, - errors, - errors, - context, - hint, - ctor_targs, - ), - CallTarget::Callable(TargetWithTParams(tparams, callable)) => self.callable_infer( - callable, - None, - tparams.as_deref(), - None, - args, - keywords, - range, - errors, - errors, - context, - hint, - ctor_targs, - ), + ) => { + // If we have a hint and type parameters, try with hint first. + // If there are call errors (but NOT BadSpecialization errors which indicate + // legitimate bound violations), retry without hint. + // This prevents outer context hints from causing argument type errors. + if tparams.is_some() && hint.is_some() { + let call_errors = self.error_collector(); + let res = self.callable_infer( + signature.clone(), + Some(&metadata.kind), + tparams.as_deref(), + Some(obj.clone()), + args, + keywords, + range, + errors, + &call_errors, + context, + hint, + None, // Don't pass ctor_targs in trial call + ); + // Only fall back if there are errors that are NOT BadSpecialization. + // BadSpecialization errors indicate legitimate type variable bound violations + // that should be preserved, not worked around. + let should_fallback = !call_errors.is_empty() + && !call_errors.has_error_kind(ErrorKind::BadSpecialization); + if call_errors.is_empty() { + // First try succeeded, but we need to redo with ctor_targs if present + if ctor_targs.is_some() { + self.callable_infer( + signature, + Some(&metadata.kind), + tparams.as_deref(), + Some(obj), + args, + keywords, + range, + errors, + errors, + context, + hint, + ctor_targs, + ) + } else { + errors.extend(call_errors); + res + } + } else if should_fallback { + // Retry without hint - errors are hint-induced, not bound violations + self.callable_infer( + signature, + Some(&metadata.kind), + tparams.as_deref(), + Some(obj), + args, + keywords, + range, + errors, + errors, + context, + None, + ctor_targs, + ) + } else { + // Keep the errors (including bound violations) + errors.extend(call_errors); + res + } + } else { + self.callable_infer( + signature, + Some(&metadata.kind), + tparams.as_deref(), + Some(obj), + args, + keywords, + range, + errors, + errors, + context, + hint, + ctor_targs, + ) + } + } + CallTarget::Callable(TargetWithTParams(tparams, callable)) => { + if tparams.is_some() && hint.is_some() { + let call_errors = self.error_collector(); + let res = self.callable_infer( + callable.clone(), + None, + tparams.as_deref(), + None, + args, + keywords, + range, + errors, + &call_errors, + context, + hint, + None, + ); + // Only fall back if there are errors that are NOT BadSpecialization. + let should_fallback = !call_errors.is_empty() + && !call_errors.has_error_kind(ErrorKind::BadSpecialization); + if call_errors.is_empty() { + if ctor_targs.is_some() { + self.callable_infer( + callable, + None, + tparams.as_deref(), + None, + args, + keywords, + range, + errors, + errors, + context, + hint, + ctor_targs, + ) + } else { + errors.extend(call_errors); + res + } + } else if should_fallback { + // Retry without hint - errors are hint-induced, not bound violations + self.callable_infer( + callable, + None, + tparams.as_deref(), + None, + args, + keywords, + range, + errors, + errors, + context, + None, + ctor_targs, + ) + } else { + // Keep the errors (including bound violations) + errors.extend(call_errors); + res + } + } else { + self.callable_infer( + callable, + None, + tparams.as_deref(), + None, + args, + keywords, + range, + errors, + errors, + context, + hint, + ctor_targs, + ) + } + } CallTarget::Function(TargetWithTParams( tparams, Function { signature: callable, metadata, }, - )) => self.callable_infer( - callable, - Some(&metadata.kind), - tparams.as_deref(), - None, - args, - keywords, - range, - errors, - errors, - context, - hint, - ctor_targs, - ), + )) => { + if tparams.is_some() && hint.is_some() { + let call_errors = self.error_collector(); + let res = self.callable_infer( + callable.clone(), + Some(&metadata.kind), + tparams.as_deref(), + None, + args, + keywords, + range, + errors, + &call_errors, + context, + hint, + None, + ); + // Only fall back if there are errors that are NOT BadSpecialization. + let should_fallback = !call_errors.is_empty() + && !call_errors.has_error_kind(ErrorKind::BadSpecialization); + if call_errors.is_empty() { + if ctor_targs.is_some() { + self.callable_infer( + callable, + Some(&metadata.kind), + tparams.as_deref(), + None, + args, + keywords, + range, + errors, + errors, + context, + hint, + ctor_targs, + ) + } else { + errors.extend(call_errors); + res + } + } else if should_fallback { + // Retry without hint - errors are hint-induced, not bound violations + self.callable_infer( + callable, + Some(&metadata.kind), + tparams.as_deref(), + None, + args, + keywords, + range, + errors, + errors, + context, + None, + ctor_targs, + ) + } else { + // Keep the errors (including bound violations) + errors.extend(call_errors); + res + } + } else { + self.callable_infer( + callable, + Some(&metadata.kind), + tparams.as_deref(), + None, + args, + keywords, + range, + errors, + errors, + context, + hint, + ctor_targs, + ) + } + } CallTarget::FunctionOverload(overloads, metadata) => { self.call_overloads( overloads, metadata, None, args, keywords, range, errors, context, hint, diff --git a/pyrefly/lib/error/collector.rs b/pyrefly/lib/error/collector.rs index a13085193..e859289e2 100644 --- a/pyrefly/lib/error/collector.rs +++ b/pyrefly/lib/error/collector.rs @@ -171,6 +171,11 @@ impl ErrorCollector { self.errors.lock().len() } + /// Check if any error has the given error kind. + pub fn has_error_kind(&self, kind: ErrorKind) -> bool { + self.errors.lock().iter().any(|e| e.error_kind() == kind) + } + pub fn collect_into(&self, error_config: &ErrorConfig, result: &mut CollectedErrors) { let mut errors = self.errors.lock(); if !(self.module_info.is_generated() && error_config.ignore_errors_in_generated_code) { diff --git a/pyrefly/lib/lsp/wasm/inlay_hints.rs b/pyrefly/lib/lsp/wasm/inlay_hints.rs index 77d37b36c..2d197769d 100644 --- a/pyrefly/lib/lsp/wasm/inlay_hints.rs +++ b/pyrefly/lib/lsp/wasm/inlay_hints.rs @@ -9,6 +9,7 @@ use std::iter::once; use std::sync::Arc; use pyrefly_build::handle::Handle; +use pyrefly_graph::index::Idx; use pyrefly_python::ast::Ast; use pyrefly_python::module::TextRangeWithModule; use pyrefly_types::literal::Lit; @@ -29,6 +30,7 @@ use ruff_text_size::TextSize; use crate::binding::binding::Binding; use crate::binding::binding::Key; +use crate::binding::binding::UnpackedPosition; use crate::state::lsp::AllOffPartial; use crate::state::lsp::InlayHintConfig; use crate::state::state::CancellableTransaction; @@ -165,14 +167,21 @@ impl<'a> Transaction<'a> { if inlay_hint_config.variable_types && let Some(ty) = self.get_type(handle, key) => { - let e = match bindings.get(idx) { + // For unpacked values, extract the element expression if available + let (e, is_unpacked) = match bindings.get(idx) { Binding::NameAssign { annotation: None, expr: e, .. - } => Some(&**e), - Binding::Expr(None, e) => Some(e), - _ => None, + } => (Some(&**e), false), + Binding::Expr(None, e) => (Some(e), false), + Binding::UnpackedValue(None, unpack_idx, _, pos) => { + // Try to get the element expression from the unpacked source + let element_expr = + Self::get_unpacked_element_expr(&bindings, *unpack_idx, *pos); + (element_expr, true) + } + _ => (None, false), }; // If the inferred type is a class type w/ no type arguments and the // RHS is a call to a function that's the same name as the inferred class, @@ -184,9 +193,17 @@ impl<'a> Transaction<'a> { } else { None }; - if let Some(e) = e - && is_interesting(e, &ty, class_name) - { + // For unpacked values without a known element expression (e.g., from + // function calls or nested unpacking), show the hint if the type is not Any. + // For regular assignments, require the expression to be interesting. + let should_show = if let Some(e) = e { + is_interesting(e, &ty, class_name) + } else { + // For unpacked values where we couldn't extract the element, + // show hint if type is not Any + is_unpacked && !ty.is_any() + }; + if should_show { // Use get_types_with_locations to get type parts with location info let type_parts = ty.get_types_with_locations(Some(&stdlib)); let label_parts = once((": ".to_owned(), None)) @@ -214,6 +231,43 @@ impl<'a> Transaction<'a> { Some(res) } + /// Helper to extract the element expression from an unpacked source. + /// Returns the expression at the given position if the source is a tuple or list literal. + /// For nested unpacking or function calls, returns None (caller should fall back to + /// showing hints based on type information alone). + fn get_unpacked_element_expr<'b>( + bindings: &'b crate::binding::bindings::Bindings, + unpack_idx: Idx, + pos: UnpackedPosition, + ) -> Option<&'b Expr> { + // Get the binding for the unpacked source + let source_binding = bindings.get(unpack_idx); + // For top-level unpacking, the source is Binding::Expr containing the RHS. + // For nested unpacking, it's Binding::UnpackedValue - we return None in that case. + let source_expr = match source_binding { + Binding::Expr(_, e) => Some(e), + _ => None, + }?; + + // Try to extract elements from tuple or list literals + let elts = match source_expr { + Expr::Tuple(tup) => Some(&tup.elts), + Expr::List(lst) => Some(&lst.elts), + _ => None, + }?; + + // Extract the element at the given position + // This mirrors the logic in solve.rs for Binding::UnpackedValue + match pos { + UnpackedPosition::Index(i) => elts.get(i), + UnpackedPosition::ReverseIndex(i) => { + elts.len().checked_sub(i).and_then(|idx| elts.get(idx)) + } + // For slices (starred unpacking), we can't return a single element + UnpackedPosition::Slice(_, _) => None, + } + } + fn collect_function_calls_from_ast(module: Arc) -> Vec { fn collect_function_calls(x: &Expr, calls: &mut Vec) { if let Expr::Call(call) = x { diff --git a/pyrefly/lib/test/generic_basic.rs b/pyrefly/lib/test/generic_basic.rs index 26106f0c8..c78634b39 100644 --- a/pyrefly/lib/test/generic_basic.rs +++ b/pyrefly/lib/test/generic_basic.rs @@ -106,6 +106,21 @@ v: C[int] = C() append(v, "test") # E: `Literal['test']` is not assignable to parameter `y` with type `int` "#, ); +testcase!( + test_call_hint_does_not_override_arg, + r#" +from typing import Any, reveal_type + +class Map[K, V]: + def set(self, key: K, value: V) -> None: ... + def get[T](self, key: Any, default: T, /) -> V | T: ... + +d_any: Map[str, Any] = Map() + +reveal_type(d_any.get("key", None)) # E: revealed type: Any | None +result: str = reveal_type(d_any.get("key", None)) # E: revealed type: Any | None # E: `Any | None` is not assignable to `str` +"#, +); testcase!( test_generic_default, r#" diff --git a/pyrefly/lib/test/lsp/inlay_hint.rs b/pyrefly/lib/test/lsp/inlay_hint.rs index 5b0d845b1..e1b940539 100644 --- a/pyrefly/lib/test/lsp/inlay_hint.rs +++ b/pyrefly/lib/test/lsp/inlay_hint.rs @@ -113,6 +113,155 @@ imported = ssl.VerifyMode.CERT_NONE ); } +#[test] +fn test_tuple_unpacking_inlay_hint() { + let code = r#" +a = 1 +b = 1 + +x, y = (a, b) +z = a +"#; + // Individual hints for each unpacked variable + assert_eq!( + r#" +# main.py +5 | x, y = (a, b) + ^ inlay-hint: `: Literal[1]` + +5 | x, y = (a, b) + ^ inlay-hint: `: Literal[1]` + +6 | z = a + ^ inlay-hint: `: Literal[1]` +"# + .trim(), + generate_inlay_hint_report(code, Default::default()).trim() + ); +} + +#[test] +fn test_tuple_unpacking_from_function_call() { + let code = r#" +def f() -> tuple[int, str]: + return (1, "test") + +x, y = f() +"#; + // Individual hints for unpacked values from function calls + assert_eq!( + r#" +# main.py +5 | x, y = f() + ^ inlay-hint: `: int` + +5 | x, y = f() + ^ inlay-hint: `: str` +"# + .trim(), + generate_inlay_hint_report(code, Default::default()).trim() + ); +} + +#[test] +fn test_tuple_unpacking_no_hint_for_literals() { + let code = r#" +x, y = (1, 2) +"#; + // No hints when unpacking literal values + assert_eq!( + r#" +# main.py +"# + .trim(), + generate_inlay_hint_report(code, Default::default()).trim() + ); +} + +#[test] +fn test_tuple_unpacking_with_prior_annotation() { + let code = r#" +x: int +y: str +x, y = (1, "test") +"#; + // No hints because variables already have annotations + assert_eq!( + r#" +# main.py +"# + .trim(), + generate_inlay_hint_report(code, Default::default()).trim() + ); +} + +#[test] +fn test_nested_tuple_unpacking() { + let code = r#" +def f() -> tuple[int, str]: + return (1, "test") + +(a, b), c = f(), 3 +"#; + // Individual hints for nested unpacked values from function call. + // No hint for c because it's unpacked from a literal (3). + assert_eq!( + r#" +# main.py +5 | (a, b), c = f(), 3 + ^ inlay-hint: `: int` + +5 | (a, b), c = f(), 3 + ^ inlay-hint: `: str` +"# + .trim(), + generate_inlay_hint_report(code, Default::default()).trim() + ); +} + +#[test] +fn test_starred_unpacking_from_function() { + let code = r#" +def get_list() -> list[int]: + return [1, 2, 3, 4] + +a, *b, c = get_list() +"#; + // All variables get hints since we can't determine if elements are literals + assert_eq!( + r#" +# main.py +5 | a, *b, c = get_list() + ^ inlay-hint: `: int` + +5 | a, *b, c = get_list() + ^ inlay-hint: `: list[int]` + +5 | a, *b, c = get_list() + ^ inlay-hint: `: int` +"# + .trim(), + generate_inlay_hint_report(code, Default::default()).trim() + ); +} + +#[test] +fn test_starred_unpacking_from_literal() { + let code = r#" +a, *b, c = [1, 2, 3, 4] +"#; + // No hints for a and c (literals), but b gets hint since we can't extract slice elements + assert_eq!( + r#" +# main.py +2 | a, *b, c = [1, 2, 3, 4] + ^ inlay-hint: `: list[int]` +"# + .trim(), + generate_inlay_hint_report(code, Default::default()).trim() + ); +} + #[test] fn test_parameter_name_hints() { let code = r#"