diff --git a/pyrefly/lib/alt/answers_solver.rs b/pyrefly/lib/alt/answers_solver.rs index 9b639b469..8dbe008ff 100644 --- a/pyrefly/lib/alt/answers_solver.rs +++ b/pyrefly/lib/alt/answers_solver.rs @@ -857,4 +857,28 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { pub fn error_swallower(&self) -> ErrorCollector { ErrorCollector::new(self.module().dupe(), ErrorStyle::Never) } + + pub fn prefer_union_branch_without_vars(&self, ty: &Type) -> Option { + if let Type::Union(options) = ty { + let mut reordered = options.clone(); + reordered.sort_by_key(|option| self.type_contains_var(option)); + if reordered == *options { + None + } else { + Some(Type::Union(reordered)) + } + } else { + None + } + } + + pub(crate) fn type_contains_var(&self, ty: &Type) -> bool { + let mut has_var = false; + ty.universe(&mut |t| { + if matches!(t, Type::Var(_)) { + has_var = true; + } + }); + has_var + } } diff --git a/pyrefly/lib/alt/call.rs b/pyrefly/lib/alt/call.rs index 157a208c2..c61777e11 100644 --- a/pyrefly/lib/alt/call.rs +++ b/pyrefly/lib/alt/call.rs @@ -1293,7 +1293,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { None, ) } - }) + }) } } diff --git a/pyrefly/lib/alt/callable.rs b/pyrefly/lib/alt/callable.rs index 834e34ed6..618bb7a4f 100644 --- a/pyrefly/lib/alt/callable.rs +++ b/pyrefly/lib/alt/callable.rs @@ -312,6 +312,7 @@ impl CallArgPreEval<'_> { solver: &AnswersSolver, callable_name: Option<&FunctionKind>, hint: &Type, + use_hint: bool, param_name: Option<&Name>, vararg: bool, range: TextRange, @@ -334,11 +335,16 @@ impl CallArgPreEval<'_> { } Self::Expr(x, done) => { *done = true; - solver.expr_with_separate_check_errors( - x, - Some((hint, call_errors, tcc)), - arg_errors, - ); + if use_hint && !hint.is_any() { + solver.expr_with_separate_check_errors( + x, + Some((hint, call_errors, tcc)), + arg_errors, + ); + } else { + let ty = solver.expr_infer(x, arg_errors); + solver.check_type(&ty, hint, range, call_errors, tcc); + } } Self::Star(ty, done) => { *done = vararg; @@ -541,17 +547,39 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { // We ignore positional-only parameters because they can't be passed in by name. seen_names.insert(name, ty); } + let ty = if let Some(reordered) = self.prefer_union_branch_without_vars(ty) + { + type_owner.push(reordered) + } else { + ty + }; + let expanded = self.solver().expand_vars((*ty).clone()); + let (hint_ty, mut use_hint) = if expanded == *ty { + (ty, false) + } else { + (type_owner.push(expanded), true) + }; + if !use_hint { + if !self.type_contains_var(hint_ty) { + use_hint = true; + } else if let Type::Union(options) = hint_ty { + if options.iter().any(|option| !self.type_contains_var(option)) { + use_hint = true; + } + } + } arg_pre.post_check( self, callable_name, - ty, + hint_ty, + use_hint, name, false, arg.range(), arg_errors, call_errors, context, - ) + ); } Some(PosParam { ty, @@ -568,17 +596,41 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ty, name, kind: PosParamKind::Variadic, - }) => arg_pre.post_check( - self, - callable_name, - ty, - name, - true, - arg.range(), - arg_errors, - call_errors, - context, - ), + }) => { + let ty = if let Some(reordered) = self.prefer_union_branch_without_vars(ty) + { + type_owner.push(reordered) + } else { + ty + }; + let expanded = self.solver().expand_vars((*ty).clone()); + let (hint_ty, mut use_hint) = if expanded == *ty { + (ty, false) + } else { + (type_owner.push(expanded), true) + }; + if !use_hint { + if !self.type_contains_var(hint_ty) { + use_hint = true; + } else if let Type::Union(options) = hint_ty { + if options.iter().any(|option| !self.type_contains_var(option)) { + use_hint = true; + } + } + } + arg_pre.post_check( + self, + callable_name, + hint_ty, + use_hint, + name, + true, + arg.range(), + arg_errors, + call_errors, + context, + ) + } None => { arg_pre.post_infer(self, arg_errors); if !arg_pre.is_star() { diff --git a/pyrefly/lib/alt/class/typed_dict.rs b/pyrefly/lib/alt/class/typed_dict.rs index 65f6d0d81..cf85a3706 100644 --- a/pyrefly/lib/alt/class/typed_dict.rs +++ b/pyrefly/lib/alt/class/typed_dict.rs @@ -28,7 +28,6 @@ use crate::alt::class::class_field::ClassField; use crate::alt::types::class_metadata::ClassMetadata; use crate::alt::types::class_metadata::ClassSynthesizedField; use crate::alt::types::class_metadata::ClassSynthesizedFields; -use crate::alt::unwrap::HintRef; use crate::binding::binding::ClassFieldDefinition; use crate::config::error_kind::ErrorKind; use crate::error::collector::ErrorCollector; @@ -149,11 +148,9 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { // This is an unpacked item (`**some_dict`). has_expansion = true; let partial_td_ty = Type::PartialTypedDict(typed_dict.clone()); - let item_ty = self.expr_infer_with_hint( - &x.value, - Some(HintRef::soft(&partial_td_ty)), - item_errors, - ); + let partial_hint = self.hint_from_type(partial_td_ty.clone(), None); + let item_ty = + self.expr_infer_with_hint(&x.value, Some(partial_hint.as_ref()), item_errors); let subset_result = self.is_subset_eq_with_reason(&item_ty, &partial_td_ty); if let Some(subset_error) = subset_result.err() { let tcc: &dyn Fn() -> TypeCheckContext = diff --git a/pyrefly/lib/alt/expr.rs b/pyrefly/lib/alt/expr.rs index c0ce4388b..56e771e0f 100644 --- a/pyrefly/lib/alt/expr.rs +++ b/pyrefly/lib/alt/expr.rs @@ -20,7 +20,6 @@ use pyrefly_types::callable::FunctionKind; use pyrefly_types::typed_dict::ExtraItems; use pyrefly_util::owner::Owner; use pyrefly_util::prelude::SliceExt; -use pyrefly_util::prelude::VecExt; use pyrefly_util::visit::Visit; use ruff_python_ast::Arguments; use ruff_python_ast::BoolOp; @@ -29,6 +28,9 @@ use ruff_python_ast::DictItem; use ruff_python_ast::Expr; use ruff_python_ast::ExprCall; use ruff_python_ast::ExprGenerator; +use ruff_python_ast::ExprLambda; +use ruff_python_ast::ExprList; +use ruff_python_ast::ExprName; use ruff_python_ast::ExprNumberLiteral; use ruff_python_ast::ExprStarred; use ruff_python_ast::ExprStringLiteral; @@ -42,7 +44,6 @@ use ruff_text_size::Ranged; use ruff_text_size::TextRange; use starlark_map::Hashed; use vec1::Vec1; -use vec1::vec1; use crate::alt::answers::LookupAnswer; use crate::alt::answers_solver::AnswersSolver; @@ -50,6 +51,7 @@ use crate::alt::callable::CallArg; use crate::alt::solve::TypeFormContext; use crate::alt::unwrap::Hint; use crate::alt::unwrap::HintRef; +use crate::binding::binding::Binding; use crate::binding::binding::Key; use crate::binding::binding::KeyYield; use crate::binding::binding::KeyYieldFrom; @@ -58,6 +60,7 @@ use crate::error::collector::ErrorCollector; use crate::error::context::ErrorContext; use crate::error::context::ErrorInfo; use crate::error::context::TypeCheckContext; +use crate::graph::index::Idx; use crate::types::callable::Callable; use crate::types::callable::Param; use crate::types::callable::ParamList; @@ -78,6 +81,7 @@ use crate::types::type_var::TypeVar; use crate::types::type_var_tuple::TypeVarTuple; use crate::types::types::AnyStyle; use crate::types::types::Type; +use crate::types::types::Var; #[derive(Debug, Clone, Copy)] pub enum TypeOrExpr<'a> { @@ -256,13 +260,23 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { hint: Option, errors: &ErrorCollector, ) -> TypeInfo { + if let Some(hint_ref) = hint + && hint_ref.branches().len() > 1 + { + return self.expr_infer_type_info_with_union_hint(x, hint_ref, errors); + } if let Some(self_type_annotation) = self.intercept_typing_self_use(x) { return self_type_annotation; } let res = match x { - Expr::Name(x) => self - .get(&Key::BoundName(ShortIdentifier::expr_name(x))) - .arc_clone(), + Expr::Name(x) => { + if let Some(info) = self.lambda_param_type_info(x) { + info + } else { + self.get(&Key::BoundName(ShortIdentifier::expr_name(x))) + .arc_clone() + } + } Expr::Attribute(x) => { let base = self.expr_infer_type_info_with_hint(&x.value, None, errors); self.record_external_attribute_definition_index( @@ -304,6 +318,45 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { res } + fn expr_infer_type_info_with_union_hint( + &self, + x: &Expr, + hint: HintRef, + errors: &ErrorCollector, + ) -> TypeInfo { + let mut first_error = None; + for branch in hint.branches() { + let branch_hint = self + .hint_from_type(branch.clone(), hint.errors()) + .with_source_branches(hint.source_branches()); + let branch_errors = self.error_collector(); + let info = + self.expr_infer_type_info_with_hint(x, Some(branch_hint.as_ref()), &branch_errors); + if branch_errors.is_empty() && self.is_subset_eq(info.ty(), branch) { + errors.extend(branch_errors); + return info; + } + if first_error.is_none() { + first_error = Some(branch_errors); + } + } + if let Some(errs) = first_error { + errors.extend(errs); + } + let fallback_errors = self.error_collector(); + let fallback_hint = Vec1::try_from_vec(vec![hint.ty().clone()]) + .ok() + .map(|branches| { + self.hint_from_branches(branches, hint.errors()) + .with_source_branches(hint.source_branches()) + }); + self.expr_infer_type_info_with_hint( + x, + fallback_hint.as_ref().map(|hint| hint.as_ref()), + &fallback_errors, + ) + } + fn expr_type_info( &self, x: &Expr, @@ -325,11 +378,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ) -> TypeInfo { match check { Some((hint, hint_errors, tcc)) if !hint.is_any() => { - let got = self.expr_infer_type_info_with_hint( - x, - Some(HintRef::new(hint, Some(hint_errors))), - errors, - ); + let owned_hint = self.hint_from_type(hint.clone(), Some(hint_errors)); + let got = self.expr_infer_type_info_with_hint(x, Some(owned_hint.as_ref()), errors); self.check_and_return_type_info(got, hint, x.range(), hint_errors, tcc) } _ => self.expr_infer_type_info_with_hint(x, None, errors), @@ -371,75 +421,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { Expr::BoolOp(x) => self.boolop(&x.values, x.op, hint, errors), Expr::BinOp(x) => self.binop_infer(x, hint, errors), Expr::UnaryOp(x) => self.unop_infer(x, errors), - Expr::Lambda(lambda) => { - let param_vars = if let Some(parameters) = &lambda.parameters { - parameters - .iter_non_variadic_params() - .map(|x| (&x.name().id, self.bindings().get_lambda_param(x.name()))) - .collect() - } else { - Vec::new() - }; - // Pass any contextual information to the parameter bindings used in the lambda body as a side - // effect, by setting an answer for the vars created at binding time. - let return_hint = hint.and_then(|hint| self.decompose_lambda(hint, ¶m_vars)); - - let mut params = param_vars.into_map(|(name, var)| { - Param::Pos( - name.clone(), - self.solver().force_var(var), - Required::Required, - ) - }); - if let Some(parameters) = &lambda.parameters { - params.extend(parameters.vararg.iter().map(|x| { - Param::VarArg( - Some(x.name.id.clone()), - self.solver() - .force_var(self.bindings().get_lambda_param(&x.name)), - ) - })); - params.extend(parameters.kwarg.iter().map(|x| { - Param::Kwargs( - Some(x.name.id.clone()), - self.solver() - .force_var(self.bindings().get_lambda_param(&x.name)), - ) - })); - } - let params = Params::List(ParamList::new(params)); - let ret = self.expr_infer_type_no_trace( - &lambda.body, - return_hint.as_ref().map(|hint| hint.as_ref()), - errors, - ); - Type::Callable(Box::new(Callable { params, ret })) - } + Expr::Lambda(lambda) => self.lambda_infer(lambda, hint, errors), Expr::Tuple(x) => self.tuple_infer(x, hint, errors), Expr::List(x) => { - let elt_hint = hint.and_then(|ty| self.decompose_list(ty)); - if x.is_empty() { - let elem_ty = elt_hint.map_or_else( - || { - if !self.solver().infer_with_first_use { - self.error( - errors, - x.range(), - ErrorInfo::Kind(ErrorKind::ImplicitAny), - "This expression is implicitly inferred to be `list[Any]`. Please provide an explicit type annotation.".to_owned(), - ); - Type::any_implicit() - } else { - self.solver().fresh_contained(self.uniques).to_type() - } - }, - |hint| hint.to_type(), - ); - self.stdlib.list(elem_ty).to_type() - } else { - let elem_tys = self.elts_infer(&x.elts, elt_hint, errors); - self.stdlib.list(self.unions(elem_tys)).to_type() - } + let elt_hint = hint.and_then(|hint| self.decompose_list(hint)); + self.list_with_hint(x, elt_hint, errors) } Expr::Dict(x) => self.dict_infer(&x.items, hint, x.range, errors), Expr::Set(x) => { @@ -621,33 +607,39 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { if let Some(want) = hint && self.is_subset_eq(&ty, want.ty()) { - want.ty().clone() + let mut matched_branch = None; + for branch in want.branches() { + if self.is_subset_eq(&ty, branch) { + matched_branch = Some(branch.clone()); + break; + } + } + matched_branch.unwrap_or_else(|| want.ty().clone()) } else { ty.promote_literals(self.stdlib) } } fn tuple_infer(&self, x: &ExprTuple, hint: Option, errors: &ErrorCollector) -> Type { - let owner = Owner::new(); let (hint_ts, default_hint) = if let Some(hint) = &hint { let (tuples, nontuples) = self.split_tuple_hint(hint.ty()); // Combine hints from multiple tuples. - let mut element_hints: Vec> = Vec::new(); - let mut default_hint = Vec::new(); + let mut element_hints: Vec> = Vec::new(); + let mut default_hint: Vec = Vec::new(); for tuple in tuples { let (cur_element_hints, cur_default_hint) = self.tuple_to_element_hints(tuple); if let Some(cur_default_hint) = cur_default_hint { // Use the default hint for any elements that this tuple doesn't provide per-element hints for. for ts in element_hints.iter_mut().skip(cur_element_hints.len()) { - ts.push(cur_default_hint); + ts.push(cur_default_hint.clone()); } - default_hint.push(cur_default_hint); + default_hint.push(cur_default_hint.clone()); } for (i, element_hint) in cur_element_hints.into_iter().enumerate() { if i < element_hints.len() { - element_hints[i].push(element_hint); + element_hints[i].push(element_hint.clone()); } else { - element_hints.push(vec1![element_hint]); + element_hints.push(vec![element_hint.clone()]); } } } @@ -661,21 +653,25 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { .cloned() .collect(), ); - let nontuple_element_hint = - self.decompose_tuple(HintRef::new(&nontuple_hint, hint.errors())); + let nontuple_hint = self.hint_from_type(nontuple_hint, hint.errors()); + let nontuple_element_hint = self.decompose_tuple(nontuple_hint.as_ref()); if let Some(nontuple_element_hint) = nontuple_element_hint { - let nontuple_element_hint = owner.push(nontuple_element_hint.to_type()); + let nontuple_element_hint = nontuple_element_hint.to_type(); for ts in element_hints.iter_mut() { - ts.push(nontuple_element_hint); + ts.push(nontuple_element_hint.clone()); } default_hint.push(nontuple_element_hint); } } ( - element_hints.into_map(|ts| self.types_to_hint(ts, hint.errors(), &owner)), + element_hints + .into_iter() + .map(|ts| Vec1::try_from_vec(ts).expect("non-empty element hint")) + .map(|ts| self.types_to_hint(ts, hint.errors(), hint.source_branches())) + .collect(), Vec1::try_from_vec(default_hint) .ok() - .map(|ts| self.types_to_hint(ts, hint.errors(), &owner)), + .map(|ts| self.types_to_hint(ts, hint.errors(), hint.source_branches())), ) } else { (Vec::new(), None) @@ -730,13 +726,14 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } _ => { + let elem_hint_owned = if unbounded.is_empty() { + hint_ts_iter.next().or_else(|| default_hint.clone()) + } else { + None + }; let ty = self.expr_infer_type_no_trace( elt, - if unbounded.is_empty() { - hint_ts_iter.next().or(default_hint) - } else { - None - }, + elem_hint_owned.as_ref().map(|hint| hint.as_ref()), errors, ); if unbounded.is_empty() { @@ -797,21 +794,14 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } - fn types_to_hint<'b>( + fn types_to_hint( &self, - ts: Vec1<&'b Type>, - errors: Option<&'b ErrorCollector>, - owner: &'b Owner, - ) -> HintRef<'b, 'b> { - if ts.len() == 1 { - let (t, _) = ts.split_off_first(); - HintRef::new(t, errors) - } else { - HintRef::new( - owner.push(self.unions(ts.into_iter().cloned().collect())), - errors, - ) - } + ts: Vec1, + errors: Option<&'a ErrorCollector>, + source_branches: usize, + ) -> Hint<'a> { + self.hint_from_branches(ts, errors) + .with_source_branches(source_branches) } fn dict_infer( @@ -822,41 +812,36 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { errors: &ErrorCollector, ) -> Type { let flattened_items = Ast::flatten_dict_items(items); - let hints = hint.as_ref().map_or(Vec::new(), |hint| match hint.ty() { - Type::Union(ts) => ts - .iter() - .map(|ty| HintRef::new(ty, hint.errors())) - .collect(), - _ => vec![*hint], - }); - for hint in hints.iter() { - let (typed_dict, is_update) = match hint.ty() { - Type::TypedDict(td) => (td, false), - Type::PartialTypedDict(td) => (td, true), - _ => continue, - }; - let check_errors = self.error_collector(); - let item_errors = self.error_collector(); - self.check_dict_items_against_typed_dict( - &flattened_items, - typed_dict, - is_update, - range, - &check_errors, - &item_errors, - ); + if let Some(hint_ref) = hint { + for branch in hint_ref.branches() { + let (typed_dict, is_update) = match branch { + Type::TypedDict(td) => (td, false), + Type::PartialTypedDict(td) => (td, true), + _ => continue, + }; + let check_errors = self.error_collector(); + let item_errors = self.error_collector(); + self.check_dict_items_against_typed_dict( + &flattened_items, + typed_dict, + is_update, + range, + &check_errors, + &item_errors, + ); - // We use the TypedDict hint if it successfully matched or if there is only one hint, unless - // this is a "soft" type hint, in which case we don't want to raise any check errors. - if check_errors.is_empty() - || hints.len() == 1 - && hint - .errors() - .inspect(|errors| errors.extend(check_errors)) - .is_some() - { - errors.extend(item_errors); - return (*hint.ty()).clone(); + // We use the TypedDict hint if it successfully matched or if there is only one hint, unless + // this is a "soft" type hint, in which case we don't want to raise any check errors. + if check_errors.is_empty() + || hint_ref.source_branches() == 1 + && hint_ref + .errors() + .inspect(|errors| errors.extend(check_errors)) + .is_some() + { + errors.extend(item_errors); + return branch.clone(); + } } } // Note that we don't need to filter out the TypedDict options here; any non-`dict` options @@ -872,90 +857,114 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { hint: Option, errors: &ErrorCollector, ) -> Type { - let (key_hint, value_hint) = hint.map_or((None, None), |ty| self.decompose_dict(ty)); + let (key_hint, value_hint) = if let Some(hint) = hint { + let (key_hint, value_hint) = self.decompose_dict(hint); + let key_hint = key_hint.or_else(|| self.decompose_iterable(hint)); + (key_hint, value_hint) + } else { + (None, None) + }; if items.is_empty() { - let key_ty = key_hint.map_or_else( + let key_ty = key_hint.as_ref().map_or_else( || { if !self.solver().infer_with_first_use { + self.error( + errors, + range, + ErrorInfo::Kind(ErrorKind::ImplicitAny), + "This expression is implicitly inferred to be `dict[Any, Any]`. Please provide an explicit type annotation.".to_owned(), + ); Type::any_implicit() } else { self.solver().fresh_contained(self.uniques).to_type() } }, - |ty| ty.to_type(), + |hint| hint.to_type(), ); - let value_ty = value_hint.map_or_else( + let value_ty = value_hint.as_ref().map_or_else( || { if !self.solver().infer_with_first_use { + self.error( + errors, + range, + ErrorInfo::Kind(ErrorKind::ImplicitAny), + "This expression is implicitly inferred to be `dict[Any, Any]`. Please provide an explicit type annotation.".to_owned(), + ); Type::any_implicit() } else { self.solver().fresh_contained(self.uniques).to_type() } }, - |ty| ty.to_type(), + |hint| hint.to_type(), ); - if hint.is_none() && !self.solver().infer_with_first_use { - self.error( - errors, - range, - ErrorInfo::Kind(ErrorKind::ImplicitAny), - "This expression is implicitly inferred to be `dict[Any, Any]`. Please provide an explicit type annotation.".to_owned(), - ); - } self.stdlib.dict(key_ty, value_ty).to_type() } else { let mut key_tys = Vec::new(); let mut value_tys = Vec::new(); - items.iter().for_each(|x| match &x.key { - Some(key) => { - let key_t = self.expr_infer_with_hint_promote( - key, - key_hint.as_ref().map(|hint| hint.as_ref()), - errors, - ); - let value_t = self.expr_infer_with_hint_promote( - &x.value, - value_hint.as_ref().map(|hint| hint.as_ref()), - errors, - ); - if !key_t.is_error() { - key_tys.push(key_t); - } - if !value_t.is_error() { - value_tys.push(value_t); - } - } - None => { - let ty = self.expr_infer(&x.value, errors); - if let Some((key_t, value_t)) = self.unwrap_mapping(&ty) { + for item in items { + match &item.key { + Some(key) => { + let key_t = self.expr_infer_with_hint_promote( + key, + key_hint.as_ref().map(|hint| hint.as_ref()), + errors, + ); + let value_t = self.expr_infer_with_hint_promote( + &item.value, + value_hint.as_ref().map(|hint| hint.as_ref()), + errors, + ); if !key_t.is_error() { if let Some(key_hint) = &key_hint - && self.is_subset_eq(&key_t, key_hint.ty()) + && self.is_subset_eq(&key_t, key_hint.union()) { - key_tys.push(key_hint.ty().clone()); + key_tys.push(key_hint.union().clone()); } else { key_tys.push(key_t); } } if !value_t.is_error() { if let Some(value_hint) = &value_hint - && self.is_subset_eq(&value_t, value_hint.ty()) + && self.is_subset_eq(&value_t, value_hint.union()) { - value_tys.push(value_hint.ty().clone()); + value_tys.push(value_hint.union().clone()); } else { value_tys.push(value_t); } } - } else { - self.error( - errors, - x.value.range(), - ErrorInfo::Kind(ErrorKind::InvalidArgument), - format!("Expected a mapping, got {}", self.for_display(ty)), - ); + } + None => { + let ty = self.expr_infer(&item.value, errors); + if let Some((key_t, value_t)) = self.unwrap_mapping(&ty) { + if !key_t.is_error() { + if let Some(key_hint) = &key_hint + && self.is_subset_eq(&key_t, key_hint.union()) + { + key_tys.push(key_hint.union().clone()); + } else { + key_tys.push(key_t); + } + } + if !value_t.is_error() { + if let Some(value_hint) = &value_hint + && self.is_subset_eq(&value_t, value_hint.union()) + { + value_tys.push(value_hint.union().clone()); + } else { + value_tys.push(value_t); + } + } + } else { + self.error( + errors, + item.value.range(), + ErrorInfo::Kind(ErrorKind::InvalidArgument), + format!("Expected a mapping, got {}", self.for_display(ty)), + ); + } } } - }); + } if key_tys.is_empty() { key_tys.push(Type::any_error()) } @@ -967,7 +976,6 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { self.stdlib.dict(key_ty, value_ty).to_type() } } - /// If this is a `dict` call that can be converted to an equivalent dict literal (e.g., `dict(x=1)` => `{'x': 1}`), /// return the items in the converted dict. fn call_to_dict(&self, callee_ty: &Type, args: &Arguments) -> Option> { @@ -1083,14 +1091,16 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { for (i, value) in values.iter().enumerate() { // If there isn't a hint for the overall expression, use the preceding branches as a "soft" hint // for the next one. Most useful for expressions like `optional_list or []`. - let hint = hint.or_else(|| { + let mut soft_hint = None; + let hint_ref = hint.clone().or_else(|| { if t_acc.is_never() { None } else { - Some(HintRef::soft(&t_acc)) + soft_hint = Some(self.soft_hint_from_type(t_acc.clone())); + Some(soft_hint.as_ref().unwrap().as_ref()) } }); - let mut t = self.expr_infer_with_hint(value, hint, errors); + let mut t = self.expr_infer_with_hint(value, hint_ref, errors); self.expand_vars_mut(&mut t); // If this is not the last entry, we have to make a type-dependent decision and also narrow the // result; both operations require us to force `Var` first or they become unpredictable. @@ -1122,6 +1132,85 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { t_acc } + fn lambda_infer( + &self, + lambda: &ExprLambda, + hint: Option, + errors: &ErrorCollector, + ) -> Type { + let param_vars: Vec<(Name, Var)> = if let Some(parameters) = &lambda.parameters { + parameters + .iter_non_variadic_params() + .map(|x| { + ( + x.name().id.clone(), + self.bindings().get_lambda_param(x.name()), + ) + }) + .collect() + } else { + Vec::new() + }; + let param_states = param_vars + .iter() + .map(|(_, var)| self.solver().snapshot_unwrap_var(*var)) + .collect::>(); + let return_hint = hint.and_then(|hint| self.decompose_lambda(hint, ¶m_vars)); + + let ret = self.expr_infer_type_no_trace( + &lambda.body, + return_hint.as_ref().map(|hint| hint.as_ref()), + errors, + ); + let ret = self.solver().expand_vars(ret); + + let mut params = param_vars + .iter() + .map(|(name, var)| { + Param::Pos( + name.clone(), + self.solver().force_var(*var), + Required::Required, + ) + }) + .collect::>(); + if let Some(parameters) = &lambda.parameters { + params.extend(parameters.vararg.iter().map(|x| { + Param::VarArg( + Some(x.name.id.clone()), + self.solver() + .force_var(self.bindings().get_lambda_param(&x.name)), + ) + })); + params.extend(parameters.kwarg.iter().map(|x| { + Param::Kwargs( + Some(x.name.id.clone()), + self.solver() + .force_var(self.bindings().get_lambda_param(&x.name)), + ) + })); + } + let params = Params::List(ParamList::new(params)); + for ((_, var), state) in param_vars.iter().zip(param_states.into_iter()) { + self.solver().restore_unwrap_var(*var, state); + } + Type::Callable(Box::new(Callable { params, ret })) + } + + fn lambda_param_type_info(&self, name: &ExprName) -> Option { + let key = Key::BoundName(ShortIdentifier::expr_name(name)); + let idx = self.bindings().key_to_idx(&key); + self.lambda_param_type_info_from_idx(idx) + } + + fn lambda_param_type_info_from_idx(&self, idx: Idx) -> Option { + match self.bindings().get(idx) { + Binding::Forward(next) => self.lambda_param_type_info_from_idx(*next), + Binding::LambdaParameter(var) => Some(TypeInfo::of_ty(var.to_type())), + _ => None, + } + } + /// Infers types for `if` clauses in the given comprehensions. /// This is for error detection only; the types are not used. fn ifs_infer(&self, comps: &[Comprehension], errors: &ErrorCollector) { @@ -1780,8 +1869,9 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ) -> Vec { let star_hint = LazyCell::new(|| { elt_hint.as_ref().map(|hint| { - hint.as_ref() - .map_ty(|ty| self.stdlib.iterable(ty.clone()).to_type()) + self.hint_map(hint.as_ref(), |ty| { + self.stdlib.iterable(ty.clone()).to_type() + }) }) }); elts.map(|x| match x { @@ -1813,6 +1903,36 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { }) } + fn list_with_hint( + &self, + x: &ExprList, + elt_hint: Option, + errors: &ErrorCollector, + ) -> Type { + if x.is_empty() { + let elem_ty = elt_hint.as_ref().map_or_else( + || { + if !self.solver().infer_with_first_use { + self.error( + errors, + x.range(), + ErrorInfo::Kind(ErrorKind::ImplicitAny), + "This expression is implicitly inferred to be `list[Any]`. Please provide an explicit type annotation.".to_owned(), + ); + Type::any_implicit() + } else { + self.solver().fresh_contained(self.uniques).to_type() + } + }, + |hint| hint.to_type(), + ); + self.stdlib.list(elem_ty).to_type() + } else { + let elem_tys = self.elts_infer(&x.elts, elt_hint, errors); + self.stdlib.list(self.unions(elem_tys)).to_type() + } + } + fn intercept_typing_self_use(&self, x: &Expr) -> Option { match x { Expr::Name(..) | Expr::Attribute(..) => { diff --git a/pyrefly/lib/alt/operators.rs b/pyrefly/lib/alt/operators.rs index 8eb32fea5..e91acf66c 100644 --- a/pyrefly/lib/alt/operators.rs +++ b/pyrefly/lib/alt/operators.rs @@ -237,7 +237,6 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { lhs = self.expr_infer(&x.left, errors); rhs = self.expr_infer(&x.right, errors); } - // Optimisation: If we have `Union[a, b] | Union[c, d]`, instead of unioning // (a | c) | (a | d) | (b | c) | (b | d), we can just do one union. if x.op == Operator::BitOr diff --git a/pyrefly/lib/alt/solve.rs b/pyrefly/lib/alt/solve.rs index 659036ca0..adb164afc 100644 --- a/pyrefly/lib/alt/solve.rs +++ b/pyrefly/lib/alt/solve.rs @@ -49,7 +49,6 @@ use crate::alt::types::decorated_function::UndecoratedFunction; use crate::alt::types::legacy_lookup::LegacyTypeParameterLookup; use crate::alt::types::yields::YieldFromResult; use crate::alt::types::yields::YieldResult; -use crate::alt::unwrap::HintRef; use crate::binding::binding::AnnAssignHasValue; use crate::binding::binding::AnnotationStyle; use crate::binding::binding::AnnotationTarget; @@ -3265,9 +3264,10 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { x.ty(self.stdlib) .map(|ty| self.stdlib.async_iterable(ty.clone()).to_type()) }); + let infer_hint = infer_hint.map(|ty| self.hint_from_type(ty, None)); let iterable = self.expr_infer_with_hint( e, - infer_hint.as_ref().map(HintRef::soft), + infer_hint.as_ref().map(|hint| hint.as_ref()), errors, ); self.async_iterate(&iterable, e.range(), errors) @@ -3276,9 +3276,10 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { x.ty(self.stdlib) .map(|ty| self.stdlib.iterable(ty.clone()).to_type()) }); + let infer_hint = infer_hint.map(|ty| self.hint_from_type(ty, None)); let iterable = self.expr_infer_with_hint( e, - infer_hint.as_ref().map(HintRef::soft), + infer_hint.as_ref().map(|hint| hint.as_ref()), errors, ); self.iterate(&iterable, e.range(), errors, None) diff --git a/pyrefly/lib/alt/unwrap.rs b/pyrefly/lib/alt/unwrap.rs index 6ba1c0684..2b42ea742 100644 --- a/pyrefly/lib/alt/unwrap.rs +++ b/pyrefly/lib/alt/unwrap.rs @@ -6,6 +6,7 @@ */ use ruff_python_ast::name::Name; +use vec1::Vec1; use crate::alt::answers::LookupAnswer; use crate::alt::answers_solver::AnswersSolver; @@ -17,56 +18,79 @@ use crate::types::tuple::Tuple; use crate::types::types::Type; use crate::types::types::Var; -// The error collector is None for a "soft" type hint, where we try to -// match an expression against a hint, but fall back to the inferred type -// without any errors if the hint is incompatible. -// Soft type hints are used for `e1 or e1` expressions. -pub struct Hint<'a>(Type, Option<&'a ErrorCollector>); +#[derive(Clone, Debug)] +pub struct Hint<'a> { + union: Type, + branches: Vec1, + errors: Option<&'a ErrorCollector>, + source_branches: usize, +} #[derive(Clone, Copy, Debug)] -pub struct HintRef<'a, 'b>(&'b Type, Option<&'a ErrorCollector>); +pub struct HintRef<'a, 'b> { + union: &'b Type, + branches: &'b [Type], + errors: Option<&'a ErrorCollector>, + source_branches: usize, +} impl<'a> Hint<'a> { - pub fn as_ref<'b>(&'a self) -> HintRef<'a, 'b> - where - 'a: 'b, - { - HintRef(&self.0, self.1) + pub fn new(union: Type, branches: Vec1, errors: Option<&'a ErrorCollector>) -> Self { + let source_branches = branches.len(); + Self { + union, + branches, + errors, + source_branches, + } } - pub fn ty(&self) -> &Type { - &self.0 + pub fn as_ref(&self) -> HintRef<'a, '_> { + HintRef { + union: &self.union, + branches: self.branches.as_slice(), + errors: self.errors, + source_branches: self.source_branches, + } } - pub fn to_type(self) -> Type { - self.0 + pub fn errors(&self) -> Option<&'a ErrorCollector> { + self.errors } -} -impl<'a, 'b> HintRef<'a, 'b> { - pub fn new(hint: &'b Type, errors: Option<&'a ErrorCollector>) -> Self { - Self(hint, errors) + pub fn to_type(&self) -> Type { + self.union.clone() + } + + pub fn union(&self) -> &Type { + &self.union + } + + pub fn source_branches(&self) -> usize { + self.source_branches } - /// Construct a "soft" type hint that doesn't report an error when the hint is incompatible. - pub fn soft(hint: &'b Type) -> Self { - Self(hint, None) + pub fn with_source_branches(mut self, count: usize) -> Self { + self.source_branches = count.max(1); + self } +} +impl<'a, 'b> HintRef<'a, 'b> { pub fn ty(&self) -> &Type { - self.0 + self.union } - pub fn errors(&self) -> Option<&ErrorCollector> { - self.1 + pub fn errors(&self) -> Option<&'a ErrorCollector> { + self.errors } - pub fn map_ty(&self, f: impl FnOnce(&Type) -> Type) -> Hint<'a> { - Hint(f(self.0), self.1) + pub fn branches(&self) -> &'b [Type] { + self.branches } - pub fn map_ty_opt(&self, f: impl FnOnce(&Type) -> Option) -> Option> { - f(self.0).map(|ty| Hint(ty, self.1)) + pub fn source_branches(&self) -> usize { + self.source_branches.max(1) } } @@ -119,6 +143,80 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } + fn collect_hint_branches(&self, ty: &Type) -> Vec { + let mut branches = Vec::new(); + self.map_over_union(ty, |branch| branches.push(branch.clone())); + if branches.is_empty() { + branches.push(ty.clone()); + } + branches.sort_by_key(|branch| self.type_contains_var(branch)); + branches + } + + pub fn hint_from_type(&self, ty: Type, errors: Option<&'a ErrorCollector>) -> Hint<'a> { + let branches = + Vec1::try_from_vec(self.collect_hint_branches(&ty)).expect("hint branches non-empty"); + Hint::new(ty, branches, errors) + } + + pub fn soft_hint_from_type(&self, ty: Type) -> Hint<'a> { + self.hint_from_type(ty, None) + } + + fn hint_from_branches_vec( + &self, + mut branches: Vec, + errors: Option<&'a ErrorCollector>, + ) -> Option> { + if branches.is_empty() { + return None; + } + branches.sort_by_key(|branch| self.type_contains_var(branch)); + let branches = Vec1::try_from_vec(branches).ok()?; + let union = if branches.len() == 1 { + branches.first().clone() + } else { + self.unions(branches.clone().into_vec()) + }; + Some(Hint::new(union, branches, errors)) + } + + pub fn hint_from_branches( + &self, + branches: Vec1, + errors: Option<&'a ErrorCollector>, + ) -> Hint<'a> { + let union = if branches.len() == 1 { + branches.first().clone() + } else { + self.unions(branches.clone().into_vec()) + }; + Hint::new(union, branches, errors) + } + + pub(crate) fn hint_map(&self, hint: HintRef<'a, '_>, f: impl Fn(&Type) -> Type) -> Hint<'a> { + let mapped = hint.branches().iter().map(|ty| f(ty)).collect(); + let source_branches = hint.source_branches(); + self.hint_from_branches_vec(mapped, hint.errors()) + .expect("hint branches non-empty") + .with_source_branches(source_branches) + } + + pub(crate) fn hint_filter_map( + &self, + hint: HintRef<'a, '_>, + mut f: impl FnMut(&Type) -> Option, + ) -> Option> { + let mapped = hint + .branches() + .iter() + .filter_map(|ty| f(ty)) + .collect::>(); + let source_branches = hint.source_branches(); + self.hint_from_branches_vec(mapped, hint.errors()) + .map(|hint| hint.with_source_branches(source_branches)) + } + pub fn unwrap_mapping(&self, ty: &Type) -> Option<(Type, Type)> { let key = self.fresh_var(); let value = self.fresh_var(); @@ -222,70 +320,104 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { pub fn decompose_dict<'b>( &self, - hint: HintRef<'b, '_>, - ) -> (Option>, Option>) { - let key = self.fresh_var(); - let value = self.fresh_var(); - let dict_type = self.stdlib.dict(key.to_type(), value.to_type()).to_type(); - if self.is_subset_eq(&dict_type, hint.ty()) { - let key = hint.map_ty_opt(|ty| self.resolve_var_opt(ty, key)); - let value = hint.map_ty_opt(|ty| self.resolve_var_opt(ty, value)); - (key, value) - } else { - (None, None) + hint: HintRef<'a, 'b>, + ) -> (Option>, Option>) { + let mut key_types = Vec::new(); + let mut value_types = Vec::new(); + let source_branches = hint.source_branches(); + for branch in hint.branches() { + let key = self.fresh_var(); + let value = self.fresh_var(); + let dict_type = self.stdlib.dict(key.to_type(), value.to_type()).to_type(); + if self.is_subset_eq(&dict_type, branch) { + if let (Some(key_ty), Some(value_ty)) = ( + self.resolve_var_opt(branch, key), + self.resolve_var_opt(branch, value), + ) { + key_types.push(key_ty); + value_types.push(value_ty); + } + } } + let key = self + .hint_from_branches_vec(key_types, hint.errors()) + .map(|hint| hint.with_source_branches(source_branches)); + let value = self + .hint_from_branches_vec(value_types, hint.errors()) + .map(|hint| hint.with_source_branches(source_branches)); + (key, value) } - pub fn decompose_set<'b>(&self, hint: HintRef<'b, '_>) -> Option> { - let elem = self.fresh_var(); - let set_type = self.stdlib.set(elem.to_type()).to_type(); - if self.is_subset_eq(&set_type, hint.ty()) { - hint.map_ty_opt(|ty| self.resolve_var_opt(ty, elem)) - } else { - None - } + pub fn decompose_set<'b>(&self, hint: HintRef<'a, 'b>) -> Option> { + self.hint_filter_map(hint, move |branch| { + let elem = self.fresh_var(); + let set_type = self.stdlib.set(elem.to_type()).to_type(); + if self.is_subset_eq(&set_type, branch) { + self.resolve_var_opt(branch, elem) + } else { + None + } + }) } - pub fn decompose_list<'b>(&self, hint: HintRef<'b, '_>) -> Option> { - let elem = self.fresh_var(); - let list_type = self.stdlib.list(elem.to_type()).to_type(); - if self.is_subset_eq(&list_type, hint.ty()) { - hint.map_ty_opt(|ty| self.resolve_var_opt(ty, elem)) - } else { - None - } + pub fn decompose_iterable<'b>(&self, hint: HintRef<'a, 'b>) -> Option> { + self.hint_filter_map(hint, move |branch| { + let elem = self.fresh_var(); + let iterable_type = self.stdlib.iterable(elem.to_type()).to_type(); + if self.is_subset_eq(&iterable_type, branch) { + self.resolve_var_opt(branch, elem) + } else { + None + } + }) } - pub fn decompose_tuple<'b>(&self, hint: HintRef<'b, '_>) -> Option> { - let elem = self.fresh_var(); - let tuple_type = self.stdlib.tuple(elem.to_type()).to_type(); - if self.is_subset_eq(&tuple_type, hint.ty()) { - hint.map_ty_opt(|ty| self.resolve_var_opt(ty, elem)) - } else { - None - } + pub fn decompose_list<'b>(&self, hint: HintRef<'a, 'b>) -> Option> { + self.hint_filter_map(hint, move |branch| { + let elem = self.fresh_var(); + let list_type = self.stdlib.list(elem.to_type()).to_type(); + if self.is_subset_eq(&list_type, branch) { + self.resolve_var_opt(branch, elem) + } else { + None + } + }) + } + + pub fn decompose_tuple<'b>(&self, hint: HintRef<'a, 'b>) -> Option> { + self.hint_filter_map(hint, move |branch| { + let elem = self.fresh_var(); + let tuple_type = self.stdlib.tuple(elem.to_type()).to_type(); + if self.is_subset_eq(&tuple_type, branch) { + self.resolve_var_opt(branch, elem) + } else { + None + } + }) } pub fn decompose_lambda<'b>( &self, - hint: HintRef<'b, '_>, - param_vars: &[(&Name, Var)], - ) -> Option> { + hint: HintRef<'a, 'b>, + param_vars: &[(Name, Var)], + ) -> Option> { let return_ty = self.fresh_var(); let params = param_vars .iter() - .map(|(name, var)| Param::Pos((*name).clone(), var.to_type(), Required::Required)) + .map(|(name, var)| Param::Pos(name.clone(), var.to_type(), Required::Required)) .collect::>(); let callable_ty = Type::callable(params, return_ty.to_type()); - if self.is_subset_eq(&callable_ty, hint.ty()) { - hint.map_ty_opt(|ty| self.resolve_var_opt(ty, return_ty)) - } else { - None - } + self.hint_filter_map(hint, move |branch| { + if self.is_subset_eq(&callable_ty, branch) { + self.resolve_var_opt(branch, return_ty) + } else { + None + } + }) } - pub fn decompose_generator_yield<'b>(&self, hint: HintRef<'b, '_>) -> Option> { + pub fn decompose_generator_yield<'b>(&self, hint: HintRef<'a, 'b>) -> Option> { let yield_ty = self.fresh_var(); let generator_ty = self .stdlib @@ -295,11 +427,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { self.fresh_var().to_type(), ) .to_type(); - if self.is_subset_eq(&generator_ty, hint.ty()) { - hint.map_ty_opt(|ty| self.resolve_var_opt(ty, yield_ty)) - } else { - None - } + self.hint_filter_map(hint, move |branch| { + if self.is_subset_eq(&generator_ty, branch) { + self.resolve_var_opt(branch, yield_ty) + } else { + None + } + }) } pub fn decompose_generator(&self, ty: &Type) -> Option<(Type, Type, Type)> { diff --git a/pyrefly/lib/solver/solver.rs b/pyrefly/lib/solver/solver.rs index 4f849d55f..867e19be6 100644 --- a/pyrefly/lib/solver/solver.rs +++ b/pyrefly/lib/solver/solver.rs @@ -298,6 +298,12 @@ pub struct Solver { pub infer_with_first_use: bool, } +#[derive(Clone)] +pub enum UnwrapVarState { + Unwrap, + Answer(Type), +} + impl Display for Solver { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { for (x, y) in self.variables.lock().iter() { @@ -325,6 +331,24 @@ impl Solver { self.variables.lock().recurse(var, recurser) } + pub fn snapshot_unwrap_var(&self, var: Var) -> UnwrapVarState { + let variables = self.variables.lock(); + match &*variables.get(var) { + Variable::Unwrap => UnwrapVarState::Unwrap, + Variable::Answer(ty) => UnwrapVarState::Answer(ty.clone()), + other => panic!("Expected lambda parameter var to be Unwrap or Answer, got {other:?}"), + } + } + + pub fn restore_unwrap_var(&self, var: Var, state: UnwrapVarState) { + let variables = self.variables.lock(); + let mut entry = variables.get_mut(var); + *entry = match state { + UnwrapVarState::Unwrap => Variable::Unwrap, + UnwrapVarState::Answer(ty) => Variable::Answer(ty), + }; + } + /// Force all non-recursive Vars in `vars`. /// /// TODO: deduplicate Variable-to-gradual-type logic with `force_var`. diff --git a/pyrefly/lib/test/contextual.rs b/pyrefly/lib/test/contextual.rs index 63f8683d2..101287b8a 100644 --- a/pyrefly/lib/test/contextual.rs +++ b/pyrefly/lib/test/contextual.rs @@ -102,7 +102,6 @@ kwarg(xs=[B()], ys=[B()]) ); testcase!( - bug = "Both assignments should be allowed. When decomposing the contextual hint, we eagerly resolve vars to the 'first' branch of the union. Note: due to the union's sorted representation, the first branch is not necessarily the first in source order.", test_contextual_typing_against_unions, r#" class A: ... @@ -110,7 +109,7 @@ class B: ... class B2(B): ... class C: ... -x: list[A] | list[B] = [B2()] # E: `list[B2]` is not assignable to `list[A] | list[B]` +x: list[A] | list[B] = [B2()] y: list[B] | list[C] = [B2()] "#, ); @@ -266,7 +265,6 @@ x2: list[A] = True and [B()] ); testcase!( - bug = "x or y or ... fails due to union hints, see test_contextual_typing_against_unions", test_context_boolop_soft, r#" from typing import TypedDict, assert_type @@ -280,7 +278,7 @@ def test(x: list[A] | None, y: list[C] | None, z: TD | None) -> None: assert_type(x or [B()], list[A]) assert_type(x or [0], list[A] | list[int]) assert_type(x or y or [B()], list[A] | list[C]) - assert_type(x or y or [D()], list[A] | list[C]) # TODO # E: assert_type(list[A] | list[C] | list[D], list[A] | list[C]) failed + assert_type(x or y or [D()], list[A] | list[C]) assert_type(z or {"x": 0}, TD) assert_type(z or {"x": ""}, TD | dict[str, str]) "#, @@ -309,6 +307,37 @@ f: Callable[[], list[A]] = lambda: [B()] "#, ); +testcase!( + test_context_lambda_callable_union, + r#" +from typing import Callable +f: Callable[[int], int] | Callable[[str], str] = lambda x: x + "1" +"#, +); + +testcase!( + test_context_dict_set_union, + r#" +from typing import assert_type +xs: dict[int, int] | dict[str, str] = {} +assert_type(xs, dict[int, int] | dict[str, str]) +ys: set[int] | set[str] = {1} +assert_type(ys, set[int] | set[str]) +"#, +); + +testcase!( + test_context_typed_dict_union_list, + r#" +from typing import TypedDict +class A(TypedDict): + x: int +class B(TypedDict): + x: str +xs: list[A] | list[B] = [{"x": "foo"}] +"#, +); + // We want to contextually type lambda params even when there is an arity mismatch. testcase!( test_context_lambda_arity, @@ -389,7 +418,6 @@ f(g(0)) # OK ); testcase!( - bug = "Propagating the hint should still allow for a narrower inferred type", test_context_return_narrow, r#" from typing import assert_type @@ -399,7 +427,7 @@ def f[T](x: T) -> T: def test(x: int | str): x = f(0) - assert_type(x, int) # E: assert_type(int | str, int) failed + assert_type(x, int) "#, );