Skip to content

Commit bc68004

Browse files
committed
completion: fix function overload
1 parent 10980e8 commit bc68004

File tree

2 files changed

+222
-46
lines changed

2 files changed

+222
-46
lines changed

crates/emmylua_ls/src/handlers/completion/providers/type_special_provider.rs

Lines changed: 91 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use emmylua_code_analysis::{
2-
InferGuard, LuaDeclLocation, LuaFunctionType, LuaMemberKey, LuaMemberOwner, LuaMultiLineUnion,
3-
LuaPropertyOwnerId, LuaType, LuaTypeDeclId, LuaUnionType, RenderLevel,
2+
InferGuard, LuaDeclLocation, LuaFunctionType, LuaMember, LuaMemberKey, LuaMemberOwner,
3+
LuaMultiLineUnion, LuaPropertyOwnerId, LuaType, LuaTypeDeclId, LuaUnionType, RenderLevel,
44
};
55
use emmylua_parser::{
66
LuaAst, LuaAstNode, LuaAstToken, LuaCallArgList, LuaCallExpr, LuaComment, LuaNameToken,
@@ -198,83 +198,130 @@ fn infer_call_arg_list(
198198
let typ = call_expr_func.get_params().get(param_idx)?.1.clone()?;
199199
let mut types = Vec::new();
200200
types.push(typ);
201-
infer_call_arg_list_overload(builder, &call_expr, &call_expr_func, param_idx, &mut types);
201+
push_function_overloads_param(
202+
builder,
203+
&call_expr,
204+
call_expr_func.get_params(),
205+
param_idx,
206+
&mut types,
207+
);
202208
Some(types.into_iter().unique().collect()) // 需要去重
203209
}
204210

205-
fn infer_call_arg_list_overload(
211+
fn push_function_overloads_param(
206212
builder: &mut CompletionBuilder,
207213
call_expr: &LuaCallExpr,
208-
call_expr_func: &LuaFunctionType,
214+
call_params: &[(String, Option<LuaType>)],
209215
param_idx: usize,
210216
types: &mut Vec<LuaType>,
211217
) -> Option<()> {
218+
let member_index = builder.semantic_model.get_db().get_member_index();
212219
let prefix_expr = call_expr.get_prefix_expr()?;
213220
let property_owner_id = builder
214221
.semantic_model
215222
.get_property_owner_id(prefix_expr.syntax().clone().into())?;
216223

217-
let signature_id = match property_owner_id {
224+
// 收集函数类型
225+
let functions = match property_owner_id {
218226
LuaPropertyOwnerId::Member(member_id) => {
219-
let member = builder
220-
.semantic_model
221-
.get_db()
222-
.get_member_index()
223-
.get_member(&member_id)?;
224-
if let LuaType::Signature(signature_id) = member.get_decl_type() {
225-
Some(signature_id)
226-
} else {
227-
None
228-
}
227+
let member = member_index.get_member(&member_id)?;
228+
let key = member.get_key().to_path();
229+
let members = member_index.get_members(&member.get_owner())?;
230+
let functions = filter_function_members(members, key);
231+
Some(functions)
229232
}
230233
LuaPropertyOwnerId::LuaDecl(decl_id) => {
231234
let decl = builder
232235
.semantic_model
233236
.get_db()
234237
.get_decl_index()
235238
.get_decl(&decl_id)?;
236-
if let LuaType::Signature(signature_id) = &decl.get_type()? {
237-
Some(signature_id.clone())
238-
} else {
239-
None
239+
240+
let typ = decl.get_type()?;
241+
match typ {
242+
LuaType::Signature(_) | LuaType::DocFunction(_) => Some(vec![typ.clone()]),
243+
_ => {
244+
let key = decl.get_name();
245+
let type_id = LuaTypeDeclId::new(decl.get_name());
246+
let members = member_index.get_members(&LuaMemberOwner::Type(type_id))?;
247+
let functions = filter_function_members(members, key.to_string());
248+
Some(functions)
249+
}
240250
}
241251
}
242252
_ => None,
243253
}?;
244254

245-
let signature = builder
246-
.semantic_model
247-
.get_db()
248-
.get_signature_index()
249-
.get(&signature_id)?;
250-
251-
let call_params = call_expr_func.get_params();
252-
for overload in signature.overloads.iter() {
253-
let overload_param = overload.get_params();
254-
// 前面的参数必须相同
255-
let mut is_match = true;
256-
if param_idx != 0 {
257-
for (i, param) in call_params.iter().enumerate() {
258-
if i < param_idx {
259-
if let Some(overload_param_type) = overload_param.get(i) {
260-
if param.1 != overload_param_type.1 {
261-
is_match = false;
262-
break;
263-
}
264-
}
265-
} else {
266-
break;
255+
// 获取重载函数列表
256+
let signature_index = builder.semantic_model.get_db().get_signature_index();
257+
let mut overloads = Vec::new();
258+
for function in functions {
259+
match function {
260+
LuaType::Signature(signature_id) => {
261+
if let Some(signature) = signature_index.get(&signature_id) {
262+
overloads.extend(signature.overloads.iter().cloned());
267263
}
268264
}
265+
LuaType::DocFunction(doc_function) => {
266+
overloads.push(doc_function);
267+
}
268+
_ => {}
269269
}
270-
if !is_match {
270+
}
271+
272+
// 筛选匹配的参数类型并添加到结果中
273+
for overload in overloads.iter() {
274+
let overload_params = overload.get_params();
275+
276+
// 检查前面的参数是否匹配
277+
if !params_match_prefix(call_params, overload_params, param_idx) {
271278
continue;
272279
}
273-
if let Some(param_type) = overload.get_params().get(param_idx)?.1.clone() {
280+
281+
// 添加匹配的参数类型
282+
if let Some(param_type) = overload_params.get(param_idx).and_then(|p| p.1.clone()) {
274283
types.push(param_type);
275284
}
276285
}
277286

287+
/// 过滤出函数类型的成员
288+
fn filter_function_members(members: Vec<&LuaMember>, key: String) -> Vec<LuaType> {
289+
members
290+
.into_iter()
291+
.filter(|it| {
292+
it.get_key().to_path() == key
293+
&& matches!(
294+
it.get_decl_type(),
295+
LuaType::Signature(_) | LuaType::DocFunction(_)
296+
)
297+
})
298+
.map(|it| it.get_decl_type())
299+
.collect()
300+
}
301+
302+
/// 判断前面的参数是否匹配
303+
fn params_match_prefix(
304+
call_params: &[(String, Option<LuaType>)],
305+
overload_params: &[(String, Option<LuaType>)],
306+
param_idx: usize,
307+
) -> bool {
308+
if param_idx == 0 {
309+
return true;
310+
}
311+
312+
for i in 0..param_idx {
313+
if let (Some(call_param), Some(overload_param)) =
314+
(call_params.get(i), overload_params.get(i))
315+
{
316+
if call_param.1 != overload_param.1 {
317+
return false;
318+
}
319+
}
320+
}
321+
322+
true
323+
}
324+
278325
Some(())
279326
}
280327

crates/emmylua_ls/src/handlers/completion/test/completion_test.rs

Lines changed: 131 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
#[cfg(test)]
22
mod tests {
33

4-
use lsp_types::CompletionItemKind;
4+
use lsp_types::{CompletionItemKind, CompletionTriggerKind};
55

66
use crate::handlers::completion::test::{CompletionVirtualWorkspace, VirtualCompletionItem};
77

88
#[test]
9-
fn test_basic() {
9+
fn test_1() {
1010
let mut ws = CompletionVirtualWorkspace::new();
1111

1212
assert!(ws.check_completion(
@@ -20,4 +20,133 @@ mod tests {
2020
}],
2121
));
2222
}
23+
24+
#[test]
25+
fn test_2() {
26+
let mut ws = CompletionVirtualWorkspace::new();
27+
assert!(ws.check_completion(
28+
r#"
29+
---@overload fun(event: "AAA", callback: fun(trg: string, data: number)): number
30+
---@overload fun(event: "BBB", callback: fun(trg: string, data: string)): string
31+
local function test(event, callback)
32+
end
33+
34+
test("AAA", function(trg, data)
35+
<??>
36+
end)
37+
"#,
38+
vec![
39+
VirtualCompletionItem {
40+
label: "data".to_string(),
41+
kind: CompletionItemKind::VARIABLE,
42+
},
43+
VirtualCompletionItem {
44+
label: "trg".to_string(),
45+
kind: CompletionItemKind::VARIABLE,
46+
},
47+
VirtualCompletionItem {
48+
label: "test".to_string(),
49+
kind: CompletionItemKind::FUNCTION,
50+
},
51+
],
52+
));
53+
54+
// 主动触发补全
55+
assert!(ws.check_completion(
56+
r#"
57+
---@overload fun(event: "AAA", callback: fun(trg: string, data: number)): number
58+
---@overload fun(event: "BBB", callback: fun(trg: string, data: string)): string
59+
local function test(event, callback)
60+
end
61+
test(<??>)
62+
"#,
63+
vec![
64+
VirtualCompletionItem {
65+
label: "\"AAA\"".to_string(),
66+
kind: CompletionItemKind::ENUM_MEMBER,
67+
},
68+
VirtualCompletionItem {
69+
label: "\"BBB\"".to_string(),
70+
kind: CompletionItemKind::ENUM_MEMBER,
71+
},
72+
VirtualCompletionItem {
73+
label: "test".to_string(),
74+
kind: CompletionItemKind::FUNCTION,
75+
},
76+
],
77+
));
78+
79+
// 被动触发补全
80+
assert!(ws.check_completion_with_kind(
81+
r#"
82+
---@overload fun(event: "AAA", callback: fun(trg: string, data: number)): number
83+
---@overload fun(event: "BBB", callback: fun(trg: string, data: string)): string
84+
local function test(event, callback)
85+
end
86+
test(<??>)
87+
"#,
88+
vec![
89+
VirtualCompletionItem {
90+
label: "\"AAA\"".to_string(),
91+
kind: CompletionItemKind::ENUM_MEMBER,
92+
},
93+
VirtualCompletionItem {
94+
label: "\"BBB\"".to_string(),
95+
kind: CompletionItemKind::ENUM_MEMBER,
96+
},
97+
],
98+
CompletionTriggerKind::TRIGGER_CHARACTER,
99+
));
100+
}
101+
102+
#[test]
103+
fn test_3() {
104+
let mut ws = CompletionVirtualWorkspace::new();
105+
// 被动触发补全
106+
assert!(ws.check_completion_with_kind(
107+
r#"
108+
---@class Completion.Test4
109+
---@field event fun(a: "A", b: number)
110+
---@field event fun(a: "B", b: string)
111+
local Test4 = {}
112+
Test4.event(<??>)
113+
"#,
114+
vec![
115+
VirtualCompletionItem {
116+
label: "\"A\"".to_string(),
117+
kind: CompletionItemKind::ENUM_MEMBER,
118+
},
119+
VirtualCompletionItem {
120+
label: "\"B\"".to_string(),
121+
kind: CompletionItemKind::ENUM_MEMBER,
122+
},
123+
],
124+
CompletionTriggerKind::TRIGGER_CHARACTER,
125+
));
126+
127+
// 主动触发补全
128+
assert!(ws.check_completion(
129+
r#"
130+
---@class Completion.Test4
131+
---@field event fun(a: "A", b: number)
132+
---@field event fun(a: "B", b: string)
133+
local Test4 = {}
134+
Test4.event(<??>)
135+
"#,
136+
vec![
137+
VirtualCompletionItem {
138+
label: "\"A\"".to_string(),
139+
kind: CompletionItemKind::ENUM_MEMBER,
140+
},
141+
VirtualCompletionItem {
142+
label: "\"B\"".to_string(),
143+
kind: CompletionItemKind::ENUM_MEMBER,
144+
},
145+
VirtualCompletionItem {
146+
label: "Test4".to_string(),
147+
kind: CompletionItemKind::CLASS,
148+
},
149+
],
150+
));
151+
}
23152
}

0 commit comments

Comments
 (0)