Skip to content

Commit 61b277c

Browse files
committed
refactor function return infer
Fix #643 Fix #659
1 parent 07fbea4 commit 61b277c

File tree

2 files changed

+144
-4
lines changed

2 files changed

+144
-4
lines changed

crates/emmylua_code_analysis/src/compilation/analyzer/lua/closure.rs

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::ops::Deref;
2+
13
use emmylua_parser::{
24
LuaAst, LuaAstNode, LuaCallArgList, LuaCallExpr, LuaClosureExpr, LuaFuncStat, LuaVarExpr,
35
};
@@ -215,7 +217,7 @@ pub fn analyze_return_point(
215217
match point {
216218
LuaReturnPoint::Expr(expr) => {
217219
let expr_type = infer_expr(db, cache, expr.clone())?;
218-
return_type = TypeOps::Union.apply(db, &return_type, &expr_type);
220+
return_type = union_return_expr(db, return_type, expr_type);
219221
}
220222
LuaReturnPoint::MuliExpr(exprs) => {
221223
let mut multi_return = vec![];
@@ -224,10 +226,10 @@ pub fn analyze_return_point(
224226
multi_return.push(expr_type);
225227
}
226228
let typ = LuaType::Variadic(VariadicType::Multi(multi_return).into());
227-
return_type = TypeOps::Union.apply(db, &return_type, &typ);
229+
return_type = union_return_expr(db, return_type, typ);
228230
}
229231
LuaReturnPoint::Nil => {
230-
return_type = TypeOps::Union.apply(db, &return_type, &LuaType::Nil);
232+
return_type = union_return_expr(db, return_type, LuaType::Nil);
231233
}
232234
_ => {}
233235
}
@@ -239,3 +241,92 @@ pub fn analyze_return_point(
239241
name: None,
240242
}])
241243
}
244+
245+
fn union_return_expr(db: &DbIndex, left: LuaType, right: LuaType) -> LuaType {
246+
if left == LuaType::Unknown {
247+
return right;
248+
}
249+
250+
match (&left, &right) {
251+
(LuaType::Variadic(left_variadic), LuaType::Variadic(right_variadic)) => {
252+
match (&left_variadic.deref(), &right_variadic.deref()) {
253+
(VariadicType::Base(left_base), VariadicType::Base(right_base)) => {
254+
let union_base = TypeOps::Union.apply(db, left_base, right_base);
255+
LuaType::Variadic(VariadicType::Base(union_base).into())
256+
}
257+
(VariadicType::Multi(left_multi), VariadicType::Multi(right_multi)) => {
258+
let mut new_multi = vec![];
259+
let max_len = left_multi.len().max(right_multi.len());
260+
for i in 0..max_len {
261+
let left_type = left_multi.get(i).cloned().unwrap_or(LuaType::Nil);
262+
let right_type = right_multi.get(i).cloned().unwrap_or(LuaType::Nil);
263+
new_multi.push(TypeOps::Union.apply(db, &left_type, &right_type));
264+
}
265+
LuaType::Variadic(VariadicType::Multi(new_multi).into())
266+
}
267+
// difficult to merge the type, use let
268+
_ => left.clone(),
269+
}
270+
}
271+
(LuaType::Variadic(variadic), _) => {
272+
let first_type = variadic.get_type(0).cloned().unwrap_or(LuaType::Unknown);
273+
let first_union_type = TypeOps::Union.apply(db, &first_type, &right);
274+
275+
match variadic.deref() {
276+
VariadicType::Base(base) => {
277+
let union_base = TypeOps::Union.apply(db, base, &LuaType::Nil);
278+
LuaType::Variadic(
279+
VariadicType::Multi(vec![
280+
first_union_type,
281+
LuaType::Variadic(VariadicType::Base(union_base).into()),
282+
])
283+
.into(),
284+
)
285+
}
286+
VariadicType::Multi(multi) => {
287+
let mut new_multi = multi.clone();
288+
if new_multi.len() > 0 {
289+
new_multi[0] = first_union_type;
290+
for i in 1..new_multi.len() {
291+
new_multi[i] = TypeOps::Union.apply(db, &new_multi[i], &LuaType::Nil);
292+
}
293+
} else {
294+
new_multi.push(first_union_type);
295+
}
296+
297+
LuaType::Variadic(VariadicType::Multi(new_multi).into())
298+
}
299+
}
300+
}
301+
(_, LuaType::Variadic(variadic)) => {
302+
let first_type = variadic.get_type(0).cloned().unwrap_or(LuaType::Unknown);
303+
let first_union_type = TypeOps::Union.apply(db, &left, &first_type);
304+
match variadic.deref() {
305+
VariadicType::Base(base) => {
306+
let union_base = TypeOps::Union.apply(db, base, &LuaType::Nil);
307+
LuaType::Variadic(
308+
VariadicType::Multi(vec![
309+
first_union_type,
310+
LuaType::Variadic(VariadicType::Base(union_base).into()),
311+
])
312+
.into(),
313+
)
314+
}
315+
VariadicType::Multi(multi) => {
316+
let mut new_multi = multi.clone();
317+
if new_multi.len() > 0 {
318+
new_multi[0] = first_union_type;
319+
for i in 1..new_multi.len() {
320+
new_multi[i] = TypeOps::Union.apply(db, &new_multi[i], &LuaType::Nil);
321+
}
322+
} else {
323+
new_multi.push(first_union_type);
324+
}
325+
326+
LuaType::Variadic(VariadicType::Multi(new_multi).into())
327+
}
328+
}
329+
}
330+
_ => TypeOps::Union.apply(db, &left, &right),
331+
}
332+
}

crates/emmylua_code_analysis/src/compilation/test/return_unwrap_test.rs

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#[cfg(test)]
22
mod test {
3-
use crate::{LuaType, VirtualWorkspace};
3+
use crate::{DiagnosticCode, LuaType, VirtualWorkspace};
44

55
#[test]
66
fn test_issue_376() {
@@ -41,4 +41,53 @@ mod test {
4141
"#,
4242
));
4343
}
44+
45+
#[test]
46+
fn test_issue_659() {
47+
let mut ws = VirtualWorkspace::new_with_init_std_lib();
48+
49+
assert!(ws.check_code_for(
50+
DiagnosticCode::ReturnTypeMismatch,
51+
r#"
52+
--- @async
53+
--- @generic R
54+
--- @param fn fun(): R...
55+
--- @return R...
56+
function wrap(fn) end
57+
58+
---@async
59+
--- @param a {}?
60+
--- @return {}?
61+
--- @return string? err
62+
function get(a)
63+
return wrap(function()
64+
if not a then
65+
return nil, 'err'
66+
end
67+
68+
return a
69+
end)
70+
end
71+
"#,
72+
));
73+
}
74+
75+
#[test]
76+
fn test_issue_643() {
77+
let mut ws = VirtualWorkspace::new_with_init_std_lib();
78+
79+
assert!(ws.check_code_for(
80+
DiagnosticCode::AssignTypeMismatch,
81+
r#"
82+
local function foo(b)
83+
if not b then
84+
return
85+
end
86+
return 'a', 1
87+
end
88+
--- @type 'a'?
89+
local _ = foo()
90+
"#,
91+
));
92+
}
4493
}

0 commit comments

Comments
 (0)