Skip to content

Commit 99c57ba

Browse files
committed
update check_field
1 parent 05a7d70 commit 99c57ba

File tree

3 files changed

+89
-35
lines changed

3 files changed

+89
-35
lines changed

crates/emmylua_code_analysis/src/diagnostic/checker/check_field.rs

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ use emmylua_parser::{
77
};
88

99
use crate::{
10-
DiagnosticCode, InferFailReason, LuaMemberKey, LuaSemanticDeclId, LuaType, ModuleInfo,
11-
SemanticDeclLevel, SemanticModel, enum_variable_is_param, parse_require_module_info,
10+
DbIndex, DiagnosticCode, InferFailReason, LuaAliasCallKind, LuaAliasCallType, LuaMemberKey,
11+
LuaSemanticDeclId, LuaType, ModuleInfo, SemanticDeclLevel, SemanticModel,
12+
enum_variable_is_param, get_keyof_members, parse_require_module_info,
1213
};
1314

1415
use super::{Checker, DiagnosticContext, humanize_lint_type};
@@ -262,7 +263,7 @@ fn is_valid_member(
262263
local field
263264
local a = Class[field]
264265
*/
265-
let key_types = get_key_types(&key_type);
266+
let key_types = get_key_types(&semantic_model.get_db(), &key_type);
266267
if key_types.is_empty() {
267268
return None;
268269
}
@@ -358,7 +359,7 @@ fn get_prefix_types(prefix_typ: &LuaType) -> HashSet<LuaType> {
358359
type_set
359360
}
360361

361-
fn get_key_types(typ: &LuaType) -> HashSet<LuaType> {
362+
fn get_key_types(db: &DbIndex, typ: &LuaType) -> HashSet<LuaType> {
362363
let mut type_set = HashSet::new();
363364
let mut stack = vec![typ.clone()];
364365
let mut visited = HashSet::new();
@@ -383,6 +384,16 @@ fn get_key_types(typ: &LuaType) -> HashSet<LuaType> {
383384
LuaType::StrTplRef(_) | LuaType::Ref(_) => {
384385
type_set.insert(current_type);
385386
}
387+
LuaType::DocStringConst(_) | LuaType::DocIntegerConst(_) => {
388+
type_set.insert(current_type);
389+
}
390+
LuaType::Call(alias_call) => {
391+
if let Some(key_types) = get_key_of_keys(db, alias_call) {
392+
for t in key_types {
393+
stack.push(t.clone());
394+
}
395+
}
396+
}
386397
_ => {}
387398
}
388399
}
@@ -535,3 +546,23 @@ pub fn parse_require_expr_module_info<'a>(
535546
.get_module_index()
536547
.find_module(&module_path)
537548
}
549+
550+
fn get_key_of_keys(db: &DbIndex, alias_call: &LuaAliasCallType) -> Option<Vec<LuaType>> {
551+
if alias_call.get_call_kind() != LuaAliasCallKind::KeyOf {
552+
return None;
553+
}
554+
let source_operands = alias_call.get_operands().iter().collect::<Vec<_>>();
555+
if source_operands.len() != 1 {
556+
return None;
557+
}
558+
let members = get_keyof_members(db, &source_operands[0]).unwrap_or_default();
559+
let key_types = members
560+
.iter()
561+
.filter_map(|m| match &m.key {
562+
LuaMemberKey::Integer(i) => Some(LuaType::DocIntegerConst(*i)),
563+
LuaMemberKey::Name(s) => Some(LuaType::DocStringConst(s.clone().into())),
564+
_ => None,
565+
})
566+
.collect::<Vec<_>>();
567+
Some(key_types)
568+
}

crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -736,31 +736,54 @@ mod test {
736736
));
737737
}
738738

739-
// #[test]
740-
// fn test_export() {
741-
// let mut ws = VirtualWorkspace::new();
742-
// ws.def_file(
743-
// "a.lua",
744-
// r#"
745-
// ---@export
746-
// local export = {}
747-
748-
// return export
749-
// "#,
750-
// );
751-
// assert!(!ws.check_code_for(
752-
// DiagnosticCode::UndefinedField,
753-
// r#"
754-
// local a = require("a")
755-
// a.func()
756-
// "#,
757-
// ));
758-
759-
// assert!(!ws.check_code_for(
760-
// DiagnosticCode::UndefinedField,
761-
// r#"
762-
// local a = require("a").ABC
763-
// "#,
764-
// ));
765-
// }
739+
#[test]
740+
fn test_export() {
741+
let mut ws = VirtualWorkspace::new();
742+
ws.def_file(
743+
"a.lua",
744+
r#"
745+
---@export
746+
local export = {}
747+
748+
return export
749+
"#,
750+
);
751+
assert!(!ws.check_code_for(
752+
DiagnosticCode::UndefinedField,
753+
r#"
754+
local a = require("a")
755+
a.func()
756+
"#,
757+
));
758+
759+
assert!(!ws.check_code_for(
760+
DiagnosticCode::UndefinedField,
761+
r#"
762+
local a = require("a").ABC
763+
"#,
764+
));
765+
}
766+
767+
#[test]
768+
fn test_keyof_type() {
769+
let mut ws = VirtualWorkspace::new();
770+
ws.def(
771+
r#"
772+
---@class SuiteHooks
773+
---@field beforeAll string
774+
775+
---@type SuiteHooks
776+
hooks = {}
777+
778+
---@type keyof SuiteHooks
779+
name = "beforeAll"
780+
"#,
781+
);
782+
assert!(ws.check_code_for(
783+
DiagnosticCode::UndefinedField,
784+
r#"
785+
local a = hooks[name]
786+
"#
787+
));
788+
}
766789
}

crates/emmylua_code_analysis/src/semantic/type_check/complex_type/call_type_check.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ pub fn check_call_type_compact(
3232
}
3333

3434
let source_key_types = LuaType::Union(Arc::new(LuaUnionType::from_vec(
35-
get_key_of_key_types(context, &source_operands[0]),
35+
get_key_of_keys(context, &source_operands[0]),
3636
)));
3737
let compact_key_types = LuaType::Union(Arc::new(LuaUnionType::from_vec(
38-
get_key_of_key_types(context, &compact_operands[0]),
38+
get_key_of_keys(context, &compact_operands[0]),
3939
)));
4040
return check_general_type_compact(
4141
context,
@@ -46,7 +46,7 @@ pub fn check_call_type_compact(
4646
}
4747
}
4848
_ => {
49-
let key_types = get_key_of_key_types(context, &source_operands[0]);
49+
let key_types = get_key_of_keys(context, &source_operands[0]);
5050
for key_type in &key_types {
5151
match check_general_type_compact(
5252
context,
@@ -68,7 +68,7 @@ pub fn check_call_type_compact(
6868
Ok(())
6969
}
7070

71-
fn get_key_of_key_types(context: &TypeCheckContext, prefix_type: &LuaType) -> Vec<LuaType> {
71+
fn get_key_of_keys(context: &TypeCheckContext, prefix_type: &LuaType) -> Vec<LuaType> {
7272
let members = get_keyof_members(context.db, prefix_type).unwrap_or_default();
7373
let key_types = members
7474
.iter()

0 commit comments

Comments
 (0)