Skip to content

Commit 26377c8

Browse files
committed
Infer the value of an index expression where the member key is keyof.
1 parent b5cdab2 commit 26377c8

File tree

4 files changed

+93
-39
lines changed

4 files changed

+93
-39
lines changed

crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,28 @@ mod test {
102102
assert_eq!(e_ty, LuaType::Integer);
103103
assert_eq!(f_ty, LuaType::Integer);
104104
}
105+
106+
#[test]
107+
fn test_keyof() {
108+
let mut ws = VirtualWorkspace::new();
109+
110+
ws.def(
111+
r#"
112+
---@class SuiteHooks
113+
---@field beforeAll string
114+
---@field afterAll number
115+
116+
---@type SuiteHooks
117+
local hooks = {}
118+
119+
---@type keyof SuiteHooks
120+
local name = "beforeAll"
121+
122+
A = hooks[name]
123+
"#,
124+
);
125+
126+
let ty = ws.expr_ty("A");
127+
assert_eq!(ws.humanize_type(ty), "(number|string)");
128+
}
105129
}

crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ fn get_key_types(db: &DbIndex, typ: &LuaType) -> HashSet<LuaType> {
388388
type_set.insert(current_type);
389389
}
390390
LuaType::Call(alias_call) => {
391-
if let Some(key_types) = get_key_of_keys(db, alias_call) {
391+
if let Some(key_types) = get_keyof_keys(db, alias_call) {
392392
for t in key_types {
393393
stack.push(t.clone());
394394
}
@@ -547,7 +547,7 @@ pub fn parse_require_expr_module_info<'a>(
547547
.find_module(&module_path)
548548
}
549549

550-
fn get_key_of_keys(db: &DbIndex, alias_call: &LuaAliasCallType) -> Option<Vec<LuaType>> {
550+
fn get_keyof_keys(db: &DbIndex, alias_call: &LuaAliasCallType) -> Option<Vec<LuaType>> {
551551
if alias_call.get_call_kind() != LuaAliasCallKind::KeyOf {
552552
return None;
553553
}

crates/emmylua_code_analysis/src/semantic/infer/infer_index.rs

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ use rowan::TextRange;
99
use smol_str::SmolStr;
1010

1111
use crate::{
12-
CacheEntry, GenericTpl, InFiled, InferGuardRef, LuaArrayLen, LuaArrayType, LuaDeclOrMemberId,
13-
LuaInferCache, LuaInstanceType, LuaMemberOwner, LuaOperatorOwner, TypeOps,
12+
CacheEntry, GenericTpl, InFiled, InferGuardRef, LuaAliasCallKind, LuaArrayLen, LuaArrayType,
13+
LuaDeclOrMemberId, LuaInferCache, LuaInstanceType, LuaMemberOwner, LuaOperatorOwner, TypeOps,
1414
db_index::{
1515
DbIndex, LuaGenericType, LuaIntersectionType, LuaMemberKey, LuaObjectType,
1616
LuaOperatorMetaMethod, LuaTupleType, LuaType, LuaTypeDeclId, LuaUnionType,
1717
},
18-
enum_variable_is_param, get_tpl_ref_extend_type,
18+
enum_variable_is_param, get_keyof_members, get_tpl_ref_extend_type,
1919
semantic::{
2020
InferGuard,
2121
generic::{TypeSubstitutor, instantiate_type_generic},
@@ -420,26 +420,9 @@ fn infer_custom_type_member(
420420
return member_item.resolve_type(db);
421421
}
422422

423-
if type_decl.is_class()
424-
&& let Some(super_types) = type_index.get_super_types(&prefix_type_id)
425-
{
426-
for super_type in super_types {
427-
let result =
428-
infer_member_by_member_key(db, cache, &super_type, index_expr.clone(), infer_guard);
429-
430-
match result {
431-
Ok(member_type) => {
432-
return Ok(member_type);
433-
}
434-
Err(InferFailReason::FieldNotFound) => {}
435-
Err(err) => return Err(err),
436-
}
437-
}
438-
}
439-
440423
// 解决`key`为表达式的情况
441424
if let LuaIndexKey::Expr(expr) = index_key
442-
&& let Some(keys) = expr_to_member_key(db, cache, &expr)
425+
&& let Some(keys) = get_expr_member_key(db, cache, &expr)
443426
{
444427
let mut result_types = Vec::new();
445428
for key in keys {
@@ -462,6 +445,23 @@ fn infer_custom_type_member(
462445
}
463446
}
464447

448+
if type_decl.is_class()
449+
&& let Some(super_types) = type_index.get_super_types(&prefix_type_id)
450+
{
451+
for super_type in super_types {
452+
let result =
453+
infer_member_by_member_key(db, cache, &super_type, index_expr.clone(), infer_guard);
454+
455+
match result {
456+
Ok(member_type) => {
457+
return Ok(member_type);
458+
}
459+
Err(InferFailReason::FieldNotFound) => {}
460+
Err(err) => return Err(err),
461+
}
462+
}
463+
}
464+
465465
Err(InferFailReason::FieldNotFound)
466466
}
467467

@@ -1233,46 +1233,76 @@ fn infer_namespace_member(
12331233
))
12341234
}
12351235

1236-
fn expr_to_member_key(
1236+
fn get_expr_member_key(
12371237
db: &DbIndex,
12381238
cache: &mut LuaInferCache,
12391239
expr: &LuaExpr,
1240-
) -> Option<HashSet<LuaMemberKey>> {
1240+
) -> Option<Vec<LuaMemberKey>> {
12411241
let expr_type = infer_expr(db, cache, expr.clone()).ok()?;
12421242
let mut keys: HashSet<LuaMemberKey> = HashSet::new();
12431243
let mut stack = vec![expr_type.clone()];
12441244
let mut visited = HashSet::new();
12451245

12461246
while let Some(current_type) = stack.pop() {
1247-
if visited.contains(&current_type) {
1247+
if !visited.insert(current_type.clone()) {
12481248
continue;
12491249
}
1250-
visited.insert(current_type.clone());
12511250
match &current_type {
12521251
LuaType::StringConst(name) | LuaType::DocStringConst(name) => {
1253-
keys.insert(name.as_ref().to_string().into());
1252+
keys.insert(LuaMemberKey::Name((**name).clone()));
12541253
}
12551254
LuaType::IntegerConst(i) | LuaType::DocIntegerConst(i) => {
1256-
keys.insert((*i).into());
1255+
keys.insert(LuaMemberKey::Integer(*i));
1256+
}
1257+
LuaType::Call(alias_call) => {
1258+
if alias_call.get_call_kind() == LuaAliasCallKind::KeyOf {
1259+
let operands = alias_call.get_operands();
1260+
if operands.len() == 1 {
1261+
if let Some(members) = get_keyof_members(db, &operands[0]) {
1262+
keys.extend(members.into_iter().map(|member| member.key));
1263+
}
1264+
}
1265+
}
1266+
}
1267+
LuaType::MultiLineUnion(multi_union) => {
1268+
for (typ, _) in multi_union.get_unions() {
1269+
if !visited.contains(typ) {
1270+
stack.push(typ.clone());
1271+
}
1272+
}
12571273
}
12581274
LuaType::Union(union_typ) => {
12591275
for t in union_typ.into_vec() {
1260-
stack.push(t.clone());
1276+
if !visited.contains(&t) {
1277+
stack.push(t.clone());
1278+
}
12611279
}
12621280
}
12631281
LuaType::TableConst(_) | LuaType::Tuple(_) => {
1264-
keys.insert(LuaMemberKey::ExprType(expr_type.clone()));
1282+
keys.insert(LuaMemberKey::ExprType(current_type.clone()));
12651283
}
12661284
LuaType::Ref(id) => {
1267-
if let Some(type_decl) = db.get_type_index().get_type_decl(id)
1268-
&& (type_decl.is_enum() || type_decl.is_alias())
1269-
{
1270-
keys.insert(LuaMemberKey::ExprType(current_type.clone()));
1285+
if let Some(type_decl) = db.get_type_index().get_type_decl(id) {
1286+
if type_decl.is_alias() {
1287+
if let Some(origin_type) = type_decl.get_alias_origin(db, None) {
1288+
if !visited.contains(&origin_type) {
1289+
stack.push(origin_type);
1290+
}
1291+
continue;
1292+
}
1293+
}
1294+
if type_decl.is_enum() || type_decl.is_alias() {
1295+
keys.insert(LuaMemberKey::ExprType(current_type.clone()));
1296+
}
12711297
}
12721298
}
12731299
_ => {}
12741300
}
12751301
}
1302+
1303+
// 转换为 Vec 并排序以确保顺序确定性
1304+
let mut keys: Vec<_> = keys.into_iter().collect();
1305+
keys.sort();
12761306
Some(keys)
12771307
}
12781308

crates/emmylua_code_analysis/src/semantic/type_check/complex_type/call_type_check.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ pub fn check_call_type_compact(
3232
}
3333

3434
let source_key_types = LuaType::Union(Arc::new(LuaUnionType::from_vec(
35-
get_key_of_keys(context, &source_operands[0]),
35+
get_keyof_keys(context, &source_operands[0]),
3636
)));
3737
let compact_key_types = LuaType::Union(Arc::new(LuaUnionType::from_vec(
38-
get_key_of_keys(context, &compact_operands[0]),
38+
get_keyof_keys(context, &compact_operands[0]),
3939
)));
4040
return check_general_type_compact(
4141
context,
@@ -46,7 +46,7 @@ pub fn check_call_type_compact(
4646
}
4747
}
4848
_ => {
49-
let key_types = get_key_of_keys(context, &source_operands[0]);
49+
let key_types = get_keyof_keys(context, &source_operands[0]);
5050
for key_type in &key_types {
5151
match check_general_type_compact(
5252
context,
@@ -68,7 +68,7 @@ pub fn check_call_type_compact(
6868
Ok(())
6969
}
7070

71-
fn get_key_of_keys(context: &TypeCheckContext, prefix_type: &LuaType) -> Vec<LuaType> {
71+
fn get_keyof_keys(context: &TypeCheckContext, prefix_type: &LuaType) -> Vec<LuaType> {
7272
let members = get_keyof_members(context.db, prefix_type).unwrap_or_default();
7373
let key_types = members
7474
.iter()

0 commit comments

Comments
 (0)