Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use emmylua_parser::{LuaAstNode, LuaExpr, LuaVarExpr};
use emmylua_parser::{LuaAstNode, LuaTableExpr, LuaVarExpr};

use crate::{
infer_call_expr_func, infer_expr, infer_table_field_value_should_be, DbIndex, InferFailReason,
InferGuard, LuaDocParamInfo, LuaDocReturnInfo, LuaInferCache, LuaType, SignatureReturnStatus,
infer_call_expr_func, infer_expr, infer_member_map, infer_table_should_be, DbIndex,
InferFailReason, InferGuard, LuaDocParamInfo, LuaDocReturnInfo, LuaFunctionType, LuaInferCache,
LuaMemberInfo, LuaSemanticDeclId, LuaSignatureId, LuaType, LuaTypeDeclId,
SignatureReturnStatus,
};

use super::{
Expand Down Expand Up @@ -177,19 +179,28 @@ pub fn try_resolve_closure_parent_params(
if !signature.param_docs.is_empty() {
return Some(true);
}

let self_type;
let member_type = match &closure_params.parent_ast {
UnResolveParentAst::LuaFuncStat(func_stat) => {
let func_name = func_stat.get_func_name()?;
match func_name {
LuaVarExpr::IndexExpr(index_expr) => {
infer_expr(db, cache, LuaExpr::IndexExpr(index_expr)).ok()?
let typ = infer_expr(db, cache, index_expr.get_prefix_expr()?).ok()?;
self_type = Some(typ.clone());

find_best_function_type(db, cache, &typ, &closure_params.signature_id)
}
LuaVarExpr::NameExpr(_) => return Some(true),
_ => return Some(true),
}
}
UnResolveParentAst::LuaTableField(table_field) => {
infer_table_field_value_should_be(db, cache, table_field.clone()).ok()?
let parnet_table_expr = table_field
.get_parent::<LuaTableExpr>()
.ok_or(InferFailReason::None)
.ok()?;
let typ = infer_table_should_be(db, cache, parnet_table_expr).ok()?;
self_type = Some(typ.clone());
find_best_function_type(db, cache, &typ, &closure_params.signature_id)
}
UnResolveParentAst::LuaAssignStat(assign) => {
let (vars, exprs) = assign.get_var_and_expr_list();
Expand All @@ -198,20 +209,53 @@ pub fn try_resolve_closure_parent_params(
.iter()
.position(|expr| expr.get_position() == position)?;
let var = vars.get(idx)?;

match var {
LuaVarExpr::IndexExpr(index_expr) => {
infer_expr(db, cache, LuaExpr::IndexExpr(index_expr.clone())).ok()?
let typ = infer_expr(db, cache, index_expr.get_prefix_expr()?).ok()?;
self_type = Some(typ.clone());
find_best_function_type(db, cache, &typ, &closure_params.signature_id)
}
LuaVarExpr::NameExpr(_) => return Some(true),
_ => return Some(true),
}
}
};

let LuaType::DocFunction(doc_func) = member_type else {
let Some(member_type) = member_type else {
return Some(true);
};

match &member_type {
LuaType::DocFunction(doc_func) => {
resolve_doc_function(db, closure_params, doc_func, self_type)
}
LuaType::Signature(id) => {
if id == &closure_params.signature_id {
return Some(true);
}
let signature = db.get_signature_index().get(id);

if let Some(signature) = signature {
let fake_doc_function = LuaFunctionType::new(
signature.is_async,
signature.is_colon_define,
signature.get_type_params(),
signature.get_return_types(),
);
resolve_doc_function(db, closure_params, &fake_doc_function, self_type)
} else {
Some(true)
}
}
_ => Some(true),
}
}

fn resolve_doc_function(
db: &mut DbIndex,
closure_params: &UnResolveParentClosureParams,
doc_func: &LuaFunctionType,
self_type: Option<LuaType>,
) -> Option<bool> {
let signature = db
.get_signature_index_mut()
.get_mut(&closure_params.signature_id)?;
Expand All @@ -220,17 +264,30 @@ pub fn try_resolve_closure_parent_params(
signature.is_async = true;
}

let colon_define = signature.is_colon_define;
let mut params = doc_func.get_params();
if colon_define {
if params.len() > 1 {
params = &params[1..];
} else {
params = &[];
let mut doc_params = doc_func.get_params().to_vec();
// doc_func 是往上追溯的有效签名, signature 是未解析的签名
match (doc_func.is_colon_define(), signature.is_colon_define) {
(true, true) | (false, false) => {}
(true, false) => {
// 原始签名是冒号定义, 但未解析的签名不是冒号定义, 即要插入第一个参数
doc_params.insert(0, ("self".to_string(), Some(LuaType::SelfInfer)));
}
(false, true) => {
// 原始签名不是冒号定义, 但未解析的签名是冒号定义, 即要删除第一个参数
doc_params.remove(0);
}
}
// 如果第一个参数是 self, 则需要将 self 的类型设置为 self_type
if doc_params.get(0).map_or(false, |(_, typ)| match typ {
Some(LuaType::SelfInfer) => true,
_ => false,
}) {
if let Some(self_type) = self_type {
doc_params[0].1 = Some(self_type);
}
}

for (index, param) in params.iter().enumerate() {
for (index, param) in doc_params.iter().enumerate() {
let name = signature.params.get(index).unwrap_or(&param.0);
signature.param_docs.insert(
index,
Expand Down Expand Up @@ -259,3 +316,94 @@ pub fn try_resolve_closure_parent_params(

Some(true)
}

fn get_owner_type_id(db: &DbIndex, info: &LuaMemberInfo) -> Option<LuaTypeDeclId> {
match &info.property_owner_id {
Some(LuaSemanticDeclId::Member(member_id)) => {
if let Some(owner) = db.get_member_index().get_current_owner(member_id) {
return owner.get_type_id().cloned();
}
None
}
_ => None,
}
}

fn find_best_function_type(
db: &DbIndex,
cache: &mut LuaInferCache,
prefix_type: &LuaType,
signature_id: &LuaSignatureId,
) -> Option<LuaType> {
let member_info_map = infer_member_map(db, &prefix_type)?;
let mut current_type_id = None;
// 如果找不到证明是重定义
let target_infos = member_info_map.into_values().find(|infos| {
infos.iter().any(|info| match &info.typ {
LuaType::Signature(id) => {
if id == signature_id {
current_type_id = get_owner_type_id(db, info);
return true;
}
false
}
_ => false,
})
})?;
// 找到第一个具有实际参数类型的签名
target_infos.iter().find_map(|info| {
// 所有者类型一致, 但我们找的是父类型
if get_owner_type_id(db, info) == current_type_id {
return None;
}
let function_type =
get_final_function_type(db, cache, &info.typ).unwrap_or(info.typ.clone());
let param_type_len = match &function_type {
LuaType::Signature(id) => db
.get_signature_index()
.get(&id)
.map(|sig| sig.param_docs.len())
.unwrap_or(0),
LuaType::DocFunction(doc_func) => doc_func
.get_params()
.iter()
.filter(|(_, typ)| typ.is_some())
.count(),
_ => 0, // 跳过其他类型
};
if param_type_len > 0 {
return Some(function_type.clone());
}
None
})
}

fn get_final_function_type(
db: &DbIndex,
cache: &mut LuaInferCache,
origin: &LuaType,
) -> Option<LuaType> {
match origin {
LuaType::Signature(_) => Some(origin.clone()),
LuaType::DocFunction(_) => Some(origin.clone()),
LuaType::Ref(decl_id) => {
let decl = db.get_type_index().get_type_decl(decl_id)?;
if decl.is_alias() {
let origin_type = decl.get_alias_origin(db, None)?;
get_final_function_type(db, cache, &origin_type)
} else {
Some(origin.clone())
}
}
LuaType::Union(union_types) => {
for typ in union_types.get_types() {
let final_type = get_final_function_type(db, cache, typ);
if final_type.is_some() {
return final_type;
}
}
None
}
_ => None,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,6 @@ mod test {

let c = ws.expr_ty("c");
let c_desc = ws.humanize_type(c);
assert_eq!(c_desc, "(string|nil)");
assert_eq!(c_desc, "string?");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,140 @@ mod test {
let expected = ws.ty("Outfit_t");
assert_eq!(ty, expected);
}

#[test]
fn test_table_field_function_param() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@alias ProxyHandler.Getter fun(self: self, raw: any, key: any, receiver: table): any

---@class ProxyHandler
---@field get ProxyHandler.Getter
"#,
);

ws.def(
r#"

---@class A: ProxyHandler
local A

function A:get(target, key, receiver, name)
a = self
end
"#,
);
let ty = ws.expr_ty("a");
let expected = ws.ty("A");
assert_eq!(ws.humanize_type(ty), ws.humanize_type(expected));

ws.def(
r#"

---@class B: ProxyHandler
local B

B.get = function(self, target, key, receiver, name)
b = self
end
"#,
);
let ty = ws.expr_ty("b");
let expected = ws.ty("B");
assert_eq!(ws.humanize_type(ty), ws.humanize_type(expected));

ws.def(
r#"
---@class C: ProxyHandler
local C = {
get = function(self, target, key, receiver, name)
c = self
end,
}
"#,
);
let ty = ws.expr_ty("c");
let expected = ws.ty("C");
assert_eq!(ws.humanize_type(ty), ws.humanize_type(expected));
}

#[test]
fn test_table_field_function_param_2() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@class ProxyHandler
local P

---@param raw any
---@param key any
---@param receiver table
---@return any
function P:get(raw, key, receiver) end
"#,
);

ws.def(
r#"
---@class A: ProxyHandler
local A

function A:get(raw, key, receiver)
a = receiver
end
"#,
);
let ty = ws.expr_ty("a");
let expected = ws.ty("table");
assert_eq!(ws.humanize_type(ty), ws.humanize_type(expected));
}

#[test]
fn test_table_field_function_param_3() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@class SimpleClass.Meta
---@field __defineSet fun(self: self, key: string, f: fun(self: self, value: any))

---@class Dep: SimpleClass.Meta
local Dep
Dep:__defineSet('subs', function(self, value)
a = self
end)
"#,
);
let ty = ws.expr_ty("a");
let expected = ws.ty("Dep");
assert_eq!(ws.humanize_type(ty), ws.humanize_type(expected));
}

#[test]
fn test_table_field_function_param_4() {
let mut ws = VirtualWorkspace::new();
ws.def(r#"
---@alias ProxyHandler.Getter fun(self: self, raw: any, key: any, receiver: table): any

---@class ProxyHandler
---@field get? ProxyHandler.Getter
"#
);

ws.def(
r#"
---@class ShallowUnwrapHandlers: ProxyHandler
local ShallowUnwrapHandlers = {
get = function(self, target, key, receiver)
a = self
end,
}
"#,
);
let ty = ws.expr_ty("a");
let expected = ws.ty("ShallowUnwrapHandlers");
assert_eq!(ws.humanize_type(ty), ws.humanize_type(expected));
}
}
Loading