Skip to content

Commit 654f651

Browse files
committed
fix infer_call_expr_func && fix resolve_signature
1 parent e9c2464 commit 654f651

File tree

3 files changed

+160
-6
lines changed

3 files changed

+160
-6
lines changed

crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,4 +776,78 @@ mod test {
776776
"#
777777
));
778778
}
779+
780+
#[test]
781+
fn test_function_union() {
782+
let mut ws = VirtualWorkspace::new();
783+
assert!(!ws.check_code_for(
784+
DiagnosticCode::ParamTypeNotMatch,
785+
r#"
786+
---@class (partial) D21.A
787+
local M
788+
789+
---@alias EventType
790+
---| GlobalEventType
791+
---| UIEventType
792+
793+
---@enum UIEventType
794+
local UIEventType = {
795+
['UI_CREATE'] = "ET_UI_PREFAB_CREATE_EVENT",
796+
}
797+
---@enum GlobalEventType
798+
local GlobalEventType = {
799+
['GAME_INIT'] = "ET_GAME_INIT",
800+
}
801+
802+
---@param event_type EventType
803+
function M:event(event_type)
804+
end
805+
806+
---@class (partial) D21.A
807+
---@field event fun(self: self, event: "游戏-初始化")
808+
809+
---@param p string
810+
local function test(p)
811+
M:event(p)
812+
end
813+
"#
814+
));
815+
}
816+
817+
#[test]
818+
fn test_function_union_2() {
819+
let mut ws = VirtualWorkspace::new();
820+
assert!(ws.check_code_for(
821+
DiagnosticCode::ParamTypeNotMatch,
822+
r#"
823+
---@class (partial) D21.A
824+
local M
825+
826+
---@alias EventType
827+
---| GlobalEventType
828+
---| UIEventType
829+
830+
---@enum UIEventType
831+
local UIEventType = {
832+
['UI_CREATE'] = "ET_UI_PREFAB_CREATE_EVENT",
833+
}
834+
---@enum GlobalEventType
835+
local GlobalEventType = {
836+
['GAME_INIT'] = "ET_GAME_INIT",
837+
}
838+
839+
---@param event_type EventType
840+
function M:event(event_type)
841+
end
842+
843+
---@class (partial) D21.A
844+
---@field event fun(self: self, event: "游戏-初始化")
845+
846+
---@param p EventType
847+
local function test(p)
848+
M:event(p)
849+
end
850+
"#
851+
));
852+
}
779853
}

crates/emmylua_code_analysis/src/semantic/infer/infer_call_func.rs

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,14 @@ pub fn infer_call_expr_func(
7676
LuaType::TableConst(meta_table) => infer_table_type_doc_function(db, meta_table),
7777
LuaType::Union(union) => {
7878
// 此时我们将其视为泛型实例化联合体
79-
if union.get_types().len() > 1
80-
&& union
81-
.get_types()
82-
.iter()
83-
.all(|t| matches!(t, LuaType::DocFunction(_)))
79+
if union
80+
.get_types()
81+
.iter()
82+
.all(|t| matches!(t, LuaType::DocFunction(_)))
8483
{
8584
infer_generic_doc_function_union(db, cache, &union, call_expr, args_count)
8685
} else {
87-
Err(InferFailReason::None)
86+
infer_union(db, cache, &union, call_expr, args_count)
8887
}
8988
}
9089
_ => return Err(InferFailReason::None),
@@ -384,3 +383,79 @@ fn infer_table_type_doc_function(db: &DbIndex, table: InFiled<TextRange>) -> Inf
384383

385384
Err(InferFailReason::None)
386385
}
386+
387+
fn infer_union(
388+
db: &DbIndex,
389+
cache: &mut LuaInferCache,
390+
union: &LuaUnionType,
391+
call_expr: LuaCallExpr,
392+
args_count: Option<usize>,
393+
) -> InferCallFuncResult {
394+
// 此时一般是 signature + doc_function 的联合体
395+
let mut all_overloads = Vec::new();
396+
let mut base_signatures = Vec::new();
397+
398+
for ty in union.get_types() {
399+
match ty {
400+
LuaType::Signature(signature_id) => {
401+
if let Some(signature) = db.get_signature_index().get(signature_id) {
402+
// 处理 overloads
403+
let overloads = if signature.is_generic() {
404+
signature
405+
.overloads
406+
.iter()
407+
.map(|func| {
408+
Ok(Arc::new(instantiate_func_generic(
409+
db,
410+
cache,
411+
func,
412+
call_expr.clone(),
413+
)?))
414+
})
415+
.collect::<Result<Vec<_>, _>>()?
416+
} else {
417+
signature.overloads.clone()
418+
};
419+
all_overloads.extend(overloads);
420+
421+
// 处理 signature 本身的函数类型
422+
let mut fake_doc_function = LuaFunctionType::new(
423+
signature.is_async,
424+
signature.is_colon_define,
425+
signature.get_type_params(),
426+
signature.get_return_types(),
427+
);
428+
if signature.is_generic() {
429+
fake_doc_function = instantiate_func_generic(
430+
db,
431+
cache,
432+
&fake_doc_function,
433+
call_expr.clone(),
434+
)?;
435+
}
436+
base_signatures.push(Arc::new(fake_doc_function));
437+
}
438+
}
439+
LuaType::DocFunction(func) => {
440+
let func_to_push = if func.contain_tpl() {
441+
Arc::new(instantiate_func_generic(
442+
db,
443+
cache,
444+
func,
445+
call_expr.clone(),
446+
)?)
447+
} else {
448+
func.clone()
449+
};
450+
base_signatures.push(func_to_push);
451+
}
452+
_ => {}
453+
}
454+
}
455+
456+
all_overloads.extend(base_signatures);
457+
if all_overloads.is_empty() {
458+
return Err(InferFailReason::None);
459+
}
460+
resolve_signature(db, cache, all_overloads, call_expr, false, args_count)
461+
}

crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ fn resolve_signature_by_args(
8686
jump_param = 0;
8787
};
8888

89+
// 冒号定义且冒号调用
90+
if is_colon_call && func.is_colon_define() {
91+
total_weight += 100;
92+
}
93+
8994
// 检查每个参数的匹配情况
9095
for (i, param) in params.iter().enumerate() {
9196
if i == 0 && jump_param > 0 {

0 commit comments

Comments
 (0)