Skip to content

Commit a0cbe97

Browse files
committed
hover generics overload
1 parent 4f1edab commit a0cbe97

File tree

8 files changed

+181
-98
lines changed

8 files changed

+181
-98
lines changed

crates/emmylua_code_analysis/src/db_index/type/types.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,12 @@ impl LuaFunctionType {
481481
.any(|(_, t)| t.as_ref().map_or(false, |t| t.contain_tpl()))
482482
|| self.ret.iter().any(|t| t.contain_tpl())
483483
}
484+
485+
pub fn first_param_is_self(&self) -> bool {
486+
self.params.first().map_or(false, |(_, t)| {
487+
t.as_ref().map_or(false, |t| t.is_self_infer())
488+
})
489+
}
484490
}
485491

486492
impl From<LuaFunctionType> for LuaType {
@@ -563,10 +569,7 @@ impl LuaObjectType {
563569

564570
let mut ty = LuaType::Unknown;
565571
let mut count = 1;
566-
let mut fields = self
567-
.fields
568-
.iter()
569-
.collect::<Vec<_>>();
572+
let mut fields = self.fields.iter().collect::<Vec<_>>();
570573

571574
fields.sort_by(|(a, _), (b, _)| a.cmp(b));
572575

crates/emmylua_code_analysis/src/semantic/call_func/mod.rs

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use emmylua_parser::LuaCallExpr;
44

55
use crate::{
66
DbIndex, LuaFunctionType, LuaGenericType, LuaOperatorMetaMethod, LuaSignatureId, LuaType,
7-
LuaTypeDeclId,
7+
LuaTypeDeclId, LuaUnionType,
88
};
99

1010
use super::{
@@ -53,10 +53,44 @@ pub fn infer_call_expr_func(
5353
infer_guard,
5454
args_count,
5555
),
56+
LuaType::Union(union) => {
57+
// 此时我们将其视为泛型实例化联合体
58+
if union.get_types().len() > 1
59+
&& union
60+
.get_types()
61+
.iter()
62+
.all(|t| matches!(t, LuaType::DocFunction(_)))
63+
{
64+
infer_generic_doc_function_union(db, config, &union, call_expr, args_count)
65+
} else {
66+
None
67+
}
68+
}
5669
_ => return None,
5770
}
5871
}
5972

73+
fn infer_generic_doc_function_union(
74+
db: &DbIndex,
75+
config: &mut LuaInferConfig,
76+
union: &LuaUnionType,
77+
call_expr: LuaCallExpr,
78+
args_count: Option<usize>,
79+
) -> Option<Arc<LuaFunctionType>> {
80+
let overloads = union
81+
.get_types()
82+
.iter()
83+
.filter_map(|typ| match typ {
84+
LuaType::DocFunction(f) => Some(f.clone()),
85+
_ => None,
86+
})
87+
.collect::<Vec<_>>();
88+
89+
let doc_func = resolve_signature(db, config, overloads, call_expr.clone(), false, args_count)?;
90+
91+
Some(doc_func)
92+
}
93+
6094
fn infer_signature_doc_function(
6195
db: &DbIndex,
6296
config: &mut LuaInferConfig,
@@ -74,7 +108,8 @@ fn infer_signature_doc_function(
74108
vec![],
75109
);
76110
if signature.is_generic() {
77-
fake_doc_function = instantiate_func_generic(db, config, &fake_doc_function, call_expr)?;
111+
fake_doc_function =
112+
instantiate_func_generic(db, config, &fake_doc_function, call_expr)?;
78113
}
79114

80115
Some(fake_doc_function.into())

crates/emmylua_code_analysis/src/semantic/instantiate/instantiate_class_generic.rs

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -308,27 +308,42 @@ fn instantiate_signature(
308308
substitutor: &TypeSubstitutor,
309309
) -> LuaType {
310310
if let Some(signature) = db.get_signature_index().get(&signature_id) {
311-
let rets = signature
312-
.return_docs
313-
.iter()
314-
.map(|ret| ret.type_ref.clone())
315-
.collect();
316-
let is_async = if let Some(property) = db
317-
.get_property_index()
318-
.get_property(LuaPropertyOwnerId::Signature(signature_id.clone()))
319-
{
320-
property.is_async
321-
} else {
322-
false
311+
let origin_type = {
312+
let rets = signature
313+
.return_docs
314+
.iter()
315+
.map(|ret| ret.type_ref.clone())
316+
.collect();
317+
let is_async = if let Some(property) = db
318+
.get_property_index()
319+
.get_property(LuaPropertyOwnerId::Signature(signature_id.clone()))
320+
{
321+
property.is_async
322+
} else {
323+
false
324+
};
325+
let fake_doc_function = LuaFunctionType::new(
326+
is_async,
327+
signature.is_colon_define,
328+
signature.get_type_params(),
329+
rets,
330+
);
331+
instantiate_doc_function(db, &fake_doc_function, substitutor)
323332
};
324-
let fake_doc_function = LuaFunctionType::new(
325-
is_async,
326-
signature.is_colon_define,
327-
signature.get_type_params(),
328-
rets,
329-
);
330-
let instantiate_func = instantiate_doc_function(db, &fake_doc_function, substitutor);
331-
return instantiate_func;
333+
if signature.overloads.is_empty() {
334+
return origin_type;
335+
} else {
336+
let mut result = Vec::new();
337+
for overload in signature.overloads.iter() {
338+
result.push(instantiate_doc_function(
339+
db,
340+
&(*overload).clone(),
341+
substitutor,
342+
));
343+
}
344+
result.push(origin_type); // 我们需要将原始类型放到最后
345+
return LuaType::Union(LuaUnionType::new(result).into());
346+
}
332347
}
333348

334349
return LuaType::Signature(signature_id.clone());

crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ pub fn resolve_signature(
2222
for arg in args.get_args() {
2323
expr_types.push(infer_expr(db, infer_config, arg)?);
2424
}
25-
2625
if is_generic {
2726
return resolve_signature_by_generic(
2827
db,
@@ -33,7 +32,13 @@ pub fn resolve_signature(
3332
arg_count,
3433
);
3534
} else {
36-
return resolve_signature_by_args(db, overloads, expr_types, arg_count);
35+
return resolve_signature_by_args(
36+
db,
37+
overloads,
38+
expr_types,
39+
call_expr.is_colon_call(),
40+
arg_count,
41+
);
3742
}
3843
}
3944

@@ -45,11 +50,12 @@ fn resolve_signature_by_generic(
4550
expr_types: Vec<LuaType>,
4651
arg_count: Option<usize>,
4752
) -> Option<Arc<LuaFunctionType>> {
48-
let mut max_match = -1;
53+
let mut max_match: usize = 0;
4954
let mut matched_func: Option<Arc<LuaFunctionType>> = None;
5055
let mut instantiate_funcs = Vec::new();
5156
for func in overloads {
52-
let instantiate_func = instantiate_func_generic(db, infer_config, &func, call_expr.clone())?;
57+
let instantiate_func =
58+
instantiate_func_generic(db, infer_config, &func, call_expr.clone())?;
5359
instantiate_funcs.push(Arc::new(instantiate_func));
5460
}
5561

@@ -80,7 +86,7 @@ fn resolve_signature_by_generic(
8086
}
8187

8288
if matched_func.is_none() && !instantiate_funcs.is_empty() {
83-
matched_func = Some(instantiate_funcs[0].clone());
89+
matched_func = Some(instantiate_funcs.last().cloned().unwrap());
8490
}
8591

8692
matched_func
@@ -90,40 +96,56 @@ fn resolve_signature_by_args(
9096
db: &DbIndex,
9197
overloads: Vec<Arc<LuaFunctionType>>,
9298
expr_types: Vec<LuaType>,
99+
is_colon_call: bool,
93100
arg_count: Option<usize>,
94101
) -> Option<Arc<LuaFunctionType>> {
95-
let mut max_match = -1;
102+
let mut max_match: i32 = -1;
96103
let mut matched_func: Option<Arc<LuaFunctionType>> = None;
97104

98105
for func in &overloads {
99106
let params = func.get_params();
100-
let mut match_count = 0;
107+
// 参数数量不足
101108
if params.len() < arg_count.unwrap_or(0) {
102109
continue;
103110
}
104111

112+
// 冒号调用但不是冒号定义的函数, 需要检查第一个参数是否为`self`
113+
let jump_param = if is_colon_call && !func.is_colon_define() {
114+
match params.get(0).and_then(|p| p.1.as_ref()) {
115+
Some(param_type) if param_type == &LuaType::SelfInfer => 1,
116+
Some(_) => break, // 不是冒号定义的函数, 但是是冒号调用
117+
None => 0,
118+
}
119+
} else {
120+
0
121+
};
122+
123+
let mut match_count = 0;
124+
105125
for (i, param) in params.iter().enumerate() {
106-
if i >= expr_types.len() {
126+
if expr_types.len() <= i - jump_param {
107127
break;
108128
}
129+
if i == 0 && jump_param > 0 {
130+
continue;
131+
}
109132

110133
let param_type = param.1.clone().unwrap_or(LuaType::Any);
111-
let expr_type = &expr_types[i];
112-
if param_type == LuaType::Any {
113-
match_count += 1;
114-
} else if check_type_compact(db, &param_type, expr_type).is_ok() {
134+
let expr_type = &expr_types[i - jump_param];
135+
if param_type == LuaType::Any || check_type_compact(db, &param_type, expr_type).is_ok()
136+
{
115137
match_count += 1;
116138
}
117139
}
140+
118141
if match_count > max_match {
119142
max_match = match_count;
120143
matched_func = Some(func.clone());
144+
if match_count == (params.len() - jump_param) as i32 {
145+
break;
146+
}
121147
}
122148
}
123149

124-
if matched_func.is_none() && !overloads.is_empty() {
125-
matched_func = Some(overloads[0].clone());
126-
}
127-
128-
matched_func
150+
matched_func.or_else(|| overloads.last().cloned())
129151
}

crates/emmylua_ls/src/handlers/hover/build_hover.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,18 @@ fn build_member_hover(
187187

188188
let mut function_member = None;
189189
let mut owner_decl = None;
190-
if typ.is_function() {
190+
if typ.is_function()
191+
|| match &typ {
192+
LuaType::Union(union) => {
193+
union.get_types().len() > 1
194+
&& union
195+
.get_types()
196+
.iter()
197+
.all(|t| matches!(t, LuaType::DocFunction(_)))
198+
}
199+
_ => false,
200+
}
201+
{
191202
let property_owner = get_member_owner(&builder.semantic_model, member_id);
192203
match property_owner {
193204
Some(LuaPropertyOwnerId::Member(member_id)) => {

crates/emmylua_ls/src/handlers/hover/hover_builder.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,18 @@ impl<'a> HoverBuilder<'a> {
117117
match call_expr.kind().into() {
118118
LuaSyntaxKind::CallExpr => {
119119
let call_expr = LuaCallExpr::cast(call_expr)?;
120-
let func = self.semantic_model.infer_call_expr_func(call_expr.clone(), None);
120+
let func = self
121+
.semantic_model
122+
.infer_call_expr_func(call_expr.clone(), None);
121123
if let Some(func) = func {
122124
// 确定参数量是否与当前输入的参数数量一致, 因为`infer_call_expr_func`必然返回一个有效的类型, 即使不是完全匹配的
123125
let call_expr_args_count = call_expr.get_args_count();
124-
if let Some(call_expr_args_count) = call_expr_args_count {
126+
if let Some(mut call_expr_args_count) = call_expr_args_count {
125127
let func_params_count = func.get_params().len();
128+
if !func.is_colon_define() && call_expr.is_colon_call() {
129+
// 不是冒号定义的函数, 但是是冒号调用
130+
call_expr_args_count += 1;
131+
}
126132
if call_expr_args_count == func_params_count {
127133
return Some((*func).clone());
128134
}

0 commit comments

Comments
 (0)