1
1
use emmylua_code_analysis:: {
2
- LuaCompilation , LuaDeclId , LuaFunctionType , LuaMemberId , LuaSemanticDeclId , LuaSignature ,
3
- LuaSignatureId , LuaType , SemanticDeclLevel , SemanticModel , instantiate_func_generic,
2
+ LuaCompilation , LuaDeclId , LuaFunctionType , LuaSemanticDeclId , LuaSignature , LuaSignatureId ,
3
+ LuaType , SemanticDeclLevel , SemanticModel , instantiate_func_generic,
4
+ } ;
5
+ use emmylua_parser:: {
6
+ LuaAstNode , LuaCallExpr , LuaExpr , LuaLiteralToken , LuaSyntaxToken , LuaTokenKind ,
4
7
} ;
5
- use emmylua_parser:: { LuaAstNode , LuaCallExpr , LuaSyntaxToken , LuaTokenKind } ;
6
8
use rowan:: { NodeOrToken , TokenAtOffset } ;
7
9
use std:: sync:: Arc ;
8
10
@@ -27,17 +29,44 @@ pub fn find_matching_function_definitions(
27
29
let mut has_match = false ;
28
30
29
31
for ( decl, member_id) in member_decls {
30
- if process_member_function_type (
31
- semantic_model,
32
- compilation,
33
- & call_function,
34
- & call_expr,
35
- decl,
36
- member_id,
37
- & mut result,
38
- & mut has_match,
39
- ) {
40
- continue ;
32
+ let typ = semantic_model. get_type ( member_id. clone ( ) . into ( ) ) ;
33
+ match typ {
34
+ LuaType :: DocFunction ( func) => {
35
+ if compare_function_types ( semantic_model, & call_function, & func, & call_expr)
36
+ . unwrap_or ( false )
37
+ {
38
+ result. push ( decl. clone ( ) ) ;
39
+ has_match = true ;
40
+ }
41
+ }
42
+ LuaType :: Signature ( signature_id) => {
43
+ let signature = match semantic_model
44
+ . get_db ( )
45
+ . get_signature_index ( )
46
+ . get ( & signature_id)
47
+ {
48
+ Some ( sig) => sig,
49
+ None => continue ,
50
+ } ;
51
+ let functions = get_signature_functions ( signature) ;
52
+
53
+ if functions. iter ( ) . any ( |func| {
54
+ compare_function_types ( semantic_model, & call_function, func, & call_expr)
55
+ . unwrap_or ( false )
56
+ } ) {
57
+ has_match = true ;
58
+ }
59
+
60
+ // 无论是否匹配, 都需要将真实的定义添加到结果中
61
+ // 如果存在原始定义, 则优先使用原始定义
62
+ let origin = extract_semantic_decl_from_signature ( compilation, & signature_id) ;
63
+ if let Some ( origin) = origin {
64
+ result. insert ( 0 , origin) ;
65
+ } else {
66
+ result. insert ( 0 , decl. clone ( ) ) ;
67
+ }
68
+ }
69
+ _ => continue ,
41
70
}
42
71
}
43
72
@@ -48,58 +77,6 @@ pub fn find_matching_function_definitions(
48
77
}
49
78
}
50
79
51
- fn process_member_function_type (
52
- semantic_model : & SemanticModel ,
53
- compilation : & LuaCompilation ,
54
- call_function : & Arc < LuaFunctionType > ,
55
- call_expr : & LuaCallExpr ,
56
- decl : & LuaSemanticDeclId ,
57
- member_id : & LuaMemberId ,
58
- result : & mut Vec < LuaSemanticDeclId > ,
59
- has_match : & mut bool ,
60
- ) -> bool {
61
- let typ = semantic_model. get_type ( member_id. clone ( ) . into ( ) ) ;
62
- match typ {
63
- LuaType :: DocFunction ( func) => {
64
- if compare_function_types ( semantic_model, call_function, & func, call_expr)
65
- . unwrap_or ( false )
66
- {
67
- result. push ( decl. clone ( ) ) ;
68
- * has_match = true ;
69
- }
70
- }
71
- LuaType :: Signature ( signature_id) => {
72
- let signature = match semantic_model
73
- . get_db ( )
74
- . get_signature_index ( )
75
- . get ( & signature_id)
76
- {
77
- Some ( sig) => sig,
78
- None => return false ,
79
- } ;
80
- let functions = get_signature_functions ( signature) ;
81
-
82
- if functions. iter ( ) . any ( |func| {
83
- compare_function_types ( semantic_model, call_function, func, call_expr)
84
- . unwrap_or ( false )
85
- } ) {
86
- * has_match = true ;
87
- }
88
-
89
- // 无论是否匹配,都需要将真实的定义添加到结果中
90
- // 如果存在原始定义,则优先使用原始定义
91
- let origin = extract_semantic_decl_from_signature ( compilation, & signature_id) ;
92
- if let Some ( origin) = origin {
93
- result. insert ( 0 , origin) ;
94
- } else {
95
- result. insert ( 0 , decl. clone ( ) ) ;
96
- }
97
- }
98
- _ => return false ,
99
- }
100
- true
101
- }
102
-
103
80
pub fn find_function_call_origin (
104
81
semantic_model : & SemanticModel ,
105
82
compilation : & LuaCompilation ,
@@ -205,28 +182,99 @@ fn get_call_function(
205
182
) -> Option < Arc < LuaFunctionType > > {
206
183
let func = semantic_model. infer_call_expr_func ( call_expr. clone ( ) , None ) ;
207
184
if let Some ( func) = func {
208
- let call_expr_args_count = call_expr. get_args_count ( ) ;
209
- if let Some ( mut call_expr_args_count) = call_expr_args_count {
210
- let mut func_params_count = func. get_params ( ) . len ( ) ;
211
- if !func. is_colon_define ( ) && call_expr. is_colon_call ( ) {
212
- // 不是冒号定义的函数, 但是是冒号调用
213
- call_expr_args_count += 1 ;
185
+ if check_params_count_is_match ( semantic_model, & func, call_expr. clone ( ) ) . unwrap_or ( false ) {
186
+ return Some ( func) ;
187
+ }
188
+ }
189
+ None
190
+ }
191
+
192
+ fn check_params_count_is_match (
193
+ semantic_model : & SemanticModel ,
194
+ call_function : & LuaFunctionType ,
195
+ call_expr : LuaCallExpr ,
196
+ ) -> Option < bool > {
197
+ let mut fake_params = call_function. get_params ( ) . to_vec ( ) ;
198
+ let call_args = call_expr. get_args_list ( ) ?. get_args ( ) . collect :: < Vec < _ > > ( ) ;
199
+ let mut call_args_count = call_args. len ( ) ;
200
+ let colon_call = call_expr. is_colon_call ( ) ;
201
+ let colon_define = call_function. is_colon_define ( ) ;
202
+ match ( colon_call, colon_define) {
203
+ ( true , true ) | ( false , false ) => { }
204
+ ( false , true ) => {
205
+ fake_params. insert ( 0 , ( "self" . to_string ( ) , Some ( LuaType :: SelfInfer ) ) ) ;
206
+ }
207
+ ( true , false ) => {
208
+ call_args_count += 1 ;
209
+ }
210
+ }
211
+ if call_args_count < fake_params. len ( ) {
212
+ // 调用参数包含 `...`
213
+ for arg in call_args. iter ( ) {
214
+ if let LuaExpr :: LiteralExpr ( literal_expr) = arg {
215
+ if let Some ( literal_token) = literal_expr. get_literal ( ) {
216
+ if let LuaLiteralToken :: Dots ( _) = literal_token {
217
+ return Some ( true ) ;
218
+ }
219
+ }
214
220
}
215
- // 如果参数有可空参数, 则需要减去
216
- for ( _, param_type) in func. get_params ( ) . iter ( ) {
217
- if let Some ( param_type) = param_type {
218
- if param_type. is_optional ( ) {
219
- func_params_count -= 1 ;
221
+ }
222
+ // 对调用参数的最后一个参数进行特殊处理
223
+ if let Some ( last_arg) = call_args. last ( ) {
224
+ if let Ok ( LuaType :: Variadic ( variadic) ) = semantic_model. infer_expr ( last_arg. clone ( ) ) {
225
+ let len = match variadic. get_max_len ( ) {
226
+ Some ( len) => len,
227
+ None => {
228
+ return Some ( true ) ;
220
229
}
230
+ } ;
231
+ call_args_count = call_args_count + len as usize - 1 ;
232
+ if call_args_count >= fake_params. len ( ) {
233
+ return Some ( true ) ;
234
+ }
235
+ }
236
+ }
237
+
238
+ for i in call_args_count..fake_params. len ( ) {
239
+ let param_info = fake_params. get ( i) ?;
240
+ if param_info. 0 == "..." {
241
+ return Some ( true ) ;
242
+ }
243
+
244
+ let typ = param_info. 1 . clone ( ) ;
245
+ if let Some ( typ) = typ {
246
+ if !typ. is_optional ( ) {
247
+ return Some ( false ) ;
221
248
}
222
249
}
250
+ }
251
+ } else if call_args_count > fake_params. len ( ) {
252
+ // 参数定义中最后一个参数是 `...`
253
+ if fake_params. last ( ) . map_or ( false , |( name, typ) | {
254
+ name == "..."
255
+ || if let Some ( typ) = typ {
256
+ typ. is_variadic ( )
257
+ } else {
258
+ false
259
+ }
260
+ } ) {
261
+ return Some ( true ) ;
262
+ }
223
263
224
- if call_expr_args_count == func_params_count {
225
- return Some ( func) ;
264
+ let mut adjusted_index = 0 ;
265
+ if colon_call != colon_define {
266
+ adjusted_index = if colon_define && !colon_call { -1 } else { 1 } ;
267
+ }
268
+
269
+ for ( i, _) in call_args. iter ( ) . enumerate ( ) {
270
+ let param_index = i as isize + adjusted_index;
271
+ if param_index < 0 || param_index < fake_params. len ( ) as isize {
272
+ continue ;
226
273
}
274
+ return Some ( false ) ;
227
275
}
228
276
}
229
- None
277
+ Some ( true )
230
278
}
231
279
232
280
fn get_signature_functions ( signature : & LuaSignature ) -> Vec < Arc < LuaFunctionType > > {
0 commit comments