Skip to content

Commit bc6a5c2

Browse files
committed
fix UnResolveCallClosureParams & UnResolveClosureReturn
1 parent a3df30f commit bc6a5c2

File tree

5 files changed

+55
-26
lines changed

5 files changed

+55
-26
lines changed

crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use resolve::{
1919
try_resolve_module_ref, try_resolve_return_point, try_resolve_table_field,
2020
};
2121
use resolve_closure::{
22-
try_resolve_closure_params, try_resolve_closure_parent_params, try_resolve_closure_return,
22+
try_resolve_call_closure_params, try_resolve_closure_parent_params, try_resolve_closure_return,
2323
};
2424

2525
use super::{infer_manager::InferCacheManager, lua::LuaReturnPoint, AnalyzeContext};
@@ -179,7 +179,7 @@ fn try_resolve(
179179
try_resolve_return_point(db, cache, un_resolve_return)
180180
}
181181
UnResolve::ClosureParams(un_resolve_closure_params) => {
182-
try_resolve_closure_params(db, cache, un_resolve_closure_params)
182+
try_resolve_call_closure_params(db, cache, un_resolve_closure_params)
183183
}
184184
UnResolve::ClosureReturn(un_resolve_closure_return) => {
185185
try_resolve_closure_return(db, cache, un_resolve_closure_return)

crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ use std::sync::Arc;
33
use emmylua_parser::{LuaAstNode, LuaIndexMemberExpr, LuaTableExpr, LuaVarExpr};
44

55
use crate::{
6-
infer_call_expr_func, infer_expr, infer_table_should_be, DbIndex, InferFailReason, InferGuard,
7-
LuaDocParamInfo, LuaDocReturnInfo, LuaFunctionType, LuaInferCache, LuaSignature, LuaType,
8-
LuaUnionType, SignatureReturnStatus, TypeOps,
6+
get_real_type, infer_call_expr_func, infer_expr, infer_table_should_be, DbIndex,
7+
InferFailReason, InferGuard, LuaDocParamInfo, LuaDocReturnInfo, LuaFunctionType, LuaInferCache,
8+
LuaSignature, LuaType, LuaUnionType, SignatureReturnStatus, TypeOps,
99
};
1010

1111
use super::{
@@ -14,7 +14,7 @@ use super::{
1414
UnResolveParentClosureParams, UnResolveReturn,
1515
};
1616

17-
pub fn try_resolve_closure_params(
17+
pub fn try_resolve_call_closure_params(
1818
db: &mut DbIndex,
1919
cache: &mut LuaInferCache,
2020
closure_params: &mut UnResolveCallClosureParams,
@@ -50,21 +50,22 @@ pub fn try_resolve_closure_params(
5050
_ => {}
5151
}
5252

53-
let is_async;
54-
let expr_closure_params = if let Some(param_type) = call_doc_func.get_params().get(param_idx) {
55-
match &param_type.1 {
56-
Some(LuaType::DocFunction(func)) => {
57-
is_async = func.is_async();
58-
func.get_params()
59-
}
60-
Some(LuaType::Union(union_types)) => {
53+
let (is_async, params_to_insert) = if let Some(param_type) =
54+
call_doc_func.get_params().get(param_idx)
55+
{
56+
let Some(param_type) = get_real_type(db, &param_type.1.as_ref().unwrap_or(&LuaType::Any))
57+
else {
58+
return Ok(());
59+
};
60+
match param_type {
61+
LuaType::DocFunction(func) => (func.is_async(), func.get_params().to_vec()),
62+
LuaType::Union(union_types) => {
6163
if let Some(LuaType::DocFunction(func)) = union_types
6264
.get_types()
6365
.iter()
6466
.find(|typ| matches!(typ, LuaType::DocFunction(_)))
6567
{
66-
is_async = func.is_async();
67-
func.get_params()
68+
(func.is_async(), func.get_params().to_vec())
6869
} else {
6970
return Ok(());
7071
}
@@ -81,7 +82,7 @@ pub fn try_resolve_closure_params(
8182
.ok_or(InferFailReason::None)?;
8283

8384
let signature_params = &mut signature.param_docs;
84-
for (idx, (name, type_ref)) in expr_closure_params.iter().enumerate() {
85+
for (idx, (name, type_ref)) in params_to_insert.iter().enumerate() {
8586
if signature_params.contains_key(&idx) {
8687
continue;
8788
}
@@ -136,8 +137,12 @@ pub fn try_resolve_closure_return(
136137
}
137138

138139
let ret_type = if let Some(param_type) = call_doc_func.get_params().get(param_idx) {
139-
if let Some(LuaType::DocFunction(func)) = &param_type.1 {
140-
func.get_ret()
140+
let Some(param_type) = get_real_type(db, &param_type.1.as_ref().unwrap_or(&LuaType::Any))
141+
else {
142+
return Ok(());
143+
};
144+
if let LuaType::DocFunction(func) = param_type {
145+
func.get_ret().clone()
141146
} else {
142147
return Ok(());
143148
}

crates/emmylua_code_analysis/src/compilation/test/closure_param_infer_test.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,4 +463,27 @@ mod test {
463463
let expected = ws.ty("Trigger");
464464
assert_eq!(ws.humanize_type(ty), ws.humanize_type(expected));
465465
}
466+
467+
#[test]
468+
fn test_param_function_is_alias() {
469+
let mut ws = VirtualWorkspace::new();
470+
ws.def(
471+
r#"
472+
---@class LocalTimer
473+
---@alias LocalTimer.OnTimer fun(timer: LocalTimer, count: integer, ...: any)
474+
475+
---@param on_timer LocalTimer.OnTimer
476+
---@return LocalTimer
477+
function loop_count(on_timer)
478+
end
479+
480+
loop_count(function(timer, count)
481+
A = timer
482+
end)
483+
"#,
484+
);
485+
let ty = ws.expr_ty("A");
486+
let expected = ws.ty("LocalTimer");
487+
assert_eq!(ws.humanize_type(ty), ws.humanize_type(expected));
488+
}
466489
}

crates/emmylua_code_analysis/src/db_index/type/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pub use type_decl::{
1515
pub use type_ops::TypeOps;
1616
pub use type_owner::{LuaTypeCache, LuaTypeOwner};
1717
pub use types::*;
18+
pub use type_ops::get_real_type;
1819

1920
#[derive(Debug)]
2021
pub struct LuaTypeIndex {

crates/emmylua_code_analysis/src/db_index/type/type_ops/mod.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,29 +48,29 @@ impl TypeOps {
4848
}
4949
}
5050

51-
pub fn get_real_type<'a>(db: &'a DbIndex, compact_type: &'a LuaType) -> Option<&'a LuaType> {
52-
get_real_type_with_depth(db, compact_type, 0)
51+
pub fn get_real_type<'a>(db: &'a DbIndex, typ: &'a LuaType) -> Option<&'a LuaType> {
52+
get_real_type_with_depth(db, typ, 0)
5353
}
5454

5555
fn get_real_type_with_depth<'a>(
5656
db: &'a DbIndex,
57-
compact_type: &'a LuaType,
57+
typ: &'a LuaType,
5858
depth: u32,
5959
) -> Option<&'a LuaType> {
6060
const MAX_RECURSION_DEPTH: u32 = 100;
6161

6262
if depth >= MAX_RECURSION_DEPTH {
63-
return Some(compact_type);
63+
return Some(typ);
6464
}
6565

66-
match compact_type {
66+
match typ {
6767
LuaType::Ref(type_decl_id) => {
6868
let type_decl = db.get_type_index().get_type_decl(type_decl_id)?;
6969
if type_decl.is_alias() {
7070
return get_real_type_with_depth(db, type_decl.get_alias_ref()?, depth + 1);
7171
}
72-
Some(compact_type)
72+
Some(typ)
7373
}
74-
_ => Some(compact_type),
74+
_ => Some(typ),
7575
}
7676
}

0 commit comments

Comments
 (0)