Skip to content

Commit 21f53c3

Browse files
asukaminato0721meta-codesync[bot]
authored andcommitted
fix Using a Literal key to update values in a TypeDict #329 (#1848)
Summary: Fixes #329 now threads a BindingsBuilder through NarrowOps, so expr_to_subjects can resolve data[key] by inspecting the binding that defines key. We added helper logic that looks up the original Binding (either a NameAssign or FunctionParameter) to pull its Literal annotation via the new BindingsBuilder::get_annotation_by_idx helper. The existing literal-subscript walkers were generalized to accept resolver callbacks, and the resolver backed by annotations is used everywhere in the binder. gained literal_string_key_from_expr, letting expression inference recover literal keys from either syntax or inferred type. subscript_infer now calls this helper so TypeInfo::at_facet is entered even when the index expression is a Literal-typed variable. switches Binding::AssignToSubscript over to the same resolver (via the newly exported _with_resolver helpers) so assignments such as data[key] = [] update the correct TypedDict facet even when key is a literal-typed variable. other assertion helpers were updated to pass the builder into from_single_narrow_op, ensuring unit-test helpers keep benefiting from the new literal detection. Pull Request resolved: #1848 Test Plan: add tests Reviewed By: rchen152 Differential Revision: D89935230 Pulled By: yangdanny97 fbshipit-source-id: 5eaa0070c4b6608d25f88cfa2db2e5df7f8486c9
1 parent 631f21e commit 21f53c3

File tree

7 files changed

+318
-112
lines changed

7 files changed

+318
-112
lines changed

crates/pyrefly_types/src/facet.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use std::fmt;
99
use std::fmt::Display;
1010

1111
use pyrefly_derive::TypeEq;
12+
use ruff_python_ast::ExprName;
1213
use ruff_python_ast::name::Name;
1314
use vec1::Vec1;
1415

@@ -72,3 +73,51 @@ impl Display for FacetChain {
7273
Ok(())
7374
}
7475
}
76+
77+
// This is like `FacetKind`, but it can also represent subscripts that are arbitrary names with unknown types
78+
// `VariableSubscript` may resolve to a `FacetKind::Index`, `FacetKind::Key`, or nothing at all
79+
// depending on the type of the variable it contains
80+
#[derive(Debug, Clone, PartialEq)]
81+
pub enum UnresolvedFacetKind {
82+
Attribute(Name),
83+
Index(usize),
84+
Key(String),
85+
VariableSubscript(ExprName),
86+
}
87+
88+
impl Display for UnresolvedFacetKind {
89+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
90+
match self {
91+
Self::Attribute(name) => write!(f, ".{name}"),
92+
Self::Index(idx) => write!(f, "[{idx}]"),
93+
Self::Key(key) => write!(f, "[\"{key}\"]"),
94+
Self::VariableSubscript(name) => write!(f, "[{}]", name.id),
95+
}
96+
}
97+
}
98+
99+
// This is like `FacetChain`, but it can also represent subscripts that are arbitrary names with unknown types
100+
// It gets resolved to `FacetChain` if all names in the chain resolve to literal int or string types
101+
#[derive(Clone, Debug)]
102+
pub struct UnresolvedFacetChain(pub Box<Vec1<UnresolvedFacetKind>>);
103+
104+
impl UnresolvedFacetChain {
105+
pub fn new(chain: Vec1<UnresolvedFacetKind>) -> Self {
106+
Self(Box::new(chain))
107+
}
108+
109+
pub fn facets(&self) -> &Vec1<UnresolvedFacetKind> {
110+
match self {
111+
Self(chain) => chain,
112+
}
113+
}
114+
}
115+
116+
impl Display for UnresolvedFacetChain {
117+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
118+
for facet in self.0.iter() {
119+
write!(f, "{facet}")?;
120+
}
121+
Ok(())
122+
}
123+
}

pyrefly/lib/alt/expr.rs

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,21 +1245,31 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
12451245
range: TextRange,
12461246
errors: &ErrorCollector,
12471247
) -> TypeInfo {
1248-
match slice {
1249-
Expr::NumberLiteral(ExprNumberLiteral {
1250-
value: Number::Int(idx),
1251-
..
1252-
}) if let Some(idx) = idx.as_usize() => {
1253-
TypeInfo::at_facet(base, &FacetKind::Index(idx), || {
1254-
self.subscript_infer_for_type(base.ty(), slice, range, errors)
1255-
})
1256-
}
1257-
Expr::StringLiteral(ExprStringLiteral { value: key, .. }) => {
1258-
TypeInfo::at_facet(base, &FacetKind::Key(key.to_string()), || {
1259-
self.subscript_infer_for_type(base.ty(), slice, range, errors)
1260-
})
1248+
if let Expr::NumberLiteral(ExprNumberLiteral {
1249+
value: Number::Int(idx),
1250+
..
1251+
}) = slice
1252+
&& let Some(idx) = idx.as_usize()
1253+
{
1254+
TypeInfo::at_facet(base, &FacetKind::Index(idx), || {
1255+
self.subscript_infer_for_type(base.ty(), slice, range, errors)
1256+
})
1257+
} else if let Expr::StringLiteral(ExprStringLiteral { value, .. }) = slice {
1258+
TypeInfo::at_facet(base, &FacetKind::Key(value.to_string()), || {
1259+
self.subscript_infer_for_type(base.ty(), slice, range, errors)
1260+
})
1261+
} else {
1262+
let swallower = self.error_swallower();
1263+
match self.expr_infer(slice, &swallower) {
1264+
Type::Literal(Lit::Str(value)) => {
1265+
TypeInfo::at_facet(base, &FacetKind::Key(value.to_string()), || {
1266+
self.subscript_infer_for_type(base.ty(), slice, range, errors)
1267+
})
1268+
}
1269+
_ => {
1270+
TypeInfo::of_ty(self.subscript_infer_for_type(base.ty(), slice, range, errors))
1271+
}
12611272
}
1262-
_ => TypeInfo::of_ty(self.subscript_infer_for_type(base.ty(), slice, range, errors)),
12631273
}
12641274
}
12651275

pyrefly/lib/alt/narrow.rs

Lines changed: 106 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ use pyrefly_graph::index::Idx;
1111
use pyrefly_python::ast::Ast;
1212
use pyrefly_types::class::Class;
1313
use pyrefly_types::display::TypeDisplayContext;
14+
use pyrefly_types::facet::FacetChain;
15+
use pyrefly_types::facet::FacetKind;
16+
use pyrefly_types::facet::UnresolvedFacetChain;
17+
use pyrefly_types::facet::UnresolvedFacetKind;
1418
use pyrefly_types::simplify::intersect;
1519
use pyrefly_types::type_info::JoinStyle;
1620
use pyrefly_util::prelude::SliceExt;
@@ -43,8 +47,6 @@ use crate::error::context::ErrorInfo;
4347
use crate::error::style::ErrorStyle;
4448
use crate::types::callable::FunctionKind;
4549
use crate::types::class::ClassType;
46-
use crate::types::facet::FacetChain;
47-
use crate::types::facet::FacetKind;
4850
use crate::types::lit_int::LitInt;
4951
use crate::types::literal::Lit;
5052
use crate::types::tuple::Tuple;
@@ -945,17 +947,19 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
945947
) -> TypeInfo {
946948
match op {
947949
NarrowOp::Atomic(subject, AtomicNarrowOp::HasKey(key)) => {
948-
let base_ty = match subject {
949-
Some(facet_subject) => {
950-
self.get_facet_chain_type(type_info, &facet_subject.chain, range)
951-
}
952-
None => self.force_for_narrowing(type_info.ty(), range, errors),
950+
let resolved_chain = subject
951+
.as_ref()
952+
.and_then(|s| self.resolve_facet_chain(s.chain.clone()));
953+
let base_ty = match (&subject, &resolved_chain) {
954+
(Some(_), Some(chain)) => self.get_facet_chain_type(type_info, chain, range),
955+
(Some(_), None) => return type_info.clone(),
956+
(None, _) => self.force_for_narrowing(type_info.ty(), range, errors),
953957
};
954958
if matches!(base_ty, Type::TypedDict(_)) {
955959
let key_facet = FacetKind::Key(key.to_string());
956-
let facets = match subject {
957-
Some(facet_subject) => {
958-
let mut new_facets = facet_subject.chain.facets().clone();
960+
let facets = match resolved_chain {
961+
Some(chain) => {
962+
let mut new_facets = chain.facets().clone();
959963
new_facets.push(key_facet);
960964
new_facets
961965
}
@@ -971,17 +975,19 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
971975
}
972976
}
973977
NarrowOp::Atomic(subject, AtomicNarrowOp::NotHasKey(key)) => {
974-
let base_ty = match subject {
975-
Some(facet_subject) => {
976-
self.get_facet_chain_type(type_info, &facet_subject.chain, range)
977-
}
978-
None => self.force_for_narrowing(type_info.ty(), range, errors),
978+
let resolved_chain = subject
979+
.as_ref()
980+
.and_then(|s| self.resolve_facet_chain(s.chain.clone()));
981+
let base_ty = match (&subject, &resolved_chain) {
982+
(Some(_), Some(chain)) => self.get_facet_chain_type(type_info, chain, range),
983+
(Some(_), None) => return type_info.clone(),
984+
(None, _) => self.force_for_narrowing(type_info.ty(), range, errors),
979985
};
980986
if matches!(base_ty, Type::TypedDict(_)) {
981987
let key_facet = FacetKind::Key(key.to_string());
982-
let facets = match subject {
983-
Some(facet_subject) => {
984-
let mut new_facets = facet_subject.chain.facets().clone();
988+
let facets = match resolved_chain {
989+
Some(chain) => {
990+
let mut new_facets = chain.facets().clone();
985991
new_facets.push(key_facet);
986992
new_facets
987993
}
@@ -996,18 +1002,20 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
9961002
}
9971003
}
9981004
NarrowOp::Atomic(subject, AtomicNarrowOp::HasAttr(attr)) => {
999-
let base_ty = match subject {
1000-
Some(facet_subject) => {
1001-
self.get_facet_chain_type(type_info, &facet_subject.chain, range)
1002-
}
1003-
None => self.force_for_narrowing(type_info.ty(), range, errors),
1005+
let resolved_chain = subject
1006+
.as_ref()
1007+
.and_then(|s| self.resolve_facet_chain(s.chain.clone()));
1008+
let base_ty = match (&subject, &resolved_chain) {
1009+
(Some(_), Some(chain)) => self.get_facet_chain_type(type_info, chain, range),
1010+
(Some(_), None) => return type_info.clone(),
1011+
(None, _) => self.force_for_narrowing(type_info.ty(), range, errors),
10041012
};
10051013
// We only narrow the attribute to `Any` if the attribute does not exist
10061014
if !self.has_attr(&base_ty, attr) {
10071015
let attr_facet = FacetKind::Attribute(attr.clone());
1008-
let facets = match subject {
1009-
Some(facet_subject) => {
1010-
let mut new_facets = facet_subject.chain.facets().clone();
1016+
let facets = match resolved_chain {
1017+
Some(chain) => {
1018+
let mut new_facets = chain.facets().clone();
10111019
new_facets.push(attr_facet);
10121020
new_facets
10131021
}
@@ -1028,18 +1036,20 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
10281036
if self.as_bool(&default_ty, range, &suppress_errors) != Some(false) {
10291037
return type_info.clone();
10301038
}
1031-
let base_ty = match subject {
1032-
Some(facet_subject) => {
1033-
self.get_facet_chain_type(type_info, &facet_subject.chain, range)
1034-
}
1035-
None => self.force_for_narrowing(type_info.ty(), range, errors),
1039+
let resolved_chain = subject
1040+
.as_ref()
1041+
.and_then(|s| self.resolve_facet_chain(s.chain.clone()));
1042+
let base_ty = match (&subject, &resolved_chain) {
1043+
(Some(_), Some(chain)) => self.get_facet_chain_type(type_info, chain, range),
1044+
(Some(_), None) => return type_info.clone(),
1045+
(None, _) => self.force_for_narrowing(type_info.ty(), range, errors),
10361046
};
10371047
let attr_ty =
10381048
self.attr_infer_for_type(&base_ty, attr, range, &suppress_errors, None);
10391049
let attr_facet = FacetKind::Attribute(attr.clone());
1040-
let facets = match subject {
1041-
Some(facet_subject) => {
1042-
let mut new_facets = facet_subject.chain.facets().clone();
1050+
let facets = match resolved_chain {
1051+
Some(chain) => {
1052+
let mut new_facets = chain.facets().clone();
10431053
new_facets.push(attr_facet);
10441054
new_facets
10451055
}
@@ -1070,22 +1080,26 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
10701080
type_info.clone().with_ty(ty)
10711081
}
10721082
NarrowOp::Atomic(Some(facet_subject), op) => {
1083+
let Some(resolved_chain) = self.resolve_facet_chain(facet_subject.chain.clone())
1084+
else {
1085+
return type_info.clone();
1086+
};
10731087
if facet_subject.origin == FacetOrigin::GetMethod
10741088
&& !self.supports_dict_get_subject(type_info, facet_subject, range)
10751089
{
10761090
return type_info.clone();
10771091
}
10781092
let ty = self.atomic_narrow(
1079-
&self.get_facet_chain_type(type_info, &facet_subject.chain, range),
1093+
&self.get_facet_chain_type(type_info, &resolved_chain, range),
10801094
op,
10811095
range,
10821096
errors,
10831097
);
1084-
let mut narrowed = type_info.with_narrow(facet_subject.chain.facets(), ty);
1098+
let mut narrowed = type_info.with_narrow(resolved_chain.facets(), ty);
10851099
// For certain types of narrows, we can also narrow the parent of the current subject
10861100
// If `.get()` on a dict or TypedDict is falsy, the key may not be present at all
10871101
// We should invalidate any existing narrows
1088-
if let Some((last, prefix)) = facet_subject.chain.facets().split_last() {
1102+
if let Some((last, prefix)) = resolved_chain.facets().split_last() {
10891103
match Vec1::try_from(prefix) {
10901104
Ok(prefix_facets) => {
10911105
let prefix_chain = FacetChain::new(prefix_facets);
@@ -1094,7 +1108,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
10941108
let dict_get_key_falsy = matches!(op, AtomicNarrowOp::IsFalsy)
10951109
&& matches!(last, FacetKind::Key(_));
10961110
if dict_get_key_falsy {
1097-
narrowed.update_for_assignment(facet_subject.chain.facets(), None);
1111+
narrowed.update_for_assignment(resolved_chain.facets(), None);
10981112
} else if let Some(narrowed_ty) =
10991113
self.atomic_narrow_for_facet(&base_ty, last, op, range, errors)
11001114
&& narrowed_ty != base_ty
@@ -1107,7 +1121,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
11071121
let dict_get_key_falsy = matches!(op, AtomicNarrowOp::IsFalsy)
11081122
&& matches!(last, FacetKind::Key(_));
11091123
if dict_get_key_falsy {
1110-
narrowed.update_for_assignment(facet_subject.chain.facets(), None);
1124+
narrowed.update_for_assignment(resolved_chain.facets(), None);
11111125
} else if let Some(narrowed_ty) =
11121126
self.atomic_narrow_for_facet(base_ty, last, op, range, errors)
11131127
&& narrowed_ty != *base_ty
@@ -1147,14 +1161,16 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
11471161
subject: &FacetSubject,
11481162
range: TextRange,
11491163
) -> bool {
1150-
let base_ty = if subject.chain.facets().len() == 1 {
1164+
let Some(resolved_chain) = self.resolve_facet_chain(subject.chain.clone()) else {
1165+
return false;
1166+
};
1167+
let base_ty = if resolved_chain.facets().len() == 1 {
11511168
type_info.ty().clone()
11521169
} else {
1153-
let prefix: Vec<_> = subject
1154-
.chain
1170+
let prefix: Vec<_> = resolved_chain
11551171
.facets()
11561172
.iter()
1157-
.take(subject.chain.facets().len() - 1)
1173+
.take(resolved_chain.facets().len() - 1)
11581174
.cloned()
11591175
.collect();
11601176
match Vec1::try_from_vec(prefix) {
@@ -1232,26 +1248,28 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
12321248
return;
12331249
}
12341250
let ignore_errors = self.error_swallower();
1235-
let narrowing_subject_info = match narrowing_subject {
1236-
NarrowingSubject::Name(_) => &subject_info,
1251+
// Get the narrowed type of the match subject when none of the cases match
1252+
let mut remaining_ty = match narrowing_subject {
1253+
NarrowingSubject::Name(_) => self
1254+
.narrow(&subject_info, op.as_ref(), *narrow_range, &ignore_errors)
1255+
.ty()
1256+
.clone(),
12371257
NarrowingSubject::Facets(_, facets) => {
1258+
let Some(resolved_chain) = self.resolve_facet_chain(facets.chain.clone()) else {
1259+
return;
1260+
};
12381261
// If the narrowing subject is the facet of some variable like `x.foo`,
12391262
// We need to make a `TypeInfo` rooted at `x` using the type of `x.foo`
12401263
let type_info = TypeInfo::of_ty(Type::any_implicit());
1241-
&type_info.with_narrow(facets.chain.facets(), subject_ty.clone())
1242-
}
1243-
};
1244-
// Get the narrowed type of the match subject when none of the cases match
1245-
let narrowed = self.narrow(
1246-
narrowing_subject_info,
1247-
op.as_ref(),
1248-
*narrow_range,
1249-
&ignore_errors,
1250-
);
1251-
let mut remaining_ty = match narrowing_subject {
1252-
NarrowingSubject::Name(_) => narrowed.ty().clone(),
1253-
NarrowingSubject::Facets(_, facets) => {
1254-
self.get_facet_chain_type(&narrowed, &facets.chain, *subject_range)
1264+
let narrowing_subject_info =
1265+
type_info.with_narrow(resolved_chain.facets(), subject_ty.clone());
1266+
let narrowed = self.narrow(
1267+
&narrowing_subject_info,
1268+
op.as_ref(),
1269+
*narrow_range,
1270+
&ignore_errors,
1271+
);
1272+
self.get_facet_chain_type(&narrowed, &resolved_chain, *subject_range)
12551273
}
12561274
};
12571275
self.expand_vars_mut(&mut remaining_ty);
@@ -1275,4 +1293,33 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
12751293
msg,
12761294
);
12771295
}
1296+
1297+
pub fn resolve_facet_chain(&self, unresolved: UnresolvedFacetChain) -> Option<FacetChain> {
1298+
let resolved: Option<Vec<FacetKind>> = unresolved
1299+
.facets()
1300+
.iter()
1301+
.map(|kind| self.resolve_facet_kind(kind.clone()))
1302+
.collect();
1303+
resolved.map(|facets| FacetChain::new(Vec1::try_from_vec(facets).unwrap()))
1304+
}
1305+
1306+
pub fn resolve_facet_kind(&self, unresolved: UnresolvedFacetKind) -> Option<FacetKind> {
1307+
match unresolved {
1308+
UnresolvedFacetKind::Attribute(name) => Some(FacetKind::Attribute(name)),
1309+
UnresolvedFacetKind::Index(idx) => Some(FacetKind::Index(idx)),
1310+
UnresolvedFacetKind::Key(key) => Some(FacetKind::Key(key)),
1311+
UnresolvedFacetKind::VariableSubscript(expr_name) => {
1312+
let suppress_errors = self.error_swallower();
1313+
let ty = self.expr_infer(&Expr::Name(expr_name), &suppress_errors);
1314+
match ty {
1315+
Type::Literal(Lit::Int(lit_int)) => lit_int
1316+
.as_i64()
1317+
.and_then(|i| i.to_usize())
1318+
.map(FacetKind::Index),
1319+
Type::Literal(Lit::Str(s)) => Some(FacetKind::Key(s.to_string())),
1320+
_ => None,
1321+
}
1322+
}
1323+
}
1324+
}
12781325
}

0 commit comments

Comments
 (0)