Skip to content

Commit ffde20a

Browse files
committed
Fix condition narrow
1 parent 162c6fe commit ffde20a

File tree

3 files changed

+160
-7
lines changed

3 files changed

+160
-7
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
use std::ops::Deref;
2+
3+
use crate::{DbIndex, LuaType, get_real_type};
4+
5+
pub fn intersect_type(db: &DbIndex, source: LuaType, target: LuaType) -> LuaType {
6+
let real_type = get_real_type(db, &source).unwrap_or(&source);
7+
8+
match (&real_type, &target) {
9+
// ANY & T = T
10+
(LuaType::Any, _) => target.clone(),
11+
(_, LuaType::Any) => real_type.clone(),
12+
(LuaType::Never, _) => LuaType::Never,
13+
(_, LuaType::Never) => LuaType::Never,
14+
(LuaType::Unknown, _) => target,
15+
(_, LuaType::Unknown) => source,
16+
// int | int const
17+
(LuaType::Integer, LuaType::IntegerConst(i) | LuaType::DocIntegerConst(i)) => {
18+
LuaType::IntegerConst(*i)
19+
}
20+
(LuaType::IntegerConst(i) | LuaType::DocIntegerConst(i), LuaType::Integer) => {
21+
LuaType::IntegerConst(*i)
22+
}
23+
// float | float const
24+
(LuaType::Number, right) if right.is_number() => LuaType::Number,
25+
(left, LuaType::Number) if left.is_number() => LuaType::Number,
26+
// string | string const
27+
(LuaType::String, LuaType::StringConst(s) | LuaType::DocStringConst(s)) => {
28+
LuaType::StringConst(s.clone())
29+
}
30+
(LuaType::StringConst(s) | LuaType::DocStringConst(s), LuaType::String) => {
31+
LuaType::StringConst(s.clone())
32+
}
33+
// boolean | boolean const
34+
(LuaType::Boolean, LuaType::BooleanConst(b)) => LuaType::BooleanConst(*b),
35+
(LuaType::BooleanConst(b), LuaType::Boolean) => LuaType::BooleanConst(*b),
36+
(LuaType::BooleanConst(left), LuaType::BooleanConst(right)) => {
37+
if left == right {
38+
LuaType::BooleanConst(*left)
39+
} else {
40+
LuaType::Never
41+
}
42+
}
43+
// table | table const
44+
(LuaType::Table, LuaType::TableConst(t)) => LuaType::TableConst(t.clone()),
45+
(LuaType::TableConst(t), LuaType::Table) => LuaType::TableConst(t.clone()),
46+
// function | function const
47+
(LuaType::Function, LuaType::DocFunction(_) | LuaType::Signature(_)) => target.clone(),
48+
(LuaType::DocFunction(_) | LuaType::Signature(_), LuaType::Function) => real_type.clone(),
49+
// class references
50+
(LuaType::Ref(id1), LuaType::Ref(id2)) => {
51+
if id1 == id2 {
52+
source.clone()
53+
} else {
54+
LuaType::Never
55+
}
56+
}
57+
(LuaType::MultiLineUnion(left), right) => {
58+
let include = match right {
59+
LuaType::StringConst(v) => {
60+
left.get_unions().iter().any(|(t, _)| match (t, right) {
61+
(LuaType::DocStringConst(a), _) => a == v,
62+
_ => false,
63+
})
64+
}
65+
LuaType::IntegerConst(v) => {
66+
left.get_unions().iter().any(|(t, _)| match (t, right) {
67+
(LuaType::DocIntegerConst(a), _) => a == v,
68+
_ => false,
69+
})
70+
}
71+
_ => false,
72+
};
73+
74+
if include {
75+
return source;
76+
}
77+
LuaType::from_vec(vec![source, target])
78+
}
79+
// union ∩ non-union: (A | B) ∩ C = (A ∩ C) | (B ∩ C)
80+
(LuaType::Union(left), right) if !right.is_union() => {
81+
let left_types = left.deref().clone().into_vec();
82+
let mut result_types = Vec::new();
83+
84+
for left_type in left_types {
85+
let intersected = intersect_type(db, left_type, right.clone());
86+
if !matches!(intersected, LuaType::Never) {
87+
result_types.push(intersected);
88+
}
89+
}
90+
91+
if result_types.is_empty() {
92+
LuaType::Never
93+
} else {
94+
LuaType::from_vec(result_types)
95+
}
96+
}
97+
// non-union ∩ union: A ∩ (B | C) = (A ∩ B) | (A ∩ C)
98+
(left, LuaType::Union(right)) if !left.is_union() => {
99+
let right_types = right.deref().clone().into_vec();
100+
let mut result_types = Vec::new();
101+
102+
for right_type in right_types {
103+
let intersected = intersect_type(db, real_type.clone(), right_type);
104+
if !matches!(intersected, LuaType::Never) {
105+
result_types.push(intersected);
106+
}
107+
}
108+
109+
if result_types.is_empty() {
110+
LuaType::Never
111+
} else {
112+
LuaType::from_vec(result_types)
113+
}
114+
}
115+
// union ∩ union: (A | B) ∩ (C | D) = (A ∩ C) | (A ∩ D) | (B ∩ C) | (B ∩ D)
116+
(LuaType::Union(left), LuaType::Union(right)) => {
117+
let left_types = left.deref().clone().into_vec();
118+
let right_types = right.deref().clone().into_vec();
119+
let mut result_types = Vec::new();
120+
121+
for left_type in left_types {
122+
for right_type in &right_types {
123+
let intersected = intersect_type(db, left_type.clone(), right_type.clone());
124+
if !matches!(intersected, LuaType::Never) {
125+
result_types.push(intersected);
126+
}
127+
}
128+
}
129+
130+
if result_types.is_empty() {
131+
LuaType::Never
132+
} else {
133+
LuaType::from_vec(result_types)
134+
}
135+
}
136+
137+
// same type
138+
(left, right) if *left == right => source.clone(),
139+
_ => LuaType::Never,
140+
}
141+
}

crates/emmylua_code_analysis/src/db_index/type/type_ops/mod.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1+
mod intersect_type;
12
mod remove_type;
23
mod test;
34
mod union_type;
45

5-
use crate::DbIndex;
6-
76
use super::LuaType;
7+
use crate::DbIndex;
88

99
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
1010
pub enum TypeOps {
1111
/// Add a type to the source type
1212
Union,
13+
/// Intersect a type with the source type
14+
Intersect,
1315
/// Remove a type from the source type
1416
Remove,
1517
}
@@ -18,6 +20,9 @@ impl TypeOps {
1820
pub fn apply(&self, db: &DbIndex, source: &LuaType, target: &LuaType) -> LuaType {
1921
match self {
2022
TypeOps::Union => union_type::union_type(db, source.clone(), target.clone()),
23+
TypeOps::Intersect => {
24+
intersect_type::intersect_type(db, source.clone(), target.clone())
25+
}
2126
TypeOps::Remove => {
2227
let result = remove_type::remove_type(db, source.clone(), target.clone());
2328
if let Some(result) = result {

crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,9 @@ fn maybe_var_eq_narrow(
362362
return Ok(ResultTypeOrContinue::Continue);
363363
}
364364

365+
let antecedent_flow_id = get_single_antecedent(tree, flow_node)?;
366+
let left_type =
367+
get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?;
365368
let right_expr_type = infer_expr(db, cache, right_expr)?;
366369

367370
let result_type = match condition_flow {
@@ -370,14 +373,18 @@ fn maybe_var_eq_narrow(
370373
if var_ref_id.is_self_ref() && !right_expr_type.is_nil() {
371374
TypeOps::Remove.apply(db, &right_expr_type, &LuaType::Nil)
372375
} else {
373-
right_expr_type
376+
let left_maybe_type =
377+
TypeOps::Intersect.apply(db, &left_type, &right_expr_type);
378+
379+
if left_maybe_type.is_never() {
380+
left_type
381+
} else {
382+
left_maybe_type
383+
}
374384
}
375385
}
376386
InferConditionFlow::FalseCondition => {
377-
let antecedent_flow_id = get_single_antecedent(tree, flow_node)?;
378-
let antecedent_type =
379-
get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?;
380-
TypeOps::Remove.apply(db, &antecedent_type, &right_expr_type)
387+
TypeOps::Remove.apply(db, &left_type, &right_expr_type)
381388
}
382389
};
383390
Ok(ResultTypeOrContinue::Result(result_type))

0 commit comments

Comments
 (0)