Skip to content

Commit 82e57f3

Browse files
committed
support TypeGuard<T>
1 parent 551cc21 commit 82e57f3

File tree

8 files changed

+73
-21
lines changed

8 files changed

+73
-21
lines changed

CHANGELOG.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,24 @@ end
5858

5959
`NEW` Support `Lua 5.5` global decl grammar
6060

61+
`NEW` Support `TypeGuard<T>` as return type. For example:
62+
```lua
63+
64+
---@return TypeGuard<string>
65+
local function is_string(value)
66+
return type(value) == "string"
67+
end
68+
69+
local a
70+
71+
if is_string(a) then
72+
print(a:sub(1, 1))
73+
else
74+
print("a is not a string")
75+
end
76+
```
77+
78+
6179
# 0.7.2
6280

6381
`FIX` Fix reading configuration file encoded with UTF-8 BOM

crates/emmylua_code_analysis/resources/std/builtin.lua

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,5 @@
125125
---@alias collectgarbage_opt std.collectgarbage_opt
126126

127127
---@alias metatable std.metatable
128+
129+
---@alias TypeGuard<T> boolean

crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,12 @@ fn infer_special_generic_type(
270270
));
271271
}
272272
"std.Unpack" => {}
273+
"TypeGuard" => {
274+
let first_doc_param_type = generic_type.get_generic_types()?.get_types().next()?;
275+
let first_param = infer_type(analyzer, first_doc_param_type);
276+
277+
return Some(LuaType::TypeGuard(first_param.into()));
278+
}
273279
_ => {}
274280
}
275281

crates/emmylua_code_analysis/src/db_index/flow/type_assert.rs

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::sync::Arc;
1+
use std::{ops::Deref, sync::Arc};
22

33
use crate::{infer_expr, DbIndex, InferFailReason, LuaInferCache, LuaType, TypeOps};
44
use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr, LuaSyntaxId, LuaSyntaxNode};
@@ -182,28 +182,32 @@ fn call_assertion(
182182
let Some(signature) = db.get_signature_index().get(&signature_id) else {
183183
return Err(InferFailReason::None);
184184
};
185-
// donot change the condition
186-
if !signature.get_return_type().is_boolean() {
187-
return Err(InferFailReason::None);
188-
}
189185

190-
let Some(cast) = db.get_flow_index().get_call_cast(signature_id) else {
191-
return Err(InferFailReason::None);
192-
};
193-
194-
let param_name = if param_idx >= 0 {
195-
let Some(param_name) = signature.get_param_name_by_id(param_idx as usize) else {
196-
return Err(InferFailReason::None);
197-
};
186+
let return_type = signature.get_return_type();
187+
// donot change the condition
188+
match return_type {
189+
LuaType::Boolean => {
190+
let Some(cast) = db.get_flow_index().get_call_cast(signature_id) else {
191+
return Err(InferFailReason::None);
192+
};
193+
194+
let param_name = if param_idx >= 0 {
195+
let Some(param_name) = signature.get_param_name_by_id(param_idx as usize) else {
196+
return Err(InferFailReason::None);
197+
};
198198

199-
param_name
200-
} else {
201-
"self".to_string()
202-
};
199+
param_name
200+
} else {
201+
"self".to_string()
202+
};
203203

204-
let Some(typeassert) = cast.get(&param_name) else {
205-
return Err(InferFailReason::None);
206-
};
204+
let Some(typeassert) = cast.get(&param_name) else {
205+
return Err(InferFailReason::None);
206+
};
207207

208-
Ok(typeassert.clone())
208+
Ok(typeassert.clone())
209+
}
210+
LuaType::TypeGuard(inner) => Ok(TypeAssertion::Force(inner.deref().clone())),
211+
_ => return Err(InferFailReason::None),
212+
}
209213
}

crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ pub fn humanize_type(db: &DbIndex, ty: &LuaType, level: RenderLevel) -> String {
8888
LuaType::MultiLineUnion(multi_union) => {
8989
humanize_multi_line_union_type(db, multi_union, level)
9090
}
91+
LuaType::TypeGuard(inner) => {
92+
let type_str = humanize_type(db, inner, level.next_level());
93+
format!("TypeGuard<{}>", type_str)
94+
}
9195
_ => "unknown".to_string(),
9296
}
9397
}

crates/emmylua_code_analysis/src/db_index/type/types.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ pub enum LuaType {
5858
Namespace(ArcIntern<SmolStr>),
5959
Call(Arc<LuaAliasCallType>),
6060
MultiLineUnion(Arc<LuaMultiLineUnion>),
61+
TypeGuard(Arc<LuaType>),
6162
}
6263

6364
impl PartialEq for LuaType {
@@ -103,6 +104,7 @@ impl PartialEq for LuaType {
103104
(LuaType::DocIntegerConst(a), LuaType::DocIntegerConst(b)) => a == b,
104105
(LuaType::Namespace(a), LuaType::Namespace(b)) => a == b,
105106
(LuaType::MultiLineUnion(a), LuaType::MultiLineUnion(b)) => a == b,
107+
(LuaType::TypeGuard(a), LuaType::TypeGuard(b)) => a == b,
106108
_ => false, // 不同变体之间不相等
107109
}
108110
}
@@ -174,6 +176,10 @@ impl Hash for LuaType {
174176
let ptr = Arc::as_ptr(a);
175177
(43, ptr).hash(state)
176178
}
179+
LuaType::TypeGuard(a) => {
180+
let ptr = Arc::as_ptr(a);
181+
(44, ptr).hash(state)
182+
}
177183
}
178184
}
179185
}
@@ -402,6 +408,10 @@ impl LuaType {
402408
pub fn is_member_owner(&self) -> bool {
403409
matches!(self, LuaType::Ref(_) | LuaType::TableConst(_))
404410
}
411+
412+
pub fn is_type_guard(&self) -> bool {
413+
matches!(self, LuaType::TypeGuard(_))
414+
}
405415
}
406416

407417
#[derive(Debug, Clone, Hash, PartialEq, Eq)]

crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,7 @@ pub(crate) fn unwrapp_return_type(
568568
}
569569
}
570570
}
571+
LuaType::TypeGuard(_) => return Ok(LuaType::Boolean),
571572
_ => {}
572573
}
573574

crates/emmylua_code_analysis/src/semantic/type_check/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ fn check_general_type_compact(
112112
compact_type,
113113
check_guard.next_level()?,
114114
),
115+
LuaType::TypeGuard(_) => {
116+
if compact_type.is_boolean() {
117+
return Ok(());
118+
}
119+
return Err(TypeCheckFailReason::TypeNotMatch);
120+
}
115121
_ => Err(TypeCheckFailReason::TypeNotMatch),
116122
}
117123
}
@@ -142,6 +148,7 @@ fn escape_type(db: &DbIndex, typ: &LuaType) -> Option<LuaType> {
142148
let union = multi_union.to_union();
143149
return Some(union);
144150
}
151+
LuaType::TypeGuard(_) => return Some(LuaType::Boolean),
145152
_ => {}
146153
}
147154

0 commit comments

Comments
 (0)