Skip to content

Commit fced8b5

Browse files
committed
fix hover union function
1 parent 63d946b commit fced8b5

File tree

5 files changed

+92
-18
lines changed

5 files changed

+92
-18
lines changed

crates/emmylua_code_analysis/src/semantic/infer/infer_table.rs

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::ops::Deref;
1+
use std::{ops::Deref, sync::Arc};
22

33
use emmylua_parser::{
44
LuaAssignStat, LuaAst, LuaAstNode, LuaCallArgList, LuaCallExpr, LuaExpr, LuaIndexMemberExpr,
@@ -8,7 +8,7 @@ use emmylua_parser::{
88
use crate::{
99
db_index::{DbIndex, LuaType},
1010
infer_call_expr_func, infer_expr, InferGuard, LuaArrayType, LuaDeclId, LuaInferCache,
11-
LuaMemberId, LuaTupleStatus, LuaTupleType, VariadicType,
11+
LuaMemberId, LuaTupleStatus, LuaTupleType, LuaUnionType, TypeOps, VariadicType,
1212
};
1313

1414
use super::{
@@ -221,12 +221,37 @@ fn infer_table_type_by_calleee(
221221
call_arg_number -= 1;
222222
}
223223
}
224-
Ok(param_types
224+
let typ = param_types
225225
.get(call_arg_number)
226226
.ok_or(InferFailReason::None)?
227227
.1
228228
.clone()
229-
.unwrap_or(LuaType::Any))
229+
.unwrap_or(LuaType::Any);
230+
match &typ {
231+
LuaType::TableConst(_) => {}
232+
LuaType::Union(union) => {
233+
// TODO: 假设存在多个匹配项, 我们需要根据字段的匹配情况来确定最终的类型
234+
return Ok(union_remove_non_table_type(db, union));
235+
}
236+
_ => {}
237+
}
238+
239+
Ok(typ)
240+
}
241+
242+
/// 移除掉一些非`table`类型
243+
fn union_remove_non_table_type(db: &DbIndex, union: &Arc<LuaUnionType>) -> LuaType {
244+
let mut result = LuaType::Unknown;
245+
for typ in union.into_set().into_iter() {
246+
match typ {
247+
LuaType::Signature(_) | LuaType::DocFunction(_) => {}
248+
_ if typ.is_string() || typ.is_number() || typ.is_boolean() => {}
249+
_ => {
250+
result = TypeOps::Union.apply(db, &result, &typ);
251+
}
252+
}
253+
}
254+
result
230255
}
231256

232257
fn infer_table_field_type_by_parent(

crates/emmylua_ls/src/handlers/hover/build_hover.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ fn build_member_hover(
217217
true,
218218
)
219219
.get_types(&builder.semantic_model);
220+
220221
replace_semantic_type(&mut semantic_decls, &typ);
221222
let member_name = match member.get_key() {
222223
LuaMemberKey::Name(name) => name.to_string(),

crates/emmylua_ls/src/handlers/hover/find_origin.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,7 @@ fn resolve_member_owner(
200200
// 非类, 那么通过右值推断
201201
let value_expr = table_field.get_value_expr()?;
202202
let value_node = value_expr.get_syntax_id().to_node_from_root(&root)?;
203-
semantic_model.find_decl(
204-
value_node.into(),
205-
emmylua_code_analysis::SemanticDeclLevel::default(),
206-
)
203+
semantic_model.find_decl(value_node.into(), SemanticDeclLevel::default())
207204
} else {
208205
None
209206
}
@@ -217,10 +214,8 @@ fn resolve_member_owner(
217214
for (var, expr) in vars.iter().zip(exprs.iter()) {
218215
if var.syntax().text_range() == current_node.text_range() {
219216
let expr_node = expr.get_syntax_id().to_node_from_root(&root)?;
220-
result = semantic_model.find_decl(
221-
expr_node.into(),
222-
emmylua_code_analysis::SemanticDeclLevel::default(),
223-
);
217+
result =
218+
semantic_model.find_decl(expr_node.into(), SemanticDeclLevel::default());
224219
break;
225220
}
226221
}
@@ -255,7 +250,7 @@ fn resolve_table_field_through_type_inference(
255250

256251
if !matches!(
257252
table_type,
258-
emmylua_code_analysis::LuaType::Ref(_) | emmylua_code_analysis::LuaType::Def(_)
253+
LuaType::Ref(_) | LuaType::Def(_) | LuaType::Generic(_)
259254
) {
260255
return None;
261256
}

crates/emmylua_ls/src/handlers/hover/function_humanize.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ pub fn hover_function_type(
5353
// 已处理过的 semantic_decl_id, 用于解决`test_issue_499_3`
5454
let mut handled_semantic_decl_ids = HashSet::new();
5555
let mut type_descs: Vec<HoverFunctionInfo> = Vec::with_capacity(semantic_decls.len());
56-
// 记录已处理过的类型用于在 Union 中跳过重复类型.
56+
// 记录已处理过的类型, 用于在 Union 中跳过重复类型.
5757
// 这是为了解决最后一个类型可能是前面所有类型的联合类型的情况
5858
let mut processed_types = HashSet::new();
5959

@@ -122,7 +122,7 @@ pub fn hover_function_type(
122122
}
123123
}
124124

125-
// 如果当前类型是 Union传入已处理的类型集合
125+
// 如果当前类型是 Union, 传入已处理的类型集合
126126
let result = match typ {
127127
LuaType::Union(_) => process_single_function_type_with_exclusions(
128128
builder,
@@ -158,10 +158,10 @@ pub fn hover_function_type(
158158
function_info = info;
159159
}
160160
ProcessFunctionTypeResult::Multiple(infos) => {
161-
// 对于 Union 类型将每个子类型的结果都添加到 type_descs 中
161+
// 对于 Union 类型, 将每个子类型的结果都添加到 type_descs 中
162162
let infos_len = infos.len();
163163
for (index, mut info) in infos.into_iter().enumerate() {
164-
// 合并描述信息只有最后一个才设置描述
164+
// 合并描述信息, 只有最后一个才设置描述
165165
if function_info.description.is_some()
166166
&& info.description.is_none()
167167
&& index == infos_len - 1
@@ -754,7 +754,7 @@ fn process_single_function_type_with_exclusions(
754754
}
755755
}
756756
_ => {
757-
// 对于非 Union 类型直接调用原函数
757+
// 对于非 Union 类型, 直接调用原函数
758758
process_single_function_type(
759759
builder,
760760
db,

crates/emmylua_ls/src/handlers/test/hover_test.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,4 +245,57 @@ mod tests {
245245
},
246246
));
247247
}
248+
249+
#[test]
250+
fn test_field_key() {
251+
let mut ws = ProviderVirtualWorkspace::new();
252+
ws.def(
253+
r#"
254+
---@class ObserverParams
255+
---@field next fun() # 测试
256+
257+
---@param params fun() | ObserverParams
258+
function test(params)
259+
end
260+
"#,
261+
);
262+
assert!(ws.check_hover(
263+
r#"
264+
test({
265+
<??>next = function()
266+
end
267+
})
268+
"#,
269+
VirtualHoverResult {
270+
value: "```lua\n(field) ObserverParams.next()\n```\n\n---\n\n测试".to_string(),
271+
},
272+
));
273+
}
274+
275+
#[test]
276+
fn test_field_key_for_generic() {
277+
let mut ws = ProviderVirtualWorkspace::new();
278+
ws.def(
279+
r#"
280+
---@class ObserverParams<T>
281+
---@field next fun() # 测试
282+
283+
---@generic T
284+
---@param params fun() | ObserverParams<T>
285+
function test(params)
286+
end
287+
"#,
288+
);
289+
assert!(ws.check_hover(
290+
r#"
291+
test({
292+
<??>next = function()
293+
end
294+
})
295+
"#,
296+
VirtualHoverResult {
297+
value: "```lua\n(field) ObserverParams.next()\n```\n\n---\n\n测试".to_string(),
298+
},
299+
));
300+
}
248301
}

0 commit comments

Comments
 (0)