Skip to content

Commit 25f61a4

Browse files
committed
fix goto_function
1 parent 7c56f6d commit 25f61a4

File tree

2 files changed

+171
-80
lines changed

2 files changed

+171
-80
lines changed

crates/emmylua_ls/src/handlers/definition/goto_function.rs

Lines changed: 128 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use emmylua_code_analysis::{
2-
LuaCompilation, LuaDeclId, LuaFunctionType, LuaMemberId, LuaSemanticDeclId, LuaSignature,
3-
LuaSignatureId, LuaType, SemanticDeclLevel, SemanticModel, instantiate_func_generic,
2+
LuaCompilation, LuaDeclId, LuaFunctionType, LuaSemanticDeclId, LuaSignature, LuaSignatureId,
3+
LuaType, SemanticDeclLevel, SemanticModel, instantiate_func_generic,
4+
};
5+
use emmylua_parser::{
6+
LuaAstNode, LuaCallExpr, LuaExpr, LuaLiteralToken, LuaSyntaxToken, LuaTokenKind,
47
};
5-
use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaSyntaxToken, LuaTokenKind};
68
use rowan::{NodeOrToken, TokenAtOffset};
79
use std::sync::Arc;
810

@@ -27,17 +29,44 @@ pub fn find_matching_function_definitions(
2729
let mut has_match = false;
2830

2931
for (decl, member_id) in member_decls {
30-
if process_member_function_type(
31-
semantic_model,
32-
compilation,
33-
&call_function,
34-
&call_expr,
35-
decl,
36-
member_id,
37-
&mut result,
38-
&mut has_match,
39-
) {
40-
continue;
32+
let typ = semantic_model.get_type(member_id.clone().into());
33+
match typ {
34+
LuaType::DocFunction(func) => {
35+
if compare_function_types(semantic_model, &call_function, &func, &call_expr)
36+
.unwrap_or(false)
37+
{
38+
result.push(decl.clone());
39+
has_match = true;
40+
}
41+
}
42+
LuaType::Signature(signature_id) => {
43+
let signature = match semantic_model
44+
.get_db()
45+
.get_signature_index()
46+
.get(&signature_id)
47+
{
48+
Some(sig) => sig,
49+
None => continue,
50+
};
51+
let functions = get_signature_functions(signature);
52+
53+
if functions.iter().any(|func| {
54+
compare_function_types(semantic_model, &call_function, func, &call_expr)
55+
.unwrap_or(false)
56+
}) {
57+
has_match = true;
58+
}
59+
60+
// 无论是否匹配, 都需要将真实的定义添加到结果中
61+
// 如果存在原始定义, 则优先使用原始定义
62+
let origin = extract_semantic_decl_from_signature(compilation, &signature_id);
63+
if let Some(origin) = origin {
64+
result.insert(0, origin);
65+
} else {
66+
result.insert(0, decl.clone());
67+
}
68+
}
69+
_ => continue,
4170
}
4271
}
4372

@@ -48,58 +77,6 @@ pub fn find_matching_function_definitions(
4877
}
4978
}
5079

51-
fn process_member_function_type(
52-
semantic_model: &SemanticModel,
53-
compilation: &LuaCompilation,
54-
call_function: &Arc<LuaFunctionType>,
55-
call_expr: &LuaCallExpr,
56-
decl: &LuaSemanticDeclId,
57-
member_id: &LuaMemberId,
58-
result: &mut Vec<LuaSemanticDeclId>,
59-
has_match: &mut bool,
60-
) -> bool {
61-
let typ = semantic_model.get_type(member_id.clone().into());
62-
match typ {
63-
LuaType::DocFunction(func) => {
64-
if compare_function_types(semantic_model, call_function, &func, call_expr)
65-
.unwrap_or(false)
66-
{
67-
result.push(decl.clone());
68-
*has_match = true;
69-
}
70-
}
71-
LuaType::Signature(signature_id) => {
72-
let signature = match semantic_model
73-
.get_db()
74-
.get_signature_index()
75-
.get(&signature_id)
76-
{
77-
Some(sig) => sig,
78-
None => return false,
79-
};
80-
let functions = get_signature_functions(signature);
81-
82-
if functions.iter().any(|func| {
83-
compare_function_types(semantic_model, call_function, func, call_expr)
84-
.unwrap_or(false)
85-
}) {
86-
*has_match = true;
87-
}
88-
89-
// 无论是否匹配,都需要将真实的定义添加到结果中
90-
// 如果存在原始定义,则优先使用原始定义
91-
let origin = extract_semantic_decl_from_signature(compilation, &signature_id);
92-
if let Some(origin) = origin {
93-
result.insert(0, origin);
94-
} else {
95-
result.insert(0, decl.clone());
96-
}
97-
}
98-
_ => return false,
99-
}
100-
true
101-
}
102-
10380
pub fn find_function_call_origin(
10481
semantic_model: &SemanticModel,
10582
compilation: &LuaCompilation,
@@ -205,28 +182,99 @@ fn get_call_function(
205182
) -> Option<Arc<LuaFunctionType>> {
206183
let func = semantic_model.infer_call_expr_func(call_expr.clone(), None);
207184
if let Some(func) = func {
208-
let call_expr_args_count = call_expr.get_args_count();
209-
if let Some(mut call_expr_args_count) = call_expr_args_count {
210-
let mut func_params_count = func.get_params().len();
211-
if !func.is_colon_define() && call_expr.is_colon_call() {
212-
// 不是冒号定义的函数, 但是是冒号调用
213-
call_expr_args_count += 1;
185+
if check_params_count_is_match(semantic_model, &func, call_expr.clone()).unwrap_or(false) {
186+
return Some(func);
187+
}
188+
}
189+
None
190+
}
191+
192+
fn check_params_count_is_match(
193+
semantic_model: &SemanticModel,
194+
call_function: &LuaFunctionType,
195+
call_expr: LuaCallExpr,
196+
) -> Option<bool> {
197+
let mut fake_params = call_function.get_params().to_vec();
198+
let call_args = call_expr.get_args_list()?.get_args().collect::<Vec<_>>();
199+
let mut call_args_count = call_args.len();
200+
let colon_call = call_expr.is_colon_call();
201+
let colon_define = call_function.is_colon_define();
202+
match (colon_call, colon_define) {
203+
(true, true) | (false, false) => {}
204+
(false, true) => {
205+
fake_params.insert(0, ("self".to_string(), Some(LuaType::SelfInfer)));
206+
}
207+
(true, false) => {
208+
call_args_count += 1;
209+
}
210+
}
211+
if call_args_count < fake_params.len() {
212+
// 调用参数包含 `...`
213+
for arg in call_args.iter() {
214+
if let LuaExpr::LiteralExpr(literal_expr) = arg {
215+
if let Some(literal_token) = literal_expr.get_literal() {
216+
if let LuaLiteralToken::Dots(_) = literal_token {
217+
return Some(true);
218+
}
219+
}
214220
}
215-
// 如果参数有可空参数, 则需要减去
216-
for (_, param_type) in func.get_params().iter() {
217-
if let Some(param_type) = param_type {
218-
if param_type.is_optional() {
219-
func_params_count -= 1;
221+
}
222+
// 对调用参数的最后一个参数进行特殊处理
223+
if let Some(last_arg) = call_args.last() {
224+
if let Ok(LuaType::Variadic(variadic)) = semantic_model.infer_expr(last_arg.clone()) {
225+
let len = match variadic.get_max_len() {
226+
Some(len) => len,
227+
None => {
228+
return Some(true);
220229
}
230+
};
231+
call_args_count = call_args_count + len as usize - 1;
232+
if call_args_count >= fake_params.len() {
233+
return Some(true);
234+
}
235+
}
236+
}
237+
238+
for i in call_args_count..fake_params.len() {
239+
let param_info = fake_params.get(i)?;
240+
if param_info.0 == "..." {
241+
return Some(true);
242+
}
243+
244+
let typ = param_info.1.clone();
245+
if let Some(typ) = typ {
246+
if !typ.is_optional() {
247+
return Some(false);
221248
}
222249
}
250+
}
251+
} else if call_args_count > fake_params.len() {
252+
// 参数定义中最后一个参数是 `...`
253+
if fake_params.last().map_or(false, |(name, typ)| {
254+
name == "..."
255+
|| if let Some(typ) = typ {
256+
typ.is_variadic()
257+
} else {
258+
false
259+
}
260+
}) {
261+
return Some(true);
262+
}
223263

224-
if call_expr_args_count == func_params_count {
225-
return Some(func);
264+
let mut adjusted_index = 0;
265+
if colon_call != colon_define {
266+
adjusted_index = if colon_define && !colon_call { -1 } else { 1 };
267+
}
268+
269+
for (i, _) in call_args.iter().enumerate() {
270+
let param_index = i as isize + adjusted_index;
271+
if param_index < 0 || param_index < fake_params.len() as isize {
272+
continue;
226273
}
274+
return Some(false);
227275
}
228276
}
229-
None
277+
Some(true)
230278
}
231279

232280
fn get_signature_functions(signature: &LuaSignature) -> Vec<Arc<LuaFunctionType>> {

crates/emmylua_ls/src/handlers/test/definition_test.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,4 +430,47 @@ mod tests {
430430

431431
Ok(())
432432
}
433+
434+
#[test]
435+
fn test_goto_variable_param() {
436+
let mut ws = ProviderVirtualWorkspace::new();
437+
ws.def_file(
438+
"a.lua",
439+
r#"
440+
---@class Observable<T>
441+
442+
---test
443+
local function zipLatest(...)
444+
end
445+
return zipLatest
446+
"#,
447+
);
448+
ws.def_file(
449+
"b.lua",
450+
r#"
451+
local export = {}
452+
local zipLatest = require('a')
453+
export.zipLatest = zipLatest
454+
return export
455+
"#,
456+
);
457+
let result = ws
458+
.check_definition(
459+
r#"
460+
local zipLatest = require('b').zipLatest
461+
zipLatest<??>()
462+
"#,
463+
)
464+
.unwrap();
465+
match result {
466+
GotoDefinitionResponse::Array(array) => {
467+
assert_eq!(array.len(), 1);
468+
let location = &array[0];
469+
assert_eq!(location.uri.path().as_str().ends_with("a.lua"), true);
470+
}
471+
_ => {
472+
panic!("expect array");
473+
}
474+
}
475+
}
433476
}

0 commit comments

Comments
 (0)