Skip to content

Commit b890da1

Browse files
committed
fix #598
1 parent de59988 commit b890da1

File tree

3 files changed

+228
-4
lines changed

3 files changed

+228
-4
lines changed

crates/emmylua_code_analysis/src/compilation/test/flow.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,4 +945,60 @@ end
945945
let a_expected = ws.ty("string");
946946
assert_eq!(a, a_expected);
947947
}
948+
949+
#[test]
950+
fn test_issue_598() {
951+
let mut ws = VirtualWorkspace::new_with_init_std_lib();
952+
ws.def(
953+
r#"
954+
---@class A<T>
955+
A = {}
956+
---@class IDisposable
957+
---@class B<T>: IDisposable
958+
959+
---@class AnonymousObserver<T>: IDisposable
960+
961+
---@generic T
962+
---@return AnonymousObserver<T>
963+
function createAnonymousObserver()
964+
end
965+
"#,
966+
);
967+
assert!(ws.check_code_for(
968+
DiagnosticCode::ReturnTypeMismatch,
969+
r#"
970+
---@param observer fun(value: T) | B<T>
971+
---@return IDisposable
972+
function A:subscribe(observer)
973+
local typ = type(observer)
974+
if typ == 'function' then
975+
---@cast observer fun(value: T)
976+
observer = createAnonymousObserver()
977+
elseif typ == 'table' then
978+
---@cast observer -function
979+
observer = createAnonymousObserver()
980+
end
981+
982+
return observer
983+
end
984+
"#,
985+
));
986+
987+
assert!(!ws.check_code_for(
988+
DiagnosticCode::ReturnTypeMismatch,
989+
r#"
990+
---@param observer fun(value: T) | B<T>
991+
---@return IDisposable
992+
function A:test2(observer)
993+
local typ = type(observer)
994+
if typ == 'table' then
995+
---@cast observer -function
996+
observer = createAnonymousObserver()
997+
end
998+
999+
return observer
1000+
end
1001+
"#,
1002+
));
1003+
}
9481004
}

crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs

Lines changed: 171 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@ pub fn get_type_at_flow(
3333
}
3434
}
3535

36-
let mut result_type = LuaType::Unknown;
36+
let result_type;
3737
let mut antecedent_flow_id = flow_id;
3838
loop {
3939
let flow_node = tree
4040
.get_flow_node(antecedent_flow_id)
4141
.ok_or(InferFailReason::None)?;
42+
4243
match &flow_node.kind {
4344
FlowNodeKind::Start | FlowNodeKind::Unreachable => {
4445
result_type = get_var_ref_type(db, cache, var_ref_id)?;
@@ -49,10 +50,42 @@ pub fn get_type_at_flow(
4950
}
5051
FlowNodeKind::BranchLabel | FlowNodeKind::NamedLabel(_) => {
5152
let multi_antecedents = get_multi_antecedents(tree, flow_node)?;
52-
for flow_id in multi_antecedents {
53+
54+
// 在分支前获取原始类型
55+
let original_type = if let Some(antecedent) = &flow_node.antecedent {
56+
match antecedent {
57+
crate::FlowAntecedent::Single(single_id) => {
58+
get_type_at_flow(db, tree, cache, root, var_ref_id, *single_id)?
59+
}
60+
crate::FlowAntecedent::Multiple(_) => {
61+
// 在 BranchLabel 中,多个 antecedent 需要获取共同的祖先
62+
get_var_ref_type(db, cache, var_ref_id)?
63+
}
64+
}
65+
} else {
66+
get_var_ref_type(db, cache, var_ref_id)?
67+
};
68+
69+
let mut branch_types = Vec::new();
70+
71+
for &flow_id in &multi_antecedents {
5372
let branch_type = get_type_at_flow(db, tree, cache, root, var_ref_id, flow_id)?;
54-
result_type = TypeOps::Union.apply(db, &result_type, &branch_type);
73+
branch_types.push(branch_type);
5574
}
75+
76+
// 分析类型覆盖
77+
let result_type_analysis = analyze_branch_coverage(
78+
&original_type,
79+
&branch_types,
80+
db,
81+
tree,
82+
cache,
83+
root,
84+
var_ref_id,
85+
&multi_antecedents,
86+
)?;
87+
88+
result_type = result_type_analysis;
5689
break;
5790
}
5891
FlowNodeKind::DeclPosition(position) => {
@@ -237,3 +270,138 @@ fn get_type_at_assign_stat(
237270

238271
Ok(ResultTypeOrContinue::Continue)
239272
}
273+
274+
// 分析分支覆盖率, 确定原始类型中哪些部分被赋值覆盖
275+
fn analyze_branch_coverage(
276+
original_type: &LuaType,
277+
branch_types: &[LuaType],
278+
db: &DbIndex,
279+
tree: &FlowTree,
280+
cache: &mut LuaInferCache,
281+
root: &LuaChunk,
282+
var_ref_id: &VarRefId,
283+
flow_ids: &[FlowId],
284+
) -> Result<LuaType, InferFailReason> {
285+
if branch_types.is_empty() {
286+
return Ok(original_type.clone());
287+
}
288+
289+
// 检查哪些分支实际上有赋值
290+
let mut assignment_branches = Vec::new();
291+
let mut non_assignment_branches = Vec::new();
292+
293+
for (i, &flow_id) in flow_ids.iter().enumerate() {
294+
let branch_type = &branch_types[i];
295+
296+
// 检查这个分支是否有赋值, 通过检查类型是否显著变化
297+
let has_assignment = !types_are_equivalent(branch_type, original_type)
298+
&& has_assignment_in_branch(db, tree, cache, root, var_ref_id, flow_id)?;
299+
300+
if has_assignment {
301+
assignment_branches.push(branch_type.clone());
302+
} else {
303+
non_assignment_branches.push(branch_type.clone());
304+
}
305+
}
306+
307+
if !assignment_branches.is_empty() {
308+
// 检查所有赋值分支是否都是相同的类型
309+
let first_assignment_type = &assignment_branches[0];
310+
let all_assignments_same = assignment_branches
311+
.iter()
312+
.all(|t| types_are_equivalent(t, first_assignment_type));
313+
314+
if all_assignments_same && assignment_branches.len() >= 2 {
315+
// 多个分支具有相同的赋值类型, 表明完全覆盖
316+
return Ok(first_assignment_type.clone());
317+
}
318+
}
319+
320+
// 回退到原始行为: 合并所有分支类型
321+
let mut result_type = LuaType::Unknown;
322+
let mut has_any_type = false;
323+
324+
for branch_type in branch_types {
325+
if !has_any_type {
326+
result_type = branch_type.clone();
327+
has_any_type = true;
328+
} else {
329+
result_type = TypeOps::Union.apply(db, &result_type, branch_type);
330+
}
331+
}
332+
333+
Ok(result_type)
334+
}
335+
336+
// 检查分支是否包含对目标变量的赋值
337+
fn has_assignment_in_branch(
338+
db: &DbIndex,
339+
tree: &FlowTree,
340+
cache: &mut LuaInferCache,
341+
root: &LuaChunk,
342+
var_ref_id: &VarRefId,
343+
start_flow_id: FlowId,
344+
) -> Result<bool, InferFailReason> {
345+
let mut current_flow_id = start_flow_id;
346+
347+
// 遍历流向后看是否在这个分支中有赋值
348+
loop {
349+
let flow_node = tree
350+
.get_flow_node(current_flow_id)
351+
.ok_or(InferFailReason::None)?;
352+
353+
match &flow_node.kind {
354+
FlowNodeKind::Assignment(assign_ptr) => {
355+
let assign_stat = assign_ptr.to_node(root).ok_or(InferFailReason::None)?;
356+
let result_or_continue = get_type_at_assign_stat(
357+
db,
358+
tree,
359+
cache,
360+
root,
361+
var_ref_id,
362+
flow_node,
363+
assign_stat,
364+
)?;
365+
366+
if let ResultTypeOrContinue::Result(_) = result_or_continue {
367+
return Ok(true);
368+
}
369+
370+
// 继续检查 antecedents
371+
current_flow_id = get_single_antecedent(tree, flow_node)?;
372+
}
373+
FlowNodeKind::TrueCondition(_) | FlowNodeKind::FalseCondition(_) => {
374+
// 继续通过条件节点
375+
current_flow_id = get_single_antecedent(tree, flow_node)?;
376+
}
377+
FlowNodeKind::BranchLabel | FlowNodeKind::NamedLabel(_) => {
378+
// 到达另一个分支点, 停止这里
379+
return Ok(false);
380+
}
381+
FlowNodeKind::Start | FlowNodeKind::Unreachable => {
382+
// 到达开始没有找到赋值
383+
return Ok(false);
384+
}
385+
FlowNodeKind::DeclPosition(_) => {
386+
// 到达声明, 停止这里
387+
return Ok(false);
388+
}
389+
_ => {
390+
// 继续检查 antecedents 对于其他 flow node 类型
391+
current_flow_id = get_single_antecedent(tree, flow_node)?;
392+
}
393+
}
394+
}
395+
}
396+
397+
// 检查两个类型是否等价
398+
fn types_are_equivalent(a: &LuaType, b: &LuaType) -> bool {
399+
match (a, b) {
400+
(LuaType::Union(a_union), LuaType::Union(b_union)) => {
401+
let a_types = a_union.into_vec();
402+
let b_types = b_union.into_vec();
403+
a_types.len() == b_types.len() && a_types.iter().all(|t| b_types.contains(t))
404+
}
405+
_ => a == b,
406+
}
407+
}

crates/emmylua_ls/src/handlers/test/hover_function_test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ mod tests {
417417
);
418418
assert!(ws.check_hover(
419419
r#"
420-
source:sel<??>ect(function(value)
420+
source:<??>select(function(value)
421421
return value
422422
end)
423423
"#,

0 commit comments

Comments
 (0)