Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions crates/pyrefly_types/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1443,6 +1443,23 @@ impl Type {
self
}

/// Promote all literals (both implicit and explicit) to their base types.
/// This is used for comprehension element inference, where we don't want to
/// preserve `LiteralString` or `Literal[...]` types even if they came from
/// explicit type annotations.
pub fn promote_all_literals(mut self, stdlib: &Stdlib) -> Type {
fn g(ty: &mut Type, f: &mut dyn FnMut(&mut Type)) {
ty.recurse_mut(&mut |ty| g(ty, f));
f(ty);
}
g(&mut self, &mut |ty| match &ty {
Type::Literal(lit) => *ty = lit.value.general_class_type(stdlib).clone().to_type(),
Type::LiteralString(_) => *ty = stdlib.str().clone().to_type(),
_ => {}
});
self
}

// Attempt at a function that will convert @ to Any for now.
pub fn clean_var(self) -> Type {
self.transform(&mut |ty| match &ty {
Expand Down
11 changes: 11 additions & 0 deletions pyrefly/lib/alt/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
let elem_ty = self.expr_infer_with_hint_promote(
&x.elt,
elem_hint.as_ref().map(|hint| hint.as_ref()),
true,
errors,
);
self.heap.mk_class_type(self.stdlib.list(elem_ty))
Expand All @@ -504,6 +505,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
let elem_ty = self.expr_infer_with_hint_promote(
&x.elt,
elem_hint.as_ref().map(|hint| hint.as_ref()),
true,
errors,
);
self.heap.mk_class_type(self.stdlib.set(elem_ty))
Expand All @@ -515,11 +517,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
let key_ty = self.expr_infer_with_hint_promote(
&x.key,
key_hint.as_ref().map(|hint| hint.as_ref()),
true,
errors,
);
let value_ty = self.expr_infer_with_hint_promote(
&x.value,
value_hint.as_ref().map(|hint| hint.as_ref()),
true,
errors,
);
self.heap.mk_class_type(self.stdlib.dict(key_ty, value_ty))
Expand Down Expand Up @@ -648,13 +652,16 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
&self,
x: &Expr,
hint: Option<HintRef>,
promote_all_literals: bool,
errors: &ErrorCollector,
) -> Type {
let ty = self.expr_infer_with_hint(x, hint, errors);
if let Some(want) = hint
&& self.is_subset_eq(&ty, want.ty())
{
want.ty().clone()
} else if promote_all_literals {
ty.promote_all_literals(self.stdlib)
} else {
ty.promote_implicit_literals(self.stdlib)
}
Expand Down Expand Up @@ -973,11 +980,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
let key_t = self.expr_infer_with_hint_promote(
key,
key_hint.as_ref().map(|hint| hint.as_ref()),
false,
errors,
);
let value_t = self.expr_infer_with_hint_promote(
&x.value,
value_hint.as_ref().map(|hint| hint.as_ref()),
false,
errors,
);
if !key_t.is_error() {
Expand Down Expand Up @@ -1793,6 +1802,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
let unpacked_ty = self.expr_infer_with_hint_promote(
value,
star_hint.as_ref().map(|hint| hint.as_ref()),
false,
errors,
);
if let Some(iterable_ty) = self.unwrap_iterable(&unpacked_ty) {
Expand All @@ -1812,6 +1822,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
_ => self.expr_infer_with_hint_promote(
x,
elt_hint.as_ref().map(|hint| hint.as_ref()),
false,
errors,
),
})
Expand Down
58 changes: 50 additions & 8 deletions pyrefly/lib/alt/narrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use num_traits::ToPrimitive;
use pyrefly_config::error_kind::ErrorKind;
use pyrefly_graph::index::Idx;
use pyrefly_python::ast::Ast;
use pyrefly_python::dunder;
use pyrefly_types::class::Class;
use pyrefly_types::display::TypeDisplayContext;
use pyrefly_types::facet::FacetChain;
Expand Down Expand Up @@ -108,8 +109,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
}

fn is_final(&self, cls: &ClassType) -> bool {
let class = cls.class_object();
fn is_final(&self, class: &Class) -> bool {
self.get_metadata_for_class(class).is_final()
|| (self.get_enum_from_class(class).is_some()
&& !self.get_enum_members(class).is_empty())
Expand Down Expand Up @@ -145,7 +145,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
fallback
} else if let Type::ClassType(left_cls) = left
&& let Type::ClassType(right_cls) = right
&& (self.is_final(left_cls) || self.is_final(right_cls))
&& (self.is_final(left_cls.class_object())
|| self.is_final(right_cls.class_object()))
{
// The only way for `left & right` to exist is if it is an instance of a class that
// multiply inherits from both `left` and `right`'s classes. But at least one of
Expand Down Expand Up @@ -300,6 +301,31 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
self.intersects(&res)
}

/// Narrow `type(X) != Y`. We can only do negative narrowing if Y is final,
/// because otherwise X could still be a subclass of Y.
fn narrow_type_not_eq(&self, left: &Type, right_expr: &Expr, errors: &ErrorCollector) -> Type {
let right = self.expr_infer(right_expr, errors);
// Only narrow if the RHS is a final class type (e.g., `type(x) != bool`)
if let Type::ClassDef(cls) = &right
&& self.is_final(cls)
{
self.distribute_over_union(left, |l| {
if let Some((tparams, unwrapped)) = self.unwrap_class_object_silently(&right) {
let (vs, unwrapped) =
self.solver()
.fresh_quantified(&tparams, unwrapped, self.uniques);
let result = self.subtract(l, &unwrapped);
let _specialization_errors = self.solver().finish_quantified(vs, false);
result
} else {
l.clone()
}
})
} else {
left.clone()
}
}

/// Turn an expression into a list of (type, allows_negative_narrow) pairs.
/// allows_negative_narrow means that we can do `not isinstance`/`not issubclass` narrowing
/// with the type. We allow negative narrowing as long as it is not definitely unsafe - that
Expand Down Expand Up @@ -331,7 +357,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
if let Type::Type(box Type::ClassType(cls)) = &t {
// If `C` is not final, `type[C]` may be a subclass of `C`,
// making negative narrowing unsafe.
let allows_negative_narrow = me.is_final(cls);
let allows_negative_narrow = me.is_final(cls.class_object());
res.push((t, allows_negative_narrow));
} else {
for t in me.as_class_info(t) {
Expand Down Expand Up @@ -506,6 +532,21 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
range: TextRange,
errors: &ErrorCollector,
) -> Option<Type> {
// We narrow `X.__class__ == Y` the same way as `type(X) == Y`
if let FacetKind::Attribute(attr) = facet
&& *attr == dunder::CLASS
{
match op {
AtomicNarrowOp::Is(v) | AtomicNarrowOp::Eq(v) => {
let right = self.expr_infer(v, errors);
return Some(self.narrow_isinstance(base, &right));
}
AtomicNarrowOp::IsNot(v) | AtomicNarrowOp::NotEq(v) => {
return Some(self.narrow_type_not_eq(base, v, errors));
}
_ => {}
}
}
match op {
AtomicNarrowOp::Is(v) => {
let right = self.expr_infer(v, errors);
Expand Down Expand Up @@ -1013,13 +1054,14 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
AtomicNarrowOp::IsNotInstance(v, _source) => self.narrow_is_not_instance(ty, v, errors),
AtomicNarrowOp::TypeEq(v) => {
// If type(X) == Y then X can't be a subclass of Y
// If type(X) == Y then X has to be exactly Y, not a subclass of Y
// We can't model that, so we narrow it exactly like isinstance(X, Y)
let right = self.expr_infer(v, errors);
self.narrow_isinstance(ty, &right)
}
// Even if type(X) != Y, X can still be a subclass of Y so we can't do any negative refinement
AtomicNarrowOp::TypeNotEq(_) => ty.clone(),
// If type(X) != Y, X can still be a subclass of Y so we can't do negative refinement
// unless Y is final, in which case X cannot be a subclass of Y
AtomicNarrowOp::TypeNotEq(v) => self.narrow_type_not_eq(ty, v, errors),
AtomicNarrowOp::IsSubclass(v) => {
let right = self.expr_infer(v, errors);
self.narrow_issubclass(ty, &right, v.range(), errors)
Expand Down Expand Up @@ -1604,7 +1646,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
Type::ClassType(cls) | Type::SelfType(cls) => {
// Final classes can't have subclasses, so they are exhaustible, with the exception
// of Flag enums, whose members can be combined into new members via bitwise ops
!self.is_flag_enum(cls) && self.is_final(cls)
!self.is_flag_enum(cls) && self.is_final(cls.class_object())
// bool is effectively Literal[True] | Literal[False]
|| cls.is_builtin("bool")
}
Expand Down
22 changes: 22 additions & 0 deletions pyrefly/lib/test/attribute_narrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,28 @@ def f(foo: Foo):
"#,
);

testcase!(
test_dunder_class_attribute_narrow,
r#"
from typing import assert_type
def f(x: int | str):
if x.__class__ is int:
assert_type(x, int)
else:
assert_type(x, int | str)
if x.__class__ == int:
assert_type(x, int)
def g(x: int | str):
assert x.__class__ is int
assert_type(x, int)
def h(x: bool | str):
if x.__class__ is bool:
assert_type(x, bool)
else:
assert_type(x, str)
"#,
);

// The expected behavior when narrowing an invalid attribute chain is to produce
// type errors at the narrow site, but apply the narrowing downstream
// (motivation: being noisy downstream could be quite frustrating for gradually
Expand Down
42 changes: 42 additions & 0 deletions pyrefly/lib/test/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,48 @@ assert_type(x, list[LiteralString])
"#,
);

testcase!(
test_promote_literal_in_dict_comprehension,
r#"
from typing import assert_type
from string import ascii_uppercase

# LiteralString from iterator should be promoted to str in comprehensions
letter_to_index = {char: i for i, char in enumerate(ascii_uppercase)}
assert_type(letter_to_index, dict[str, int])

def encode(message: str) -> list[int]:
result = []
for letter in message:
result.append(letter_to_index[letter])
return result
"#,
);

testcase!(
test_promote_literal_in_list_comprehension,
r#"
from typing import assert_type
from string import ascii_lowercase

# LiteralString from iterator should be promoted to str in comprehensions
chars = [c for c in ascii_lowercase]
assert_type(chars, list[str])
"#,
);

testcase!(
test_promote_literal_in_set_comprehension,
r#"
from typing import assert_type
from string import ascii_lowercase

# LiteralString from iterator should be promoted to str in comprehensions
chars = {c for c in ascii_lowercase}
assert_type(chars, set[str])
"#,
);

testcase!(
test_literal_string_format,
r#"
Expand Down
18 changes: 18 additions & 0 deletions pyrefly/lib/test/narrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,24 @@ def foo(x: int | None) -> None:
"#,
);

testcase!(
test_type_not_eq_final,
r#"
from typing import assert_type
def f(x: str | int | bool):
# bool is final, so we can narrow it away
if type(x) != bool:
assert_type(x, str | int)
else:
assert_type(x, bool)
# str is not final, so we can't narrow it away (subclasses of str are possible)
if type(x) != str:
assert_type(x, str | int | bool)
else:
assert_type(x, str)
"#,
);

testcase!(
test_isinstance_union,
r#"
Expand Down