Skip to content

Commit e8aee81

Browse files
authored
Merge pull request #763 from EmmyLuaLs/refactor-overload
refactor overload
2 parents c8e65f9 + 309c218 commit e8aee81

File tree

2 files changed

+268
-54
lines changed

2 files changed

+268
-54
lines changed

crates/emmylua_code_analysis/src/diagnostic/test/redundant_parameter_test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ mod test {
9191
fn test_issue_360() {
9292
let mut ws = VirtualWorkspace::new();
9393

94-
assert!(ws.check_code_for(
94+
assert!(!ws.check_code_for(
9595
DiagnosticCode::RedundantParameter,
9696
r#"
9797
---@alias buz number
Lines changed: 267 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,91 @@
1+
// use std::sync::Arc;
2+
3+
// use crate::{
4+
// InferFailReason, check_type_compact,
5+
// db_index::{DbIndex, LuaFunctionType, LuaType},
6+
// semantic::infer::InferCallFuncResult,
7+
// };
8+
9+
// pub fn resolve_signature_by_args(
10+
// db: &DbIndex,
11+
// overloads: &[Arc<LuaFunctionType>],
12+
// expr_types: &[LuaType],
13+
// is_colon_call: bool,
14+
// arg_count: Option<usize>,
15+
// ) -> InferCallFuncResult {
16+
// let arg_count = arg_count.unwrap_or(0);
17+
// let mut opt_funcs = Vec::with_capacity(overloads.len());
18+
19+
// for func in overloads {
20+
// let params = func.get_params();
21+
// if params.len() < arg_count {
22+
// continue;
23+
// }
24+
// let mut total_weight = 0; // 总权重
25+
26+
// let mut fake_expr_len = expr_types.len();
27+
// let jump_param;
28+
// if is_colon_call && !func.is_colon_define() {
29+
// jump_param = 1;
30+
// fake_expr_len += 1;
31+
// } else {
32+
// jump_param = 0;
33+
// };
34+
35+
// // 如果在不计算可空参数的情况下, 参数数量完全匹配, 则认为其权重更高
36+
// if params.len() == fake_expr_len {
37+
// total_weight += params.len() as i32 + 1;
38+
// }
39+
40+
// // 冒号定义且冒号调用
41+
// if is_colon_call && func.is_colon_define() {
42+
// total_weight += 100;
43+
// }
44+
45+
// // 检查每个参数的匹配情况
46+
// for (i, param) in params.iter().enumerate() {
47+
// if i == 0 && jump_param > 0 {
48+
// // 非冒号定义但是冒号调用, 直接认为匹配
49+
// total_weight += 100;
50+
// continue;
51+
// }
52+
// let param_type = param.1.as_ref().unwrap_or(&LuaType::Any);
53+
// let expr_idx = i - jump_param;
54+
55+
// if expr_idx >= expr_types.len() {
56+
// // 没有传入参数, 但参数是可空类型
57+
// if param_type.is_nullable() {
58+
// total_weight += 1;
59+
// fake_expr_len += 1;
60+
// }
61+
// continue;
62+
// }
63+
64+
// let expr_type = &expr_types[expr_idx];
65+
// if *param_type == LuaType::Any || check_type_compact(db, param_type, expr_type).is_ok()
66+
// {
67+
// total_weight += 100; // 类型完全匹配
68+
// }
69+
// }
70+
// // 如果参数数量完全匹配, 则认为其权重更高
71+
// if params.len() == fake_expr_len {
72+
// total_weight += 50000;
73+
// }
74+
75+
// opt_funcs.push((func, total_weight));
76+
// }
77+
78+
// // 按权重降序排序
79+
// opt_funcs.sort_by(|a, b| b.1.cmp(&a.1));
80+
// // 返回权重最高的签名,若无则取最后一个重载作为默认
81+
// opt_funcs
82+
// .first()
83+
// .filter(|(_, weight)| *weight > i32::MIN) // 确保不是无效签名
84+
// .map(|(func, _)| Arc::clone(func))
85+
// .or_else(|| overloads.last().cloned())
86+
// .ok_or(InferFailReason::None)
87+
// }
88+
189
use std::sync::Arc;
290

391
use crate::{
@@ -14,74 +102,200 @@ pub fn resolve_signature_by_args(
14102
arg_count: Option<usize>,
15103
) -> InferCallFuncResult {
16104
let arg_count = arg_count.unwrap_or(0);
17-
let mut opt_funcs = Vec::with_capacity(overloads.len());
105+
let mut need_resolve_funcs = match overloads.len() {
106+
0 => return Err(InferFailReason::None),
107+
1 => return Ok(Arc::clone(&overloads[0])),
108+
_ => overloads
109+
.iter()
110+
.map(|it| Some(it.clone()))
111+
.collect::<Vec<_>>(),
112+
};
113+
114+
let exp_len = expr_types.len();
115+
if exp_len == 0 {
116+
for overload in overloads {
117+
let param_len = overload.get_params().len();
118+
if param_len == 0 {
119+
return Ok(overload.clone());
120+
}
121+
}
122+
}
123+
124+
let mut best_match_result = need_resolve_funcs[0].clone().unwrap();
125+
for arg_index in 0..exp_len {
126+
let mut current_match_result = ParamMatchResult::NotMatch;
127+
for i in 0..need_resolve_funcs.len() {
128+
let opt_func = &need_resolve_funcs[i];
129+
if opt_func.is_none() {
130+
continue;
131+
}
132+
let func = opt_func.as_ref().unwrap();
133+
let param_len = func.get_params().len();
134+
if param_len < arg_count {
135+
need_resolve_funcs[i] = None;
136+
continue;
137+
}
138+
139+
let colon_define = func.is_colon_define();
140+
let mut param_index = arg_index;
141+
match (colon_define, is_colon_call) {
142+
(true, false) => {
143+
if param_index == 0 {
144+
continue;
145+
}
146+
param_index -= 1;
147+
}
148+
(false, true) => {
149+
param_index += 1;
150+
}
151+
_ => {}
152+
}
153+
let expr_type = &expr_types[arg_index];
154+
let param_type = if param_index < param_len {
155+
let param_info = func.get_params().get(param_index);
156+
param_info
157+
.map(|it| it.1.clone().unwrap_or(LuaType::Any))
158+
.unwrap_or(LuaType::Any)
159+
} else if let Some(last_param_info) = func.get_params().last() {
160+
if last_param_info.0 == "..." {
161+
last_param_info.1.clone().unwrap_or(LuaType::Any)
162+
} else {
163+
need_resolve_funcs[i] = None;
164+
continue;
165+
}
166+
} else {
167+
need_resolve_funcs[i] = None;
168+
continue;
169+
};
18170

19-
for func in overloads {
20-
let params = func.get_params();
21-
if params.len() < arg_count {
22-
continue;
171+
let match_result = if param_type.is_any() {
172+
ParamMatchResult::AnyMatch
173+
} else if check_type_compact(db, &param_type, &expr_type).is_ok() {
174+
ParamMatchResult::TypeMatch
175+
} else {
176+
ParamMatchResult::NotMatch
177+
};
178+
179+
if match_result > current_match_result {
180+
current_match_result = match_result;
181+
best_match_result = func.clone();
182+
}
183+
184+
if match_result == ParamMatchResult::NotMatch {
185+
need_resolve_funcs[i] = None;
186+
continue;
187+
}
188+
189+
if match_result > ParamMatchResult::AnyMatch {
190+
if param_index + 1 == func.get_params().len() {
191+
return Ok(func.clone());
192+
}
193+
}
23194
}
24-
let mut total_weight = 0; // 总权重
25-
26-
let mut fake_expr_len = expr_types.len();
27-
let jump_param;
28-
if is_colon_call && !func.is_colon_define() {
29-
jump_param = 1;
30-
fake_expr_len += 1;
31-
} else {
32-
jump_param = 0;
33-
};
34-
35-
// 如果在不计算可空参数的情况下, 参数数量完全匹配, 则认为其权重更高
36-
if params.len() == fake_expr_len {
37-
total_weight += params.len() as i32 + 1;
195+
196+
if current_match_result == ParamMatchResult::NotMatch {
197+
break;
38198
}
199+
}
200+
201+
let mut rest_need_resolve_funcs = need_resolve_funcs
202+
.iter()
203+
.filter_map(|it| it.clone())
204+
.map(|it| Some(it))
205+
.collect::<Vec<_>>();
39206

40-
// 冒号定义且冒号调用
41-
if is_colon_call && func.is_colon_define() {
42-
total_weight += 100;
207+
match rest_need_resolve_funcs.len() {
208+
0 => return Ok(best_match_result),
209+
1 => return Ok(rest_need_resolve_funcs[0].clone().unwrap()),
210+
_ => {}
211+
}
212+
213+
let start_param_index = exp_len;
214+
let mut max_param_len = 0;
215+
for opt_func in &rest_need_resolve_funcs {
216+
if let Some(func) = opt_func {
217+
let param_len = func.get_params().len();
218+
if param_len > max_param_len {
219+
max_param_len = param_len;
220+
}
43221
}
222+
}
44223

45-
// 检查每个参数的匹配情况
46-
for (i, param) in params.iter().enumerate() {
47-
if i == 0 && jump_param > 0 {
48-
// 非冒号定义但是冒号调用, 直接认为匹配
49-
total_weight += 100;
224+
let rest_len = rest_need_resolve_funcs.len();
225+
for param_index in start_param_index..max_param_len {
226+
let mut current_match_result = ParamMatchResult::NotMatch;
227+
for i in 0..rest_len {
228+
let opt_func = &rest_need_resolve_funcs[i];
229+
if opt_func.is_none() {
50230
continue;
51231
}
52-
let param_type = param.1.as_ref().unwrap_or(&LuaType::Any);
53-
let expr_idx = i - jump_param;
54-
55-
if expr_idx >= expr_types.len() {
56-
// 没有传入参数, 但参数是可空类型
57-
if param_type.is_nullable() {
58-
total_weight += 1;
59-
fake_expr_len += 1;
232+
let func = opt_func.as_ref().unwrap();
233+
let param_len = func.get_params().len();
234+
let colon_define = func.is_colon_define();
235+
let mut param_index = param_index;
236+
match (colon_define, is_colon_call) {
237+
(true, false) => {
238+
if param_index == 0 {
239+
continue;
240+
}
241+
param_index -= 1;
242+
}
243+
(false, true) => {
244+
param_index += 1;
245+
}
246+
_ => {}
247+
}
248+
let param_type = if param_index < param_len {
249+
let param_info = func.get_params().get(param_index);
250+
param_info
251+
.map(|it| it.1.clone().unwrap_or(LuaType::Any))
252+
.unwrap_or(LuaType::Any)
253+
} else if let Some(last_param_info) = func.get_params().last() {
254+
if last_param_info.0 == "..." {
255+
last_param_info.1.clone().unwrap_or(LuaType::Any)
256+
} else {
257+
return Ok(func.clone());
60258
}
259+
} else {
260+
return Ok(func.clone());
261+
};
262+
263+
let match_result = if param_type.is_any() {
264+
ParamMatchResult::AnyMatch
265+
} else if param_type.is_nullable() {
266+
ParamMatchResult::TypeMatch
267+
} else {
268+
ParamMatchResult::NotMatch
269+
};
270+
271+
if match_result > current_match_result {
272+
current_match_result = match_result;
273+
best_match_result = func.clone();
274+
}
275+
276+
if match_result == ParamMatchResult::NotMatch {
277+
rest_need_resolve_funcs[i] = None;
61278
continue;
62279
}
63280

64-
let expr_type = &expr_types[expr_idx];
65-
if *param_type == LuaType::Any || check_type_compact(db, param_type, expr_type).is_ok()
66-
{
67-
total_weight += 100; // 类型完全匹配
281+
if match_result >= ParamMatchResult::AnyMatch {
282+
if param_index + 1 == func.get_params().len() {
283+
return Ok(func.clone());
284+
}
68285
}
69286
}
70-
// 如果参数数量完全匹配, 则认为其权重更高
71-
if params.len() == fake_expr_len {
72-
total_weight += 50000;
73-
}
74287

75-
opt_funcs.push((func, total_weight));
288+
if current_match_result == ParamMatchResult::NotMatch {
289+
break;
290+
}
76291
}
77292

78-
// 按权重降序排序
79-
opt_funcs.sort_by(|a, b| b.1.cmp(&a.1));
80-
// 返回权重最高的签名,若无则取最后一个重载作为默认
81-
opt_funcs
82-
.first()
83-
.filter(|(_, weight)| *weight > i32::MIN) // 确保不是无效签名
84-
.map(|(func, _)| Arc::clone(func))
85-
.or_else(|| overloads.last().cloned())
86-
.ok_or(InferFailReason::None)
293+
Ok(best_match_result)
294+
}
295+
296+
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
297+
enum ParamMatchResult {
298+
NotMatch,
299+
AnyMatch,
300+
TypeMatch,
87301
}

0 commit comments

Comments
 (0)