Skip to content

Commit 5f6351c

Browse files
committed
support inherit flow infer from local const
1 parent da17507 commit 5f6351c

File tree

6 files changed

+162
-18
lines changed

6 files changed

+162
-18
lines changed

crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/stats.rs

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use emmylua_parser::{
2-
LuaAssignStat, LuaAst, LuaAstNode, LuaBlock, LuaBreakStat, LuaCallArgList, LuaCallExprStat,
3-
LuaDoStat, LuaForRangeStat, LuaForStat, LuaFuncStat, LuaGotoStat, LuaIfStat, LuaLabelStat,
4-
LuaLocalStat, LuaRepeatStat, LuaReturnStat, LuaWhileStat,
2+
BinaryOperator, LuaAssignStat, LuaAst, LuaAstNode, LuaBlock, LuaBreakStat, LuaCallArgList,
3+
LuaCallExprStat, LuaDoStat, LuaExpr, LuaForRangeStat, LuaForStat, LuaFuncStat, LuaGotoStat,
4+
LuaIfStat, LuaLabelStat, LuaLocalStat, LuaRepeatStat, LuaReturnStat, LuaWhileStat,
55
};
66

77
use crate::{
@@ -28,8 +28,9 @@ pub fn bind_local_stat(
2828
let name = &local_names[i];
2929
let value = &values[i];
3030
let decl_id = LuaDeclId::new(binder.file_id, name.get_position());
31-
let flow_id = bind_expr(binder, value.clone(), current);
32-
binder.decl_bind_flow_ref.insert(decl_id, flow_id);
31+
if check_local_immutable(binder, decl_id) && check_value_expr_is_check_expr(value.clone()) {
32+
binder.decl_bind_expr_ref.insert(decl_id, value.to_ptr());
33+
}
3334
}
3435

3536
for value in values {
@@ -42,6 +43,41 @@ pub fn bind_local_stat(
4243
local_flow_id
4344
}
4445

46+
fn check_local_immutable(binder: &mut FlowBinder, decl_id: LuaDeclId) -> bool {
47+
let Some(refs) = binder
48+
.db
49+
.get_reference_index()
50+
.get_decl_references(&binder.file_id, &decl_id)
51+
else {
52+
return true;
53+
};
54+
55+
for r in refs {
56+
if r.is_write {
57+
return false;
58+
}
59+
}
60+
61+
true
62+
}
63+
64+
fn check_value_expr_is_check_expr(value_expr: LuaExpr) -> bool {
65+
match value_expr {
66+
LuaExpr::BinaryExpr(binary_expr) => {
67+
let Some(op) = binary_expr.get_op_token() else {
68+
return false;
69+
};
70+
71+
match op.get_op() {
72+
BinaryOperator::OpEq | BinaryOperator::OpNe => true,
73+
_ => false,
74+
}
75+
}
76+
LuaExpr::CallExpr(_) => true,
77+
_ => false, // Other expressions can be checked
78+
}
79+
}
80+
4581
pub fn bind_assign_stat(
4682
binder: &mut FlowBinder,
4783
assign_stat: LuaAssignStat,

crates/emmylua_code_analysis/src/compilation/analyzer/flow/binder.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::collections::HashMap;
22

3-
use emmylua_parser::LuaSyntaxId;
3+
use emmylua_parser::{LuaAstPtr, LuaExpr, LuaSyntaxId};
44
use internment::ArcIntern;
55
use rowan::TextSize;
66
use smol_str::SmolStr;
@@ -14,7 +14,7 @@ use crate::{
1414
pub struct FlowBinder<'a> {
1515
pub db: &'a mut DbIndex,
1616
pub file_id: FileId,
17-
pub decl_bind_flow_ref: HashMap<LuaDeclId, FlowId>,
17+
pub decl_bind_expr_ref: HashMap<LuaDeclId, LuaAstPtr<LuaExpr>>,
1818
pub start: FlowId,
1919
pub unreachable: FlowId,
2020
pub loop_label: FlowId,
@@ -35,7 +35,7 @@ impl<'a> FlowBinder<'a> {
3535
file_id,
3636
flow_nodes: Vec::new(),
3737
multiple_antecedents: Vec::new(),
38-
decl_bind_flow_ref: HashMap::new(),
38+
decl_bind_expr_ref: HashMap::new(),
3939
labels: HashMap::new(),
4040
start: FlowId::default(),
4141
unreachable: FlowId::default(),
@@ -171,7 +171,7 @@ impl<'a> FlowBinder<'a> {
171171

172172
pub fn finish(self) -> FlowTree {
173173
FlowTree::new(
174-
self.decl_bind_flow_ref,
174+
self.decl_bind_expr_ref,
175175
self.flow_nodes,
176176
self.multiple_antecedents,
177177
// self.labels,

crates/emmylua_code_analysis/src/compilation/test/flow.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,4 +891,32 @@ end
891891
"#,
892892
);
893893
}
894+
895+
#[test]
896+
fn test_feature_inherit_flow_from_const_local() {
897+
let mut ws = VirtualWorkspace::new_with_init_std_lib();
898+
899+
ws.def(
900+
r#"
901+
local ret --- @type string | nil
902+
903+
local h = type(ret) == "string"
904+
if h then
905+
a = ret
906+
end
907+
908+
local e = type(ret)
909+
if e == "string" then
910+
b = ret
911+
end
912+
"#,
913+
);
914+
915+
let a = ws.expr_ty("a");
916+
let a_expected = ws.ty("string");
917+
assert_eq!(a, a_expected);
918+
let b = ws.expr_ty("b");
919+
let b_expected = ws.ty("string");
920+
assert_eq!(b, b_expected);
921+
}
894922
}

crates/emmylua_code_analysis/src/db_index/flow/flow_tree.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
use std::collections::HashMap;
22

3-
use emmylua_parser::LuaSyntaxId;
3+
use emmylua_parser::{LuaAstPtr, LuaExpr, LuaSyntaxId};
44

55
use crate::{FlowId, FlowNode, LuaDeclId};
66

77
#[derive(Debug)]
88
pub struct FlowTree {
9-
#[allow(unused)]
10-
decl_bind_flow_ref: HashMap<LuaDeclId, FlowId>,
9+
decl_bind_expr_ref: HashMap<LuaDeclId, LuaAstPtr<LuaExpr>>,
1110
flow_nodes: Vec<FlowNode>,
1211
multiple_antecedents: Vec<Vec<FlowId>>,
1312
// labels: HashMap<LuaClosureId, HashMap<SmolStr, FlowId>>,
@@ -16,17 +15,16 @@ pub struct FlowTree {
1615

1716
impl FlowTree {
1817
pub fn new(
19-
decl_bind_flow_ref: HashMap<LuaDeclId, FlowId>,
18+
decl_bind_expr_ref: HashMap<LuaDeclId, LuaAstPtr<LuaExpr>>,
2019
flow_nodes: Vec<FlowNode>,
2120
multiple_antecedents: Vec<Vec<FlowId>>,
2221
// labels: HashMap<LuaClosureId, HashMap<SmolStr, FlowId>>,
2322
bindings: HashMap<LuaSyntaxId, FlowId>,
2423
) -> Self {
2524
Self {
26-
decl_bind_flow_ref,
25+
decl_bind_expr_ref,
2726
flow_nodes,
2827
multiple_antecedents,
29-
// labels,
3028
bindings,
3129
}
3230
}
@@ -44,4 +42,8 @@ impl FlowTree {
4442
.get(id as usize)
4543
.map(|v| v.as_slice())
4644
}
45+
46+
pub fn get_decl_ref_expr(&self, decl_id: &LuaDeclId) -> Option<LuaAstPtr<LuaExpr>> {
47+
self.decl_bind_expr_ref.get(decl_id).cloned()
48+
}
4749
}

crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use emmylua_parser::{
2-
BinaryOperator, LuaBinaryExpr, LuaCallExpr, LuaChunk, LuaExpr, LuaIndexMemberExpr,
2+
BinaryOperator, LuaAstNode, LuaBinaryExpr, LuaCallExpr, LuaChunk, LuaExpr, LuaIndexMemberExpr,
33
LuaLiteralToken,
44
};
55

@@ -128,6 +128,38 @@ fn maybe_type_guard_binary(
128128
}
129129
}
130130
}
131+
// may ref a type value
132+
} else if let LuaExpr::NameExpr(name_expr) = left_expr {
133+
if let LuaExpr::LiteralExpr(literal_expr) = right_expr {
134+
let Some(decl_id) = db
135+
.get_reference_index()
136+
.get_var_reference_decl(&cache.get_file_id(), name_expr.get_range())
137+
else {
138+
return Ok(ResultTypeOrContinue::Continue);
139+
};
140+
141+
let Some(expr_ptr) = tree.get_decl_ref_expr(&decl_id) else {
142+
return Ok(ResultTypeOrContinue::Continue);
143+
};
144+
145+
let Some(expr) = expr_ptr.to_node(root) else {
146+
return Ok(ResultTypeOrContinue::Continue);
147+
};
148+
149+
if let LuaExpr::CallExpr(call_expr) = expr {
150+
if call_expr.is_type() {
151+
type_guard_expr = Some(call_expr);
152+
match literal_expr.get_literal() {
153+
Some(LuaLiteralToken::String(s)) => {
154+
literal_string = s.get_value();
155+
}
156+
_ => return Ok(ResultTypeOrContinue::Continue),
157+
}
158+
}
159+
} else {
160+
return Ok(ResultTypeOrContinue::Continue);
161+
}
162+
}
131163
}
132164

133165
if type_guard_expr.is_none() || literal_string.is_empty() {

crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ mod binary_flow;
22
mod call_flow;
33
mod index_flow;
44

5-
use emmylua_parser::{LuaChunk, LuaExpr, LuaNameExpr, LuaUnaryExpr, UnaryOperator};
5+
use emmylua_parser::{LuaAstNode, LuaChunk, LuaExpr, LuaNameExpr, LuaUnaryExpr, UnaryOperator};
66

77
use crate::{
88
semantic::infer::{
@@ -146,7 +146,16 @@ fn get_type_at_name_expr(
146146
};
147147

148148
if name_var_ref_id != *var_ref_id {
149-
return Ok(ResultTypeOrContinue::Continue);
149+
return get_type_at_name_ref(
150+
db,
151+
tree,
152+
cache,
153+
root,
154+
var_ref_id,
155+
flow_node,
156+
name_expr,
157+
condition_flow,
158+
);
150159
}
151160

152161
let antecedent_flow_id = get_single_antecedent(tree, flow_node)?;
@@ -160,6 +169,43 @@ fn get_type_at_name_expr(
160169
Ok(ResultTypeOrContinue::Result(result_type))
161170
}
162171

172+
fn get_type_at_name_ref(
173+
db: &DbIndex,
174+
tree: &FlowTree,
175+
cache: &mut LuaInferCache,
176+
root: &LuaChunk,
177+
var_ref_id: &VarRefId,
178+
flow_node: &FlowNode,
179+
name_expr: LuaNameExpr,
180+
condition_flow: InferConditionFlow,
181+
) -> Result<ResultTypeOrContinue, InferFailReason> {
182+
let Some(decl_id) = db
183+
.get_reference_index()
184+
.get_var_reference_decl(&cache.get_file_id(), name_expr.get_range())
185+
else {
186+
return Ok(ResultTypeOrContinue::Continue);
187+
};
188+
189+
let Some(expr_ptr) = tree.get_decl_ref_expr(&decl_id) else {
190+
return Ok(ResultTypeOrContinue::Continue);
191+
};
192+
193+
let Some(expr) = expr_ptr.to_node(root) else {
194+
return Ok(ResultTypeOrContinue::Continue);
195+
};
196+
197+
get_type_at_condition_flow(
198+
db,
199+
tree,
200+
cache,
201+
root,
202+
var_ref_id,
203+
flow_node,
204+
expr,
205+
condition_flow,
206+
)
207+
}
208+
163209
fn get_type_at_unary_flow(
164210
db: &DbIndex,
165211
tree: &FlowTree,

0 commit comments

Comments
 (0)