Skip to content

Commit 55d3c93

Browse files
committed
update flow
1 parent aef7e02 commit 55d3c93

File tree

3 files changed

+93
-9
lines changed

3 files changed

+93
-9
lines changed

crates/emmylua_code_analysis/resources/std/builtin.lua

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@
149149
--- Get the parameters of a constructor as a tuple
150150
---@alias ConstructorParameters<T> T extends new (fun(...: infer P): any) and P or never
151151

152+
---@alias ReturnType<T extends function> T extends (fun(...: any): infer R) and R or any
153+
152154
--- Make all properties in T optional
153155
---@alias Partial<T> { [P in keyof T]?: T[P]; }
154156

crates/emmylua_code_analysis/src/compilation/test/generic_test.rs

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -726,18 +726,48 @@ mod test {
726726
ws.def(
727727
r#"
728728
---@alias ConstructorParameters<T> T extends new (fun(...: infer P): any) and P or never
729+
729730
---@alias Parameters<T extends function> T extends (fun(...: infer P): any) and P or never
730731
732+
---@alias ReturnType<T extends function> T extends (fun(...: any): infer R) and R or any
733+
731734
---@alias Procedure fun(...: any[]): any
732735
733-
---@alias MockParameters<T> T extends table and ConstructorParameters<T> or T extends Procedure and Parameters<T> or never
736+
---@alias MockParameters<T> T extends Procedure and Parameters<T> or never
737+
738+
---@alias MockReturnType<T> T extends Procedure and ReturnType<T> or never
734739
735740
---@class Mock<T>
736741
---@field calls MockParameters<T>[]
742+
---@overload fun(...: MockParameters<T>...): MockReturnType<T>
737743
"#,
738744
);
739-
ws.def(
740-
r#"
745+
{
746+
ws.def(
747+
r#"
748+
---@generic T: Procedure
749+
---@param a T
750+
---@return Mock<T>
751+
local function fn(a)
752+
end
753+
754+
local sum = fn(function(a, b)
755+
return a + b
756+
end)
757+
A = sum
758+
"#,
759+
);
760+
761+
let result_ty = ws.expr_ty("A");
762+
assert_eq!(
763+
ws.humanize_type_detailed(result_ty),
764+
"Mock<fun(a, b) -> any>"
765+
);
766+
}
767+
768+
{
769+
ws.def(
770+
r#"
741771
---@generic T: Procedure
742772
---@param a T?
743773
---@return Mock<T>
@@ -746,9 +776,10 @@ mod test {
746776
747777
result = fn().calls
748778
"#,
749-
);
779+
);
750780

751-
let result_ty = ws.expr_ty("result");
752-
assert_eq!(ws.humanize_type(result_ty), "any[][]");
781+
let result_ty = ws.expr_ty("result");
782+
assert_eq!(ws.humanize_type(result_ty), "any[][]");
783+
}
753784
}
754785
}

crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use emmylua_parser::{LuaAssignStat, LuaAstNode, LuaChunk, LuaVarExpr};
1+
use emmylua_parser::{LuaAssignStat, LuaAstNode, LuaChunk, LuaExpr, LuaVarExpr};
22

33
use crate::{
44
CacheEntry, DbIndex, FlowId, FlowNode, FlowNodeKind, FlowTree, InferFailReason, LuaDeclId,
@@ -61,8 +61,23 @@ pub fn get_type_at_flow(
6161
}
6262
FlowNodeKind::DeclPosition(position) => {
6363
if *position <= var_ref_id.get_position() {
64-
result_type = get_var_ref_type(db, cache, var_ref_id)?;
65-
break;
64+
match get_var_ref_type(db, cache, var_ref_id) {
65+
Ok(var_type) => {
66+
result_type = var_type;
67+
break;
68+
}
69+
Err(err) => {
70+
// 尝试推断声明位置的类型, 如果发生错误则返回初始错误, 否则返回当前推断错误
71+
if let Some(init_type) =
72+
try_infer_decl_initializer_type(db, cache, root, var_ref_id)?
73+
{
74+
result_type = init_type;
75+
break;
76+
}
77+
78+
return Err(err);
79+
}
80+
}
6681
} else {
6782
antecedent_flow_id = get_single_antecedent(tree, flow_node)?;
6883
}
@@ -258,3 +273,39 @@ fn get_type_at_assign_stat(
258273

259274
Ok(ResultTypeOrContinue::Continue)
260275
}
276+
277+
fn try_infer_decl_initializer_type(
278+
db: &DbIndex,
279+
cache: &mut LuaInferCache,
280+
root: &LuaChunk,
281+
var_ref_id: &VarRefId,
282+
) -> Result<Option<LuaType>, InferFailReason> {
283+
let Some(decl_id) = var_ref_id.get_decl_id_ref() else {
284+
return Ok(None);
285+
};
286+
287+
let decl = db
288+
.get_decl_index()
289+
.get_decl(&decl_id)
290+
.ok_or(InferFailReason::None)?;
291+
292+
let Some(value_syntax_id) = decl.get_value_syntax_id() else {
293+
return Ok(None);
294+
};
295+
296+
let Some(node) = value_syntax_id.to_node_from_root(root.syntax()) else {
297+
return Ok(None);
298+
};
299+
300+
let Some(expr) = LuaExpr::cast(node) else {
301+
return Ok(None);
302+
};
303+
304+
let expr_type = infer_expr(db, cache, expr.clone())?;
305+
let init_type = match expr_type {
306+
LuaType::Variadic(variadic) => variadic.get_type(0).cloned(),
307+
ty => Some(ty),
308+
};
309+
310+
Ok(init_type)
311+
}

0 commit comments

Comments
 (0)