Skip to content

Commit 70e21b7

Browse files
committed
update definition: goto function
1 parent 05a4bfc commit 70e21b7

File tree

3 files changed

+238
-30
lines changed

3 files changed

+238
-30
lines changed

crates/emmylua_ls/src/handlers/definition/goto_def_definition.rs

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ use itertools::Itertools;
1212
use lsp_types::{GotoDefinitionResponse, Location, Position, Range, Uri};
1313

1414
use crate::handlers::{
15-
definition::goto_function::find_call_match_function,
15+
definition::goto_function::{
16+
find_call_expr_origin_for_decl, find_call_match_function_for_member,
17+
},
1618
hover::{find_all_same_named_members, find_member_origin_owner},
1719
};
1820

@@ -35,6 +37,23 @@ pub fn goto_def_definition(
3537
}
3638
match property_owner {
3739
LuaSemanticDeclId::LuaDecl(decl_id) => {
40+
if let Some(match_semantic_decl) = find_call_expr_origin_for_decl(
41+
semantic_model,
42+
compilation,
43+
trigger_token,
44+
&property_owner,
45+
) {
46+
match match_semantic_decl {
47+
LuaSemanticDeclId::LuaDecl(decl_id) => {
48+
return Some(GotoDefinitionResponse::Scalar(get_decl_location(
49+
semantic_model,
50+
&decl_id,
51+
)?));
52+
}
53+
_ => {}
54+
};
55+
}
56+
3857
let location = get_decl_location(semantic_model, &decl_id)?;
3958
return Some(GotoDefinitionResponse::Scalar(location));
4059
}
@@ -46,36 +65,52 @@ pub fn goto_def_definition(
4665

4766
let mut locations: Vec<Location> = Vec::new();
4867
// 如果是函数调用, 则尝试寻找最匹配的定义
49-
if let Some(match_members) =
50-
find_call_match_function(semantic_model, trigger_token, &same_named_members)
51-
{
68+
if let Some(match_members) = find_call_match_function_for_member(
69+
semantic_model,
70+
compilation,
71+
trigger_token,
72+
&same_named_members,
73+
) {
5274
for member in match_members {
53-
if let LuaSemanticDeclId::Member(member_id) = member {
54-
if let Some(true) = should_trace_member(semantic_model, &member_id) {
55-
// 尝试搜索这个成员最原始的定义
56-
match find_member_origin_owner(compilation, semantic_model, member_id) {
57-
Some(LuaSemanticDeclId::Member(member_id)) => {
58-
if let Some(location) =
59-
get_member_location(semantic_model, &member_id)
60-
{
61-
locations.push(location);
62-
continue;
75+
match member {
76+
LuaSemanticDeclId::Member(member_id) => {
77+
if should_trace_member(semantic_model, &member_id).unwrap_or(false) {
78+
// 尝试搜索这个成员最原始的定义
79+
match find_member_origin_owner(
80+
compilation,
81+
semantic_model,
82+
member_id,
83+
) {
84+
Some(LuaSemanticDeclId::Member(member_id)) => {
85+
if let Some(location) =
86+
get_member_location(semantic_model, &member_id)
87+
{
88+
locations.push(location);
89+
continue;
90+
}
6391
}
64-
}
65-
Some(LuaSemanticDeclId::LuaDecl(decl_id)) => {
66-
if let Some(location) =
67-
get_decl_location(semantic_model, &decl_id)
68-
{
69-
locations.push(location);
70-
continue;
92+
Some(LuaSemanticDeclId::LuaDecl(decl_id)) => {
93+
if let Some(location) =
94+
get_decl_location(semantic_model, &decl_id)
95+
{
96+
locations.push(location);
97+
continue;
98+
}
7199
}
100+
_ => {}
72101
}
73-
_ => {}
102+
}
103+
if let Some(location) = get_member_location(semantic_model, &member_id)
104+
{
105+
locations.push(location);
74106
}
75107
}
76-
if let Some(location) = get_member_location(semantic_model, &member_id) {
77-
locations.push(location);
108+
LuaSemanticDeclId::LuaDecl(decl_id) => {
109+
if let Some(location) = get_decl_location(semantic_model, &decl_id) {
110+
locations.push(location);
111+
}
78112
}
113+
_ => {}
79114
}
80115
}
81116
if !locations.is_empty() {

crates/emmylua_ls/src/handlers/definition/goto_function.rs

Lines changed: 104 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
use emmylua_code_analysis::{
2-
instantiate_func_generic, LuaFunctionType, LuaSemanticDeclId, LuaSignature, LuaType,
3-
SemanticModel,
2+
instantiate_func_generic, LuaCompilation, LuaFunctionType, LuaSemanticDeclId, LuaSignature,
3+
LuaSignatureId, LuaType, SemanticDeclLevel, SemanticModel,
44
};
5-
use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaSyntaxToken};
5+
use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaSyntaxToken, LuaTokenKind};
6+
use rowan::{NodeOrToken, TokenAtOffset};
67
use std::sync::Arc;
78

8-
pub fn find_call_match_function(
9+
pub fn find_call_match_function_for_member(
910
semantic_model: &SemanticModel,
11+
compilation: &LuaCompilation,
1012
trigger_token: &LuaSyntaxToken,
1113
semantic_decls: &Vec<LuaSemanticDeclId>,
1214
) -> Option<Vec<LuaSemanticDeclId>> {
@@ -44,8 +46,14 @@ pub fn find_call_match_function(
4446
}) {
4547
has_match = true;
4648
}
47-
// 此处为降低优先级, 因为如果返回多个选项, 那么 vscode 会默认指向最后的选项
48-
result.insert(0, decl.clone());
49+
// 无论是否匹配, 都需要将真实的定义添加到结果中
50+
// 如果存在原始定义, 则优先使用原始定义
51+
let origin = get_signature_origin(compilation, &signature_id);
52+
if let Some(origin) = origin {
53+
result.insert(0, origin);
54+
} else {
55+
result.insert(0, decl.clone());
56+
}
4957
}
5058
_ => continue,
5159
}
@@ -61,6 +69,96 @@ pub fn find_call_match_function(
6169
}
6270
}
6371

72+
fn get_signature_origin(
73+
compilation: &LuaCompilation,
74+
signature_id: &LuaSignatureId,
75+
) -> Option<LuaSemanticDeclId> {
76+
let semantic_model = compilation.get_semantic_model(signature_id.get_file_id())?;
77+
let root = semantic_model.get_root_by_file_id(signature_id.get_file_id())?;
78+
let token = match root.syntax().token_at_offset(signature_id.get_position()) {
79+
TokenAtOffset::Single(token) => token,
80+
TokenAtOffset::Between(left, right) => {
81+
if left.kind() == LuaTokenKind::TkName.into() {
82+
left
83+
} else if left.kind() == LuaTokenKind::TkLeftBracket.into()
84+
&& right.kind() == LuaTokenKind::TkInt.into()
85+
{
86+
left
87+
} else {
88+
right
89+
}
90+
}
91+
TokenAtOffset::None => {
92+
return None;
93+
}
94+
};
95+
let semantic_info =
96+
semantic_model.find_decl(NodeOrToken::Token(token), SemanticDeclLevel::default());
97+
semantic_info
98+
}
99+
100+
pub fn find_call_expr_origin_for_decl(
101+
semantic_model: &SemanticModel,
102+
compilation: &LuaCompilation,
103+
trigger_token: &LuaSyntaxToken,
104+
semantic_decl: &LuaSemanticDeclId,
105+
) -> Option<LuaSemanticDeclId> {
106+
let call_expr = LuaCallExpr::cast(trigger_token.parent()?.parent()?)?;
107+
let call_function = get_call_function(semantic_model, &call_expr)?;
108+
let decl_id = match semantic_decl {
109+
LuaSemanticDeclId::LuaDecl(decl_id) => decl_id,
110+
_ => return None,
111+
};
112+
113+
let typ = semantic_model.get_type(decl_id.clone().into());
114+
match typ {
115+
LuaType::DocFunction(func) => {
116+
if compare_function_types(semantic_model, &call_function, &func, &call_expr)? {
117+
return Some(decl_id.clone().into());
118+
}
119+
}
120+
LuaType::Signature(signature_id) => {
121+
let signature = semantic_model
122+
.get_db()
123+
.get_signature_index()
124+
.get(&signature_id)?;
125+
let functions = get_signature_functions(signature);
126+
if functions.iter().any(|func| {
127+
compare_function_types(semantic_model, &call_function, func, &call_expr)
128+
.unwrap_or(false)
129+
}) {
130+
let semantic_model = compilation.get_semantic_model(signature_id.get_file_id())?;
131+
let root = semantic_model.get_root_by_file_id(signature_id.get_file_id())?;
132+
let token = match root.syntax().token_at_offset(signature_id.get_position()) {
133+
TokenAtOffset::Single(token) => token,
134+
TokenAtOffset::Between(left, right) => {
135+
if left.kind() == LuaTokenKind::TkName.into() {
136+
left
137+
} else if left.kind() == LuaTokenKind::TkLeftBracket.into()
138+
&& right.kind() == LuaTokenKind::TkInt.into()
139+
{
140+
left
141+
} else {
142+
right
143+
}
144+
}
145+
TokenAtOffset::None => {
146+
return None;
147+
}
148+
};
149+
let semantic_info = semantic_model
150+
.find_decl(NodeOrToken::Token(token), SemanticDeclLevel::default());
151+
if let Some(semantic_info) = semantic_info {
152+
return Some(semantic_info);
153+
}
154+
}
155+
}
156+
_ => return None,
157+
};
158+
159+
None
160+
}
161+
64162
/// 获取最匹配的函数(并不能确保完全匹配)
65163
fn get_call_function(
66164
semantic_model: &SemanticModel,

crates/emmylua_ls/src/handlers/test/definition_test.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,4 +251,79 @@ mod tests {
251251
}
252252
}
253253
}
254+
255+
#[test]
256+
fn test_goto_export_function() {
257+
let mut ws = ProviderVirtualWorkspace::new();
258+
ws.def_file(
259+
"a.lua",
260+
r#"
261+
local function create()
262+
end
263+
264+
return create
265+
"#,
266+
);
267+
let result = ws
268+
.check_definition(
269+
r#"
270+
local create = require('a')
271+
create<??>()
272+
"#,
273+
)
274+
.unwrap();
275+
match result {
276+
GotoDefinitionResponse::Scalar(location) => {
277+
assert_eq!(location.uri.path().as_str().ends_with("a.lua"), true);
278+
}
279+
_ => {
280+
panic!("expect array");
281+
}
282+
}
283+
}
284+
285+
#[test]
286+
fn test_goto_export_function_2() {
287+
let mut ws = ProviderVirtualWorkspace::new();
288+
ws.def_file(
289+
"a.lua",
290+
r#"
291+
local function testA()
292+
end
293+
294+
local function create()
295+
end
296+
297+
return create
298+
"#,
299+
);
300+
ws.def_file(
301+
"b.lua",
302+
r#"
303+
local Rxlua = {}
304+
local create = require('a')
305+
306+
Rxlua.create = create
307+
return Rxlua
308+
"#,
309+
);
310+
let result = ws
311+
.check_definition(
312+
r#"
313+
local create = require('b').create
314+
create<??>()
315+
"#,
316+
)
317+
.unwrap();
318+
match result {
319+
GotoDefinitionResponse::Array(array) => {
320+
assert_eq!(array.len(), 1);
321+
let location = &array[0];
322+
assert_eq!(location.uri.path().as_str().ends_with("a.lua"), true);
323+
}
324+
_ => {
325+
panic!("expect array");
326+
}
327+
}
328+
}
254329
}

0 commit comments

Comments
 (0)