11use emmylua_code_analysis:: {
2- LuaCompilation , LuaDeclId , LuaFunctionType , LuaMemberId , LuaSemanticDeclId , LuaSignature ,
3- LuaSignatureId , LuaType , SemanticDeclLevel , SemanticModel , instantiate_func_generic,
2+ LuaCompilation , LuaDeclId , LuaFunctionType , LuaSemanticDeclId , LuaSignature , LuaSignatureId ,
3+ LuaType , SemanticDeclLevel , SemanticModel , instantiate_func_generic,
4+ } ;
5+ use emmylua_parser:: {
6+ LuaAstNode , LuaCallExpr , LuaExpr , LuaLiteralToken , LuaSyntaxToken , LuaTokenKind ,
47} ;
5- use emmylua_parser:: { LuaAstNode , LuaCallExpr , LuaSyntaxToken , LuaTokenKind } ;
68use rowan:: { NodeOrToken , TokenAtOffset } ;
79use std:: sync:: Arc ;
810
@@ -27,17 +29,44 @@ pub fn find_matching_function_definitions(
2729 let mut has_match = false ;
2830
2931 for ( decl, member_id) in member_decls {
30- if process_member_function_type (
31- semantic_model,
32- compilation,
33- & call_function,
34- & call_expr,
35- decl,
36- member_id,
37- & mut result,
38- & mut has_match,
39- ) {
40- continue ;
32+ let typ = semantic_model. get_type ( member_id. clone ( ) . into ( ) ) ;
33+ match typ {
34+ LuaType :: DocFunction ( func) => {
35+ if compare_function_types ( semantic_model, & call_function, & func, & call_expr)
36+ . unwrap_or ( false )
37+ {
38+ result. push ( decl. clone ( ) ) ;
39+ has_match = true ;
40+ }
41+ }
42+ LuaType :: Signature ( signature_id) => {
43+ let signature = match semantic_model
44+ . get_db ( )
45+ . get_signature_index ( )
46+ . get ( & signature_id)
47+ {
48+ Some ( sig) => sig,
49+ None => continue ,
50+ } ;
51+ let functions = get_signature_functions ( signature) ;
52+
53+ if functions. iter ( ) . any ( |func| {
54+ compare_function_types ( semantic_model, & call_function, func, & call_expr)
55+ . unwrap_or ( false )
56+ } ) {
57+ has_match = true ;
58+ }
59+
60+ // 无论是否匹配, 都需要将真实的定义添加到结果中
61+ // 如果存在原始定义, 则优先使用原始定义
62+ let origin = extract_semantic_decl_from_signature ( compilation, & signature_id) ;
63+ if let Some ( origin) = origin {
64+ result. insert ( 0 , origin) ;
65+ } else {
66+ result. insert ( 0 , decl. clone ( ) ) ;
67+ }
68+ }
69+ _ => continue ,
4170 }
4271 }
4372
@@ -48,58 +77,6 @@ pub fn find_matching_function_definitions(
4877 }
4978}
5079
51- fn process_member_function_type (
52- semantic_model : & SemanticModel ,
53- compilation : & LuaCompilation ,
54- call_function : & Arc < LuaFunctionType > ,
55- call_expr : & LuaCallExpr ,
56- decl : & LuaSemanticDeclId ,
57- member_id : & LuaMemberId ,
58- result : & mut Vec < LuaSemanticDeclId > ,
59- has_match : & mut bool ,
60- ) -> bool {
61- let typ = semantic_model. get_type ( member_id. clone ( ) . into ( ) ) ;
62- match typ {
63- LuaType :: DocFunction ( func) => {
64- if compare_function_types ( semantic_model, call_function, & func, call_expr)
65- . unwrap_or ( false )
66- {
67- result. push ( decl. clone ( ) ) ;
68- * has_match = true ;
69- }
70- }
71- LuaType :: Signature ( signature_id) => {
72- let signature = match semantic_model
73- . get_db ( )
74- . get_signature_index ( )
75- . get ( & signature_id)
76- {
77- Some ( sig) => sig,
78- None => return false ,
79- } ;
80- let functions = get_signature_functions ( signature) ;
81-
82- if functions. iter ( ) . any ( |func| {
83- compare_function_types ( semantic_model, call_function, func, call_expr)
84- . unwrap_or ( false )
85- } ) {
86- * has_match = true ;
87- }
88-
89- // 无论是否匹配,都需要将真实的定义添加到结果中
90- // 如果存在原始定义,则优先使用原始定义
91- let origin = extract_semantic_decl_from_signature ( compilation, & signature_id) ;
92- if let Some ( origin) = origin {
93- result. insert ( 0 , origin) ;
94- } else {
95- result. insert ( 0 , decl. clone ( ) ) ;
96- }
97- }
98- _ => return false ,
99- }
100- true
101- }
102-
10380pub fn find_function_call_origin (
10481 semantic_model : & SemanticModel ,
10582 compilation : & LuaCompilation ,
@@ -205,28 +182,99 @@ fn get_call_function(
205182) -> Option < Arc < LuaFunctionType > > {
206183 let func = semantic_model. infer_call_expr_func ( call_expr. clone ( ) , None ) ;
207184 if let Some ( func) = func {
208- let call_expr_args_count = call_expr. get_args_count ( ) ;
209- if let Some ( mut call_expr_args_count) = call_expr_args_count {
210- let mut func_params_count = func. get_params ( ) . len ( ) ;
211- if !func. is_colon_define ( ) && call_expr. is_colon_call ( ) {
212- // 不是冒号定义的函数, 但是是冒号调用
213- call_expr_args_count += 1 ;
185+ if check_params_count_is_match ( semantic_model, & func, call_expr. clone ( ) ) . unwrap_or ( false ) {
186+ return Some ( func) ;
187+ }
188+ }
189+ None
190+ }
191+
192+ fn check_params_count_is_match (
193+ semantic_model : & SemanticModel ,
194+ call_function : & LuaFunctionType ,
195+ call_expr : LuaCallExpr ,
196+ ) -> Option < bool > {
197+ let mut fake_params = call_function. get_params ( ) . to_vec ( ) ;
198+ let call_args = call_expr. get_args_list ( ) ?. get_args ( ) . collect :: < Vec < _ > > ( ) ;
199+ let mut call_args_count = call_args. len ( ) ;
200+ let colon_call = call_expr. is_colon_call ( ) ;
201+ let colon_define = call_function. is_colon_define ( ) ;
202+ match ( colon_call, colon_define) {
203+ ( true , true ) | ( false , false ) => { }
204+ ( false , true ) => {
205+ fake_params. insert ( 0 , ( "self" . to_string ( ) , Some ( LuaType :: SelfInfer ) ) ) ;
206+ }
207+ ( true , false ) => {
208+ call_args_count += 1 ;
209+ }
210+ }
211+ if call_args_count < fake_params. len ( ) {
212+ // 调用参数包含 `...`
213+ for arg in call_args. iter ( ) {
214+ if let LuaExpr :: LiteralExpr ( literal_expr) = arg {
215+ if let Some ( literal_token) = literal_expr. get_literal ( ) {
216+ if let LuaLiteralToken :: Dots ( _) = literal_token {
217+ return Some ( true ) ;
218+ }
219+ }
214220 }
215- // 如果参数有可空参数, 则需要减去
216- for ( _, param_type) in func. get_params ( ) . iter ( ) {
217- if let Some ( param_type) = param_type {
218- if param_type. is_optional ( ) {
219- func_params_count -= 1 ;
221+ }
222+ // 对调用参数的最后一个参数进行特殊处理
223+ if let Some ( last_arg) = call_args. last ( ) {
224+ if let Ok ( LuaType :: Variadic ( variadic) ) = semantic_model. infer_expr ( last_arg. clone ( ) ) {
225+ let len = match variadic. get_max_len ( ) {
226+ Some ( len) => len,
227+ None => {
228+ return Some ( true ) ;
220229 }
230+ } ;
231+ call_args_count = call_args_count + len as usize - 1 ;
232+ if call_args_count >= fake_params. len ( ) {
233+ return Some ( true ) ;
234+ }
235+ }
236+ }
237+
238+ for i in call_args_count..fake_params. len ( ) {
239+ let param_info = fake_params. get ( i) ?;
240+ if param_info. 0 == "..." {
241+ return Some ( true ) ;
242+ }
243+
244+ let typ = param_info. 1 . clone ( ) ;
245+ if let Some ( typ) = typ {
246+ if !typ. is_optional ( ) {
247+ return Some ( false ) ;
221248 }
222249 }
250+ }
251+ } else if call_args_count > fake_params. len ( ) {
252+ // 参数定义中最后一个参数是 `...`
253+ if fake_params. last ( ) . map_or ( false , |( name, typ) | {
254+ name == "..."
255+ || if let Some ( typ) = typ {
256+ typ. is_variadic ( )
257+ } else {
258+ false
259+ }
260+ } ) {
261+ return Some ( true ) ;
262+ }
223263
224- if call_expr_args_count == func_params_count {
225- return Some ( func) ;
264+ let mut adjusted_index = 0 ;
265+ if colon_call != colon_define {
266+ adjusted_index = if colon_define && !colon_call { -1 } else { 1 } ;
267+ }
268+
269+ for ( i, _) in call_args. iter ( ) . enumerate ( ) {
270+ let param_index = i as isize + adjusted_index;
271+ if param_index < 0 || param_index < fake_params. len ( ) as isize {
272+ continue ;
226273 }
274+ return Some ( false ) ;
227275 }
228276 }
229- None
277+ Some ( true )
230278}
231279
232280fn get_signature_functions ( signature : & LuaSignature ) -> Vec < Arc < LuaFunctionType > > {
0 commit comments