Skip to content

Commit 2eb92e8

Browse files
committed
Support complex assert narrow
Fix #583
1 parent 5e4163e commit 2eb92e8

File tree

4 files changed

+47
-65
lines changed

4 files changed

+47
-65
lines changed

crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/stats.rs

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use emmylua_parser::{
2-
LuaAssignStat, LuaAst, LuaAstNode, LuaBlock, LuaBreakStat, LuaCallExprStat, LuaDoStat,
3-
LuaForRangeStat, LuaForStat, LuaFuncStat, LuaGotoStat, LuaIfStat, LuaLabelStat, LuaLocalStat,
4-
LuaRepeatStat, LuaReturnStat, LuaWhileStat,
2+
LuaAssignStat, LuaAst, LuaAstNode, LuaBlock, LuaBreakStat, LuaCallArgList, LuaCallExprStat,
3+
LuaDoStat, LuaForRangeStat, LuaForStat, LuaFuncStat, LuaGotoStat, LuaIfStat, LuaLabelStat,
4+
LuaLocalStat, LuaRepeatStat, LuaReturnStat, LuaWhileStat,
55
};
66

77
use crate::{
@@ -78,23 +78,37 @@ pub fn bind_call_expr_stat(
7878
None => return current, // If there's no call expression, just return the current flow
7979
};
8080

81-
if let Some(ast) = LuaAst::cast(call_expr.syntax().clone()) {
82-
bind_each_child(binder, ast, current);
83-
}
84-
8581
if call_expr.is_assert() {
86-
let assert_flow_id = binder.create_node(FlowNodeKind::AssertCall(call_expr.to_ptr()));
87-
binder.add_antecedent(assert_flow_id, current);
88-
assert_flow_id
82+
let Some(arg_list) = call_expr.get_args_list() else {
83+
return current; // If there's no argument list, just return the current flow
84+
};
85+
86+
bind_assert_stat(binder, arg_list, current)
8987
} else if call_expr.is_error() {
9088
let return_flow_id = binder.create_return();
9189
binder.add_antecedent(return_flow_id, current);
9290
return_flow_id
9391
} else {
92+
if let Some(ast) = LuaAst::cast(call_expr.syntax().clone()) {
93+
bind_each_child(binder, ast, current);
94+
}
9495
current
9596
}
9697
}
9798

99+
fn bind_assert_stat(binder: &mut FlowBinder, arg_list: LuaCallArgList, current: FlowId) -> FlowId {
100+
let false_target = binder.unreachable;
101+
102+
let mut pre_arg = current;
103+
for arg in arg_list.get_args() {
104+
let pre_next_arg = binder.create_branch_label();
105+
bind_condition_expr(binder, arg, pre_arg, pre_next_arg, false_target);
106+
pre_arg = finish_flow_label(binder, pre_next_arg, pre_arg);
107+
}
108+
109+
pre_arg
110+
}
111+
98112
pub fn bind_label_stat(
99113
binder: &mut FlowBinder,
100114
label_stat: LuaLabelStat,

crates/emmylua_code_analysis/src/compilation/test/flow.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,4 +850,23 @@ end
850850
let b_expected = ws.ty("B");
851851
assert_eq!(b, b_expected);
852852
}
853+
854+
#[test]
855+
fn test_issue_583() {
856+
let mut ws = VirtualWorkspace::new_with_init_std_lib();
857+
858+
ws.check_code_for(
859+
DiagnosticCode::AssignTypeMismatch,
860+
r#"
861+
--- @param sha string
862+
local function get_hash_color(sha)
863+
local r, g, b = sha:match('(%x)%x(%x)%x(%x)')
864+
assert(r and g and b, 'Invalid hash color')
865+
local _ = r --- @type string
866+
local _ = g --- @type string
867+
local _ = b --- @type string
868+
end
869+
"#,
870+
);
871+
}
853872
}

crates/emmylua_code_analysis/src/db_index/flow/flow_node.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use emmylua_parser::{
2-
LuaAssignStat, LuaAstNode, LuaAstPtr, LuaCallExpr, LuaChunk, LuaClosureExpr, LuaDocTagCast,
3-
LuaExpr, LuaForStat, LuaSyntaxKind, LuaSyntaxNode,
2+
LuaAssignStat, LuaAstNode, LuaAstPtr, LuaChunk, LuaClosureExpr, LuaDocTagCast, LuaExpr,
3+
LuaForStat, LuaSyntaxKind, LuaSyntaxNode,
44
};
55
use internment::ArcIntern;
66
use rowan::{TextRange, TextSize};
@@ -52,8 +52,6 @@ pub enum FlowNodeKind {
5252
ForIStat(LuaAstPtr<LuaForStat>),
5353
/// Tag cast comment
5454
TagCast(LuaAstPtr<LuaDocTagCast>),
55-
/// Assert call
56-
AssertCall(LuaAstPtr<LuaCallExpr>),
5755
/// Break statement
5856
Break,
5957
/// Return statement

crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs

Lines changed: 2 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use emmylua_parser::{LuaAssignStat, LuaAstNode, LuaCallExpr, LuaChunk, LuaVarExpr};
1+
use emmylua_parser::{LuaAssignStat, LuaAstNode, LuaChunk, LuaVarExpr};
22

33
use crate::{
44
infer_expr,
@@ -8,7 +8,7 @@ use crate::{
88
get_multi_antecedents, get_single_antecedent,
99
get_type_at_cast_flow::get_type_at_cast_flow,
1010
get_var_ref_type,
11-
narrow_type::{narrow_down_type, remove_false_or_nil},
11+
narrow_type::narrow_down_type,
1212
var_ref_id::get_var_expr_var_ref_id,
1313
ResultTypeOrContinue,
1414
},
@@ -138,25 +138,6 @@ pub fn get_type_at_flow(
138138
antecedent_flow_id = get_single_antecedent(tree, flow_node)?;
139139
}
140140
}
141-
FlowNodeKind::AssertCall(lua_ast_ptr) => {
142-
let assert_call = lua_ast_ptr.to_node(root).ok_or(InferFailReason::None)?;
143-
let result_or_continue = get_type_at_assert_call(
144-
db,
145-
tree,
146-
cache,
147-
root,
148-
var_ref_id,
149-
flow_node,
150-
assert_call,
151-
)?;
152-
153-
if let ResultTypeOrContinue::Result(assert_type) = result_or_continue {
154-
result_type = assert_type;
155-
break;
156-
} else {
157-
antecedent_flow_id = get_single_antecedent(tree, flow_node)?;
158-
}
159-
}
160141
}
161142
}
162143

@@ -256,33 +237,3 @@ fn get_type_at_assign_stat(
256237

257238
Ok(ResultTypeOrContinue::Continue)
258239
}
259-
260-
fn get_type_at_assert_call(
261-
db: &DbIndex,
262-
tree: &FlowTree,
263-
cache: &mut LuaInferCache,
264-
root: &LuaChunk,
265-
var_ref_id: &VarRefId,
266-
flow_node: &FlowNode,
267-
assert_call: LuaCallExpr,
268-
) -> Result<ResultTypeOrContinue, InferFailReason> {
269-
let call_arg_list = match assert_call.get_args_list() {
270-
Some(args) => args,
271-
None => return Ok(ResultTypeOrContinue::Continue),
272-
};
273-
274-
for arg in call_arg_list.get_args() {
275-
if let Some(ref_decl_id) = get_var_expr_var_ref_id(db, cache, arg.clone()) {
276-
if ref_decl_id == *var_ref_id {
277-
let antecedent_flow_id = get_single_antecedent(tree, flow_node)?;
278-
let antecedent_type =
279-
get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?;
280-
let result_type = remove_false_or_nil(antecedent_type);
281-
282-
return Ok(ResultTypeOrContinue::Result(result_type));
283-
}
284-
}
285-
}
286-
287-
Ok(ResultTypeOrContinue::Continue)
288-
}

0 commit comments

Comments
 (0)