Skip to content

Commit b95eae8

Browse files
committed
support tpl_ref infer and completion
fix #646
1 parent 4f45518 commit b95eae8

File tree

5 files changed

+190
-5
lines changed

5 files changed

+190
-5
lines changed

crates/emmylua_code_analysis/src/compilation/test/generic_test.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,28 @@ mod test {
9696
let expected = ws.ty("Observable<number>");
9797
assert_eq!(a_ty, expected);
9898
}
99+
100+
#[test]
101+
fn test_issue_646() {
102+
let mut ws = VirtualWorkspace::new();
103+
ws.def(
104+
r#"
105+
---@class Base
106+
---@field a string
107+
"#,
108+
);
109+
ws.def(
110+
r#"
111+
---@generic T: Base
112+
---@param file T
113+
function dirname(file)
114+
A = file.a
115+
end
116+
"#,
117+
);
118+
119+
let a_ty = ws.expr_ty("A");
120+
let expected = ws.ty("string");
121+
assert_eq!(a_ty, expected);
122+
}
99123
}

crates/emmylua_code_analysis/src/semantic/generic/mod.rs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,111 @@ mod tpl_context;
66
mod tpl_pattern;
77
mod type_substitutor;
88

9+
use emmylua_parser::LuaAstNode;
10+
use emmylua_parser::LuaExpr;
911
pub use instantiate_func_generic::build_self_type;
1012
pub use instantiate_func_generic::infer_self_type;
1113
pub use instantiate_func_generic::instantiate_func_generic;
1214
pub use instantiate_type_generic::instantiate_doc_function;
1315
pub use instantiate_type_generic::instantiate_type_generic;
16+
use rowan::NodeOrToken;
1417
pub use tpl_context::TplContext;
1518
pub use tpl_pattern::tpl_pattern_match_args;
1619
pub use type_substitutor::TypeSubstitutor;
20+
21+
use crate::DbIndex;
22+
use crate::GenericTplId;
23+
use crate::LuaDeclExtra;
24+
use crate::LuaInferCache;
25+
use crate::LuaMemberOwner;
26+
use crate::LuaSemanticDeclId;
27+
use crate::LuaType;
28+
use crate::SemanticDeclLevel;
29+
use crate::TypeOps;
30+
use crate::infer_node_semantic_decl;
31+
use crate::semantic::semantic_info::infer_token_semantic_decl;
32+
33+
pub fn get_tpl_ref_extend_type(
34+
db: &DbIndex,
35+
cache: &mut LuaInferCache,
36+
arg_type: &LuaType,
37+
arg_expr: LuaExpr,
38+
depth: usize,
39+
) -> Option<LuaType> {
40+
match arg_type {
41+
LuaType::TplRef(tpl_ref) => {
42+
let node_or_token = arg_expr.syntax().clone().into();
43+
let semantic_decl = match node_or_token {
44+
NodeOrToken::Node(node) => {
45+
infer_node_semantic_decl(db, cache, node, SemanticDeclLevel::default())
46+
}
47+
NodeOrToken::Token(token) => {
48+
infer_token_semantic_decl(db, cache, token, SemanticDeclLevel::default())
49+
}
50+
}?;
51+
52+
match tpl_ref.get_tpl_id() {
53+
GenericTplId::Func(tpl_id) => {
54+
if let LuaSemanticDeclId::LuaDecl(decl_id) = semantic_decl {
55+
let decl = db.get_decl_index().get_decl(&decl_id)?;
56+
match decl.extra {
57+
LuaDeclExtra::Param { signature_id, .. } => {
58+
let signature = db.get_signature_index().get(&signature_id)?;
59+
if let Some((_, param_type)) =
60+
signature.generic_params.get(tpl_id as usize)
61+
{
62+
return param_type.clone();
63+
}
64+
}
65+
_ => return None,
66+
}
67+
}
68+
None
69+
}
70+
GenericTplId::Type(tpl_id) => {
71+
if let LuaSemanticDeclId::LuaDecl(decl_id) = semantic_decl {
72+
let decl = db.get_decl_index().get_decl(&decl_id)?;
73+
match decl.extra {
74+
LuaDeclExtra::Param {
75+
owner_member_id, ..
76+
} => {
77+
let owner_member_id = owner_member_id?;
78+
let parent_owner =
79+
db.get_member_index().get_current_owner(&owner_member_id)?;
80+
match parent_owner {
81+
LuaMemberOwner::Type(type_id) => {
82+
let generic_params =
83+
db.get_type_index().get_generic_params(&type_id)?;
84+
return generic_params.get(tpl_id as usize)?.1.clone();
85+
}
86+
_ => return None,
87+
}
88+
}
89+
_ => return None,
90+
}
91+
}
92+
None
93+
}
94+
}
95+
}
96+
LuaType::Union(union_type) => {
97+
if depth > 1 {
98+
return None;
99+
}
100+
let mut result = LuaType::Unknown;
101+
for union_member_type in union_type.into_vec().iter() {
102+
let extend_type = get_tpl_ref_extend_type(
103+
db,
104+
cache,
105+
union_member_type,
106+
arg_expr.clone(),
107+
depth + 1,
108+
)
109+
.unwrap_or(union_member_type.clone());
110+
result = TypeOps::Union.apply(db, &result, &extend_type);
111+
}
112+
Some(result)
113+
}
114+
_ => None,
115+
}
116+
}

crates/emmylua_code_analysis/src/semantic/infer/infer_index.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ use rowan::TextRange;
99
use smol_str::SmolStr;
1010

1111
use crate::{
12-
CacheEntry, InFiled, LuaArrayLen, LuaArrayType, LuaDeclOrMemberId, LuaInferCache,
12+
CacheEntry, GenericTpl, InFiled, LuaArrayLen, LuaArrayType, LuaDeclOrMemberId, LuaInferCache,
1313
LuaInstanceType, LuaMemberOwner, LuaOperatorOwner, TypeOps,
1414
db_index::{
1515
DbIndex, LuaGenericType, LuaIntersectionType, LuaMemberKey, LuaObjectType,
1616
LuaOperatorMetaMethod, LuaTupleType, LuaType, LuaTypeDeclId, LuaUnionType,
1717
},
18-
enum_variable_is_param,
18+
enum_variable_is_param, get_tpl_ref_extend_type,
1919
semantic::{
2020
InferGuard,
2121
generic::{TypeSubstitutor, instantiate_type_generic},
@@ -173,6 +173,7 @@ pub fn infer_member_by_member_key(
173173
LuaType::Instance(inst) => infer_instance_member(db, cache, inst, index_expr, infer_guard),
174174
LuaType::Namespace(ns) => infer_namespace_member(db, cache, ns, index_expr),
175175
LuaType::Array(array_type) => infer_array_member(db, cache, array_type, index_expr),
176+
LuaType::TplRef(tpl) => infer_tpl_ref_member(db, cache, tpl, index_expr, infer_guard),
176177
_ => Err(InferFailReason::FieldNotFound),
177178
}
178179
}
@@ -1233,3 +1234,25 @@ fn expr_to_member_key(
12331234
}
12341235
Some(keys)
12351236
}
1237+
1238+
fn infer_tpl_ref_member(
1239+
db: &DbIndex,
1240+
cache: &mut LuaInferCache,
1241+
generic: &GenericTpl,
1242+
index_expr: LuaIndexMemberExpr,
1243+
infer_guard: &mut InferGuard,
1244+
) -> InferResult {
1245+
let extend_type = get_tpl_ref_extend_type(
1246+
db,
1247+
cache,
1248+
&LuaType::TplRef(generic.clone().into()),
1249+
index_expr
1250+
.get_index_expr()
1251+
.ok_or(InferFailReason::None)?
1252+
.get_prefix_expr()
1253+
.ok_or(InferFailReason::None)?,
1254+
0,
1255+
)
1256+
.ok_or(InferFailReason::None)?;
1257+
return infer_member_by_member_key(db, cache, &extend_type, index_expr.clone(), infer_guard);
1258+
}

crates/emmylua_ls/src/handlers/completion/providers/member_provider.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use emmylua_code_analysis::{
2-
DbIndex, LuaMemberInfo, LuaMemberKey, LuaSemanticDeclId, LuaType, LuaTypeDeclId, SemanticModel,
3-
enum_variable_is_param,
2+
DbIndex, LuaMemberInfo, LuaSemanticDeclId, LuaType, LuaTypeDeclId, SemanticModel,
3+
enum_variable_is_param, get_tpl_ref_extend_type,
44
};
55
use emmylua_parser::{LuaAstNode, LuaAstToken, LuaIndexExpr, LuaStringToken};
66
use std::collections::HashMap;
@@ -28,7 +28,20 @@ pub fn add_completion(builder: &mut CompletionBuilder) -> Option<()> {
2828
};
2929

3030
let prefix_expr = index_expr.get_prefix_expr()?;
31-
let prefix_type = builder.semantic_model.infer_expr(prefix_expr.into()).ok()?;
31+
let prefix_type = match builder
32+
.semantic_model
33+
.infer_expr(prefix_expr.clone())
34+
.ok()?
35+
{
36+
LuaType::TplRef(tpl) => get_tpl_ref_extend_type(
37+
builder.semantic_model.get_db(),
38+
&mut builder.semantic_model.get_cache().borrow_mut(),
39+
&LuaType::TplRef(tpl.clone().into()),
40+
prefix_expr.clone(),
41+
0,
42+
)?,
43+
prefix_type => prefix_type,
44+
};
3245
// 如果是枚举类型且为函数参数, 则不进行补全
3346
if enum_variable_is_param(
3447
builder.semantic_model.get_db(),

crates/emmylua_ls/src/handlers/test/completion_test.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2013,4 +2013,29 @@ mod tests {
20132013

20142014
Ok(())
20152015
}
2016+
2017+
#[test]
2018+
fn test_issue_646() {
2019+
let mut ws = ProviderVirtualWorkspace::new();
2020+
ws.def(
2021+
r#"
2022+
---@class Base
2023+
---@field a string
2024+
"#,
2025+
);
2026+
assert!(ws.check_completion(
2027+
r#"
2028+
---@generic T: Base
2029+
---@param file T
2030+
function dirname(file)
2031+
file.<??>
2032+
end
2033+
"#,
2034+
vec![VirtualCompletionItem {
2035+
label: "a".to_string(),
2036+
kind: CompletionItemKind::VARIABLE,
2037+
..Default::default()
2038+
},],
2039+
));
2040+
}
20162041
}

0 commit comments

Comments
 (0)