Skip to content

Commit 7c6bd04

Browse files
committed
feature: signature 参数没有doc声明时自动推断为同名docfunction联合声明
1 parent 6596371 commit 7c6bd04

File tree

3 files changed

+258
-38
lines changed

3 files changed

+258
-38
lines changed

crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs

Lines changed: 119 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
use std::sync::Arc;
2+
13
use emmylua_parser::{LuaAstNode, LuaTableExpr, LuaVarExpr};
24

35
use crate::{
46
infer_call_expr_func, infer_expr, infer_member_map, infer_table_should_be, DbIndex,
57
InferFailReason, InferGuard, LuaDocParamInfo, LuaDocReturnInfo, LuaFunctionType, LuaInferCache,
6-
LuaMemberInfo, LuaSemanticDeclId, LuaSignatureId, LuaType, LuaTypeDeclId,
7-
SignatureReturnStatus,
8+
LuaMemberInfo, LuaSemanticDeclId, LuaSignatureId, LuaType, LuaTypeDeclId, LuaUnionType,
9+
SignatureReturnStatus, TypeOps,
810
};
911

1012
use super::{
@@ -232,7 +234,6 @@ pub fn try_resolve_closure_parent_params(
232234
}
233235
};
234236
self_type = Some(typ.clone());
235-
236237
find_best_function_type(db, cache, &typ, &closure_params.signature_id)
237238
}
238239
_ => return Some(true),
@@ -277,8 +278,9 @@ pub fn try_resolve_closure_parent_params(
277278
}
278279
};
279280

280-
let Some(member_type) = member_type else {
281-
return Some(true);
281+
let member_type = match member_type {
282+
Some(member_type) => member_type,
283+
None => return Some(true),
282284
};
283285

284286
match &member_type {
@@ -298,6 +300,58 @@ pub fn try_resolve_closure_parent_params(
298300
Some(true)
299301
}
300302
}
303+
LuaType::Union(union_types) => {
304+
let mut final_params = signature.get_type_params().to_vec();
305+
for typ in union_types.get_types() {
306+
let LuaType::DocFunction(doc_func) = typ else {
307+
continue;
308+
};
309+
let mut doc_params = doc_func.get_params().to_vec();
310+
match (doc_func.is_colon_define(), signature.is_colon_define) {
311+
(true, true) | (false, false) => {}
312+
(true, false) => {
313+
// 原始签名是冒号定义, 但未解析的签名不是冒号定义, 即要插入第一个参数
314+
doc_params.insert(0, ("self".to_string(), Some(LuaType::SelfInfer)));
315+
}
316+
(false, true) => {
317+
// 原始签名不是冒号定义, 但未解析的签名是冒号定义, 即要删除第一个参数
318+
doc_params.remove(0);
319+
}
320+
}
321+
// 如果第一个参数是 self, 则需要将 self 的类型设置为 self_type
322+
if doc_params.get(0).map_or(false, |(_, typ)| match typ {
323+
Some(LuaType::SelfInfer) => true,
324+
_ => false,
325+
}) {
326+
if let Some(self_type) = &self_type {
327+
doc_params[0].1 = Some(self_type.clone());
328+
}
329+
}
330+
for (idx, param) in doc_params.iter().enumerate() {
331+
if let Some(final_param) = final_params.get(idx) {
332+
if final_param.0 == "..." {
333+
continue;
334+
}
335+
let new_type = TypeOps::Union.apply(
336+
final_param.1.as_ref().unwrap_or(&LuaType::Unknown),
337+
param.1.as_ref().unwrap_or(&LuaType::Unknown),
338+
);
339+
final_params[idx] = (final_param.0.clone(), Some(new_type));
340+
}
341+
}
342+
}
343+
resolve_doc_function(
344+
db,
345+
closure_params,
346+
&LuaFunctionType::new(
347+
signature.is_async,
348+
signature.is_colon_define,
349+
final_params,
350+
signature.get_return_types(),
351+
),
352+
self_type,
353+
)
354+
}
301355
_ => Some(true),
302356
}
303357
}
@@ -363,7 +417,6 @@ fn resolve_doc_function(
363417
description: None,
364418
});
365419
}
366-
367420
Some(true)
368421
}
369422

@@ -385,47 +438,80 @@ fn find_best_function_type(
385438
prefix_type: &LuaType,
386439
signature_id: &LuaSignatureId,
387440
) -> Option<LuaType> {
388-
let member_info_map = infer_member_map(db, &prefix_type)?;
441+
let member_info_map = infer_member_map(db, prefix_type)?;
389442
let mut current_type_id = None;
390-
// 如果找不到证明是重定义
443+
391444
let target_infos = member_info_map.into_values().find(|infos| {
392-
infos.iter().any(|info| match &info.typ {
393-
LuaType::Signature(id) => {
445+
infos.iter().any(|info| {
446+
if let LuaType::Signature(id) = &info.typ {
394447
if id == signature_id {
395448
current_type_id = get_owner_type_id(db, info);
396449
return true;
397450
}
398-
false
399451
}
400-
_ => false,
452+
false
401453
})
402454
})?;
403-
// 找到第一个具有实际参数类型的签名
404-
target_infos.iter().find_map(|info| {
405-
// 所有者类型一致, 但我们找的是父类型
406-
if get_owner_type_id(db, info) == current_type_id {
407-
return None;
408-
}
455+
456+
let mut current_function_types = Vec::with_capacity(target_infos.len());
457+
// 父类或许也应该返回联合类型
458+
let mut parent_function_type = None;
459+
460+
for info in target_infos {
409461
let function_type =
410-
get_final_function_type(db, cache, &info.typ).unwrap_or(info.typ.clone());
411-
let param_type_len = match &function_type {
462+
get_final_function_type(db, cache, &info.typ).unwrap_or_else(|| info.typ.clone());
463+
464+
// 所有者类型一致, 不是父类
465+
if get_owner_type_id(db, &info) == current_type_id {
466+
match &function_type {
467+
LuaType::Signature(id) => {
468+
if let Some(cur_signature) = db.get_signature_index().get(id) {
469+
// 只需要重载声明
470+
if cur_signature.param_docs.is_empty() {
471+
current_function_types.extend(cur_signature.overloads.iter().cloned());
472+
}
473+
}
474+
}
475+
LuaType::DocFunction(doc_func) => {
476+
// 使用迭代器优化参数计数
477+
if doc_func.get_params().iter().any(|(_, typ)| typ.is_some()) {
478+
current_function_types.push(doc_func.clone());
479+
}
480+
}
481+
_ => {}
482+
}
483+
continue;
484+
}
485+
486+
// 父类处理
487+
let has_params = match &function_type {
412488
LuaType::Signature(id) => db
413489
.get_signature_index()
414-
.get(&id)
415-
.map(|sig| sig.param_docs.len())
416-
.unwrap_or(0),
417-
LuaType::DocFunction(doc_func) => doc_func
418-
.get_params()
419-
.iter()
420-
.filter(|(_, typ)| typ.is_some())
421-
.count(),
422-
_ => 0, // 跳过其他类型
490+
.get(id)
491+
.map_or(false, |sig| !sig.param_docs.is_empty()),
492+
LuaType::DocFunction(doc_func) => {
493+
doc_func.get_params().iter().any(|(_, typ)| typ.is_some())
494+
}
495+
_ => false,
423496
};
424-
if param_type_len > 0 {
425-
return Some(function_type.clone());
497+
498+
if has_params {
499+
parent_function_type = Some(function_type);
426500
}
427-
None
428-
})
501+
}
502+
match current_function_types.len() {
503+
0 => parent_function_type,
504+
1 => current_function_types
505+
.into_iter()
506+
.next()
507+
.map(LuaType::DocFunction),
508+
_ => Some(LuaType::Union(Arc::new(LuaUnionType::new(
509+
current_function_types
510+
.into_iter()
511+
.map(LuaType::DocFunction)
512+
.collect(),
513+
)))),
514+
}
429515
}
430516

431517
fn get_final_function_type(

crates/emmylua_code_analysis/src/compilation/test/closure_param_infer_test.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,4 +209,69 @@ mod test {
209209
let expected = ws.ty("string[]");
210210
assert_eq!(ws.humanize_type(ty), ws.humanize_type(expected));
211211
}
212+
213+
#[test]
214+
fn test_field_doc_function() {
215+
let mut ws = VirtualWorkspace::new();
216+
217+
ws.def(
218+
r#"
219+
---@class ClosureTest
220+
---@field e fun(a: string, b: number)
221+
---@field e fun(a: number, b: number)
222+
local Test
223+
224+
function Test.e(a, b)
225+
A = a
226+
end
227+
"#,
228+
);
229+
let ty = ws.expr_ty("A");
230+
let expected = ws.ty("string|number");
231+
assert_eq!(ws.humanize_type(ty), ws.humanize_type(expected));
232+
}
233+
234+
#[test]
235+
fn test_field_doc_function_2() {
236+
let mut ws = VirtualWorkspace::new();
237+
238+
ws.def(
239+
r#"
240+
---@class ClosureTest
241+
---@field e fun(a: string, b: number)
242+
---@field e fun(a: number, b: number)
243+
local Test
244+
245+
---@overload fun(a: string, b: number)
246+
---@overload fun(a: number, b: number)
247+
function Test.e(a, b)
248+
d = b
249+
end
250+
"#,
251+
);
252+
let ty = ws.expr_ty("d");
253+
let expected = ws.ty("number");
254+
assert_eq!(ws.humanize_type(ty), ws.humanize_type(expected));
255+
}
256+
257+
#[test]
258+
fn test_field_doc_function_3() {
259+
let mut ws = VirtualWorkspace::new();
260+
261+
ws.def(
262+
r#"
263+
---@class ClosureTest
264+
---@field e fun(a: string, b: number) -- 不在 overload 时必须声明 self 才被视为方法
265+
---@field e fun(a: number, b: number)
266+
local Test
267+
268+
function Test:e(a, b) -- `:`声明
269+
A = a
270+
end
271+
"#,
272+
);
273+
let ty = ws.expr_ty("A");
274+
let expected = ws.ty("number");
275+
assert_eq!(ws.humanize_type(ty), ws.humanize_type(expected));
276+
}
212277
}

crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,19 +192,88 @@ fn find_param_type_from_type(
192192
return typ.clone();
193193
}
194194
}
195+
LuaType::Union(_) => {
196+
return find_param_type_from_union(db, source_type, param_idx, current_colon_define)
197+
}
198+
_ => {}
199+
}
200+
201+
None
202+
}
203+
204+
fn find_param_type_from_union(
205+
db: &DbIndex,
206+
source_type: LuaType,
207+
param_idx: usize,
208+
origin_colon_define: bool,
209+
) -> Option<LuaType> {
210+
match source_type {
211+
LuaType::Signature(signature_id) => {
212+
let signature = db.get_signature_index().get(&signature_id)?;
213+
if !signature.param_docs.is_empty() {
214+
return None;
215+
}
216+
let mut final_type = None;
217+
for overload in &signature.overloads {
218+
let mut param_idx = param_idx;
219+
match (origin_colon_define, overload.is_colon_define()) {
220+
(true, false) => {
221+
param_idx += 1;
222+
}
223+
(false, true) => {
224+
if param_idx > 0 {
225+
param_idx -= 1;
226+
}
227+
}
228+
_ => {}
229+
}
230+
231+
if let Some((_, typ)) = overload.get_params().get(param_idx) {
232+
if let Some(typ) = typ {
233+
final_type = match final_type {
234+
Some(existing) => Some(TypeOps::Union.apply(&existing, typ)),
235+
None => Some(typ.clone()),
236+
};
237+
}
238+
}
239+
}
240+
final_type
241+
}
242+
LuaType::DocFunction(f) => {
243+
let mut param_idx = param_idx;
244+
match (origin_colon_define, f.is_colon_define()) {
245+
(true, false) => {
246+
param_idx += 1;
247+
}
248+
(false, true) => {
249+
if param_idx > 0 {
250+
param_idx -= 1;
251+
}
252+
}
253+
_ => {}
254+
}
255+
256+
if let Some((_, typ)) = f.get_params().get(param_idx) {
257+
return typ.clone();
258+
}
259+
None
260+
}
195261
LuaType::Union(union_types) => {
262+
let mut final_type = None;
196263
for ty in union_types.get_types() {
197264
if let Some(ty) =
198-
find_param_type_from_type(db, ty.clone(), param_idx, current_colon_define)
265+
find_param_type_from_union(db, ty.clone(), param_idx, origin_colon_define)
199266
{
200-
return Some(ty);
267+
final_type = match final_type {
268+
Some(existing) => Some(TypeOps::Union.apply(&existing, &ty)),
269+
None => Some(ty),
270+
};
201271
}
202272
}
273+
final_type
203274
}
204-
_ => {}
275+
_ => None,
205276
}
206-
207-
None
208277
}
209278

210279
pub fn infer_global_type(db: &DbIndex, name: &str) -> InferResult {

0 commit comments

Comments
 (0)