Skip to content

Commit 2984f1f

Browse files
committed
diagnostic UndefinedField 如果位于判断语句条件中则跳过
1 parent 0f81997 commit 2984f1f

File tree

2 files changed

+110
-1
lines changed

2 files changed

+110
-1
lines changed

crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use std::collections::HashSet;
22

3-
use emmylua_parser::{LuaAst, LuaAstNode, LuaIndexExpr, LuaIndexKey, LuaVarExpr};
3+
use emmylua_parser::{
4+
LuaAst, LuaAstNode, LuaElseIfClauseStat, LuaForRangeStat, LuaForStat, LuaIfStat, LuaIndexExpr,
5+
LuaIndexKey, LuaRepeatStat, LuaSyntaxKind, LuaVarExpr, LuaWhileStat,
6+
};
47

58
use crate::{DiagnosticCode, InferFailReason, LuaMemberKey, LuaType, SemanticModel};
69

@@ -64,6 +67,17 @@ fn check_index_expr(
6467

6568
let index_key = index_expr.get_index_key()?;
6669

70+
// 检查是否为判断语句
71+
if matches!(code, DiagnosticCode::UndefinedField) {
72+
if is_in_conditional_statement(index_expr) {
73+
return Some(());
74+
}
75+
}
76+
77+
if is_in_conditional_statement(index_expr) {
78+
return Some(());
79+
}
80+
6781
if is_valid_member(semantic_model, &prefix_typ, index_expr, &index_key, code).is_some() {
6882
return Some(());
6983
}
@@ -324,3 +338,84 @@ fn get_key_types(typ: &LuaType) -> HashSet<LuaType> {
324338
}
325339
type_set
326340
}
341+
342+
/// 判断给定的AST节点是否位于判断语句的条件表达式中
343+
///
344+
/// 该函数检查节点是否位于以下语句的条件部分:
345+
/// - if语句的条件表达式
346+
/// - while循环的条件表达式
347+
/// - for循环的迭代表达式
348+
/// - repeat循环的条件表达式
349+
/// - elseif子句的条件表达式
350+
///
351+
/// # 参数
352+
/// * `node` - 要检查的AST节点
353+
///
354+
/// # 返回值
355+
/// * `true` - 如果节点位于判断语句的条件表达式中
356+
/// * `false` - 如果节点不在判断语句的条件表达式中
357+
fn is_in_conditional_statement<T: LuaAstNode>(node: &T) -> bool {
358+
let node_range = node.get_range();
359+
360+
// 遍历所有祖先节点,查找条件语句
361+
for ancestor in node.syntax().ancestors() {
362+
match ancestor.kind().into() {
363+
LuaSyntaxKind::IfStat => {
364+
if let Some(if_stat) = LuaIfStat::cast(ancestor) {
365+
if let Some(condition_expr) = if_stat.get_condition_expr() {
366+
if condition_expr.get_range().contains_range(node_range) {
367+
return true;
368+
}
369+
}
370+
}
371+
}
372+
LuaSyntaxKind::WhileStat => {
373+
if let Some(while_stat) = LuaWhileStat::cast(ancestor) {
374+
if let Some(condition_expr) = while_stat.get_condition_expr() {
375+
if condition_expr.get_range().contains_range(node_range) {
376+
return true;
377+
}
378+
}
379+
}
380+
}
381+
LuaSyntaxKind::ForStat => {
382+
if let Some(for_stat) = LuaForStat::cast(ancestor) {
383+
for iter_expr in for_stat.get_iter_expr() {
384+
if iter_expr.get_range().contains_range(node_range) {
385+
return true;
386+
}
387+
}
388+
}
389+
}
390+
LuaSyntaxKind::ForRangeStat => {
391+
if let Some(for_range_stat) = LuaForRangeStat::cast(ancestor) {
392+
for expr in for_range_stat.get_expr_list() {
393+
if expr.get_range().contains_range(node_range) {
394+
return true;
395+
}
396+
}
397+
}
398+
}
399+
LuaSyntaxKind::RepeatStat => {
400+
if let Some(repeat_stat) = LuaRepeatStat::cast(ancestor) {
401+
if let Some(condition_expr) = repeat_stat.get_condition_expr() {
402+
if condition_expr.get_range().contains_range(node_range) {
403+
return true;
404+
}
405+
}
406+
}
407+
}
408+
LuaSyntaxKind::ElseIfClauseStat => {
409+
if let Some(elseif_clause) = LuaElseIfClauseStat::cast(ancestor) {
410+
if let Some(condition_expr) = elseif_clause.get_condition_expr() {
411+
if condition_expr.get_range().contains_range(node_range) {
412+
return true;
413+
}
414+
}
415+
}
416+
}
417+
_ => {}
418+
}
419+
}
420+
false
421+
}

crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,4 +598,18 @@ mod test {
598598
"#
599599
));
600600
}
601+
602+
#[test]
603+
fn test_if_1() {
604+
let mut ws = VirtualWorkspace::new();
605+
assert!(ws.check_code_for(
606+
DiagnosticCode::UndefinedField,
607+
r#"
608+
---@type table<int, string>
609+
local arg = {}
610+
if arg['test'] == 'true' then
611+
end
612+
"#
613+
));
614+
}
601615
}

0 commit comments

Comments
 (0)