Skip to content

Commit ac52d6e

Browse files
committed
fix type_check simple_type enum
1 parent e6cab73 commit ac52d6e

File tree

3 files changed

+107
-12
lines changed

3 files changed

+107
-12
lines changed

crates/emmylua_code_analysis/src/db_index/type/type_decl.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ use rowan::TextRange;
44
use serde::{Deserialize, Deserializer, Serialize, Serializer};
55
use smol_str::SmolStr;
66

7-
use crate::{instantiate_type_generic, DbIndex, FileId, TypeSubstitutor};
7+
use crate::{
8+
instantiate_type_generic, DbIndex, FileId, LuaMemberKey, LuaMemberOwner, TypeSubstitutor,
9+
};
810

9-
use super::LuaType;
11+
use super::{LuaType, LuaUnionType};
1012

1113
#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy)]
1214
pub enum LuaDeclTypeKind {
@@ -173,6 +175,46 @@ impl LuaTypeDecl {
173175
pub fn merge_decl(&mut self, other: LuaTypeDecl) {
174176
self.locations.extend(other.locations);
175177
}
178+
179+
/// 获取枚举字段的类型
180+
pub fn get_enum_field_type(&self, db: &DbIndex) -> Option<LuaType> {
181+
if !self.is_enum() {
182+
return None;
183+
}
184+
185+
let enum_member_owner = LuaMemberOwner::Type(self.get_id());
186+
let enum_members = db.get_member_index().get_members(&enum_member_owner)?;
187+
188+
let mut union_types = Vec::new();
189+
if self.is_enum_key() {
190+
for enum_member in enum_members {
191+
let member_key = enum_member.get_key();
192+
let fake_type = match member_key {
193+
LuaMemberKey::Name(name) => LuaType::DocStringConst(name.clone().into()),
194+
LuaMemberKey::Integer(i) => LuaType::IntegerConst(i.clone()),
195+
LuaMemberKey::None | LuaMemberKey::Expr(_) => continue,
196+
};
197+
198+
union_types.push(fake_type);
199+
}
200+
} else {
201+
for member in enum_members {
202+
if let Some(type_cache) =
203+
db.get_type_index().get_type_cache(&member.get_id().into())
204+
{
205+
let member_fake_type = match type_cache.as_type() {
206+
LuaType::StringConst(s) => LuaType::DocStringConst(s.clone().into()),
207+
LuaType::IntegerConst(i) => LuaType::DocIntegerConst(i.clone()),
208+
_ => type_cache.as_type().clone(),
209+
};
210+
211+
union_types.push(member_fake_type);
212+
}
213+
}
214+
}
215+
216+
return Some(LuaType::Union(LuaUnionType::new(union_types).into()));
217+
}
176218
}
177219

178220
#[derive(Debug, Eq, PartialEq, Hash, Clone)]

crates/emmylua_code_analysis/src/diagnostic/test/param_type_check_test.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,41 @@ mod test {
649649
));
650650
}
651651

652+
#[test]
653+
fn test_alias_union_enum_2() {
654+
let mut ws = VirtualWorkspace::new();
655+
assert!(!ws.check_code_for(
656+
DiagnosticCode::ParamTypeNotMatch,
657+
r#"
658+
---@alias EventType
659+
---| GlobalEventType
660+
---| UIEventType
661+
662+
---@enum UIEventType
663+
local UIEventType = {
664+
['UI_CREATE'] = "ET_UI_PREFAB_CREATE_EVENT",
665+
['UI_DELETE'] = "ET_UI_PREFAB_DEL_EVENT",
666+
}
667+
668+
---@enum GlobalEventType
669+
local GlobalEventType = {
670+
['GAME_INIT'] = 1,
671+
['GAME_PAUSE'] = "ET_GAME_PAUSE",
672+
}
673+
674+
---@param event_name string
675+
local function get_py_event_name(event_name)
676+
end
677+
678+
---@param a EventType
679+
local function test(a)
680+
get_py_event_name(a)
681+
end
682+
683+
"#
684+
));
685+
}
686+
652687
#[test]
653688
fn test_empty_class() {
654689
let mut ws = VirtualWorkspace::new();

crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
use crate::{semantic::type_check::is_sub_type_of, DbIndex, LuaType};
1+
use crate::{semantic::type_check::is_sub_type_of, DbIndex, LuaType, LuaTypeDeclId};
22

33
use super::{
4-
check_general_type_compact, ref_type::check_ref_type_compact, sub_type::get_base_type_id,
4+
check_general_type_compact, sub_type::get_base_type_id,
55
type_check_fail_reason::TypeCheckFailReason, type_check_guard::TypeCheckGuard, TypeCheckResult,
66
};
77

@@ -263,14 +263,8 @@ fn check_base_type_for_ref_compact(
263263
}
264264
if let Some(decl) = db.get_type_index().get_type_decl(type_decl_id) {
265265
if decl.is_enum() {
266-
// TODO: 优化, 不经过`check_ref_type_compact`
267-
if check_ref_type_compact(
268-
db,
269-
type_decl_id,
270-
compact_type,
271-
check_guard.next_level()?,
272-
)
273-
.is_ok()
266+
if check_enum_fields_match_source(db, source, type_decl_id, check_guard)
267+
.is_ok()
274268
{
275269
return Ok(());
276270
}
@@ -282,3 +276,27 @@ fn check_base_type_for_ref_compact(
282276
}
283277
Err(TypeCheckFailReason::TypeNotMatch)
284278
}
279+
280+
/// 检查`enum`的所有字段是否匹配`source`
281+
fn check_enum_fields_match_source(
282+
db: &DbIndex,
283+
source: &LuaType,
284+
enum_type_decl_id: &LuaTypeDeclId,
285+
check_guard: TypeCheckGuard,
286+
) -> TypeCheckResult {
287+
if let Some(decl) = db.get_type_index().get_type_decl(enum_type_decl_id) {
288+
if let Some(LuaType::Union(enum_fields)) = decl.get_enum_field_type(db) {
289+
let is_match = enum_fields.get_types().iter().all(|field| {
290+
let next_guard = check_guard.next_level();
291+
if next_guard.is_err() {
292+
return false;
293+
}
294+
check_general_type_compact(db, source, field, next_guard.unwrap()).is_ok()
295+
});
296+
if is_match {
297+
return Ok(());
298+
}
299+
}
300+
}
301+
Err(TypeCheckFailReason::TypeNotMatch)
302+
}

0 commit comments

Comments
 (0)