Skip to content

Commit cca6a90

Browse files
committed
support generic typeguard
1 parent 5f6351c commit cca6a90

File tree

5 files changed

+92
-16
lines changed

5 files changed

+92
-16
lines changed

crates/emmylua_code_analysis/src/compilation/test/flow.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,4 +919,30 @@ end
919919
let b_expected = ws.ty("string");
920920
assert_eq!(b, b_expected);
921921
}
922+
923+
#[test]
924+
fn test_feature_generic_type_guard() {
925+
let mut ws = VirtualWorkspace::new();
926+
927+
ws.def(
928+
r#"
929+
---@generic T
930+
---@param type `T`
931+
---@return TypeGuard<T>
932+
local function instanceOf(inst, type)
933+
return true
934+
end
935+
936+
local ret --- @type string | nil
937+
938+
if instanceOf(ret, "string") then
939+
a = ret
940+
end
941+
"#,
942+
);
943+
944+
let a = ws.expr_ty("a");
945+
let a_expected = ws.ty("string");
946+
assert_eq!(a, a_expected);
947+
}
922948
}

crates/emmylua_code_analysis/src/db_index/type/types.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ impl LuaType {
291291
LuaType::Nil | LuaType::Boolean | LuaType::Any | LuaType::Unknown => false,
292292
LuaType::BooleanConst(boolean) | LuaType::DocBooleanConst(boolean) => boolean.clone(),
293293
LuaType::Union(u) => u.is_always_truthy(),
294+
LuaType::TypeGuard(_) => false,
294295
_ => true,
295296
}
296297
}
@@ -299,6 +300,7 @@ impl LuaType {
299300
match self {
300301
LuaType::Nil | LuaType::BooleanConst(false) | LuaType::DocBooleanConst(false) => true,
301302
LuaType::Union(u) => u.is_always_falsy(),
303+
LuaType::TypeGuard(_) => false,
302304
_ => false,
303305
}
304306
}
@@ -400,6 +402,7 @@ impl LuaType {
400402
LuaType::StrTplRef(_) => true,
401403
LuaType::SelfInfer => true,
402404
LuaType::MultiLineUnion(inner) => inner.contain_tpl(),
405+
LuaType::TypeGuard(inner) => inner.contain_tpl(),
403406
_ => false,
404407
}
405408
}

crates/emmylua_code_analysis/src/semantic/generic/instantiate_type_generic.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ pub fn instantiate_type_generic(
4545
LuaType::SelfInfer
4646
}
4747
}
48+
LuaType::TypeGuard(guard) => {
49+
let inner = instantiate_type_generic(db, guard.deref(), substitutor);
50+
LuaType::TypeGuard(inner.into())
51+
}
4852
_ => ty.clone(),
4953
}
5054
}

crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,19 @@ pub fn infer_call_expr_func(
107107
};
108108

109109
let result = if let Ok(func_ty) = result {
110-
unwrapp_return_type(db, cache, func_ty.get_ret().clone(), call_expr).map(|new_ret| {
111-
LuaFunctionType::new(
112-
func_ty.is_async(),
113-
func_ty.is_colon_define(),
114-
func_ty.get_params().to_vec(),
115-
new_ret,
116-
)
117-
.into()
118-
})
110+
let func_ret = func_ty.get_ret();
111+
match func_ret {
112+
LuaType::TypeGuard(_) => Ok(func_ty),
113+
_ => unwrapp_return_type(db, cache, func_ret.clone(), call_expr).map(|new_ret| {
114+
LuaFunctionType::new(
115+
func_ty.is_async(),
116+
func_ty.is_colon_define(),
117+
func_ty.get_params().to_vec(),
118+
new_ret,
119+
)
120+
.into()
121+
}),
122+
}
119123
} else {
120124
result
121125
};

crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
use std::ops::Deref;
1+
use std::{ops::Deref, sync::Arc};
22

33
use emmylua_parser::{LuaCallExpr, LuaChunk, LuaExpr};
44

55
use crate::{
6-
infer_expr,
6+
infer_call_expr_func, infer_expr,
77
semantic::infer::{
88
narrow::{
99
condition_flow::InferConditionFlow, get_single_antecedent,
@@ -12,8 +12,8 @@ use crate::{
1212
},
1313
VarRefId,
1414
},
15-
DbIndex, FlowNode, FlowTree, InferFailReason, LuaInferCache, LuaSignatureCast, LuaSignatureId,
16-
LuaType, TypeOps,
15+
DbIndex, FlowNode, FlowTree, InferFailReason, InferGuard, LuaFunctionType, LuaInferCache,
16+
LuaSignatureCast, LuaSignatureId, LuaType, TypeOps,
1717
};
1818

1919
pub fn get_type_at_call_expr(
@@ -35,15 +35,15 @@ pub fn get_type_at_call_expr(
3535
LuaType::DocFunction(f) => {
3636
let return_type = f.get_ret();
3737
match return_type {
38-
LuaType::TypeGuard(guard_type) => get_type_at_call_expr_by_type_guard(
38+
LuaType::TypeGuard(_) => get_type_at_call_expr_by_type_guard(
3939
db,
4040
tree,
4141
cache,
4242
root,
4343
var_ref_id,
4444
flow_node,
4545
call_expr,
46-
guard_type.deref().clone(),
46+
f,
4747
condition_flow,
4848
),
4949
_ => {
@@ -53,6 +53,25 @@ pub fn get_type_at_call_expr(
5353
}
5454
}
5555
LuaType::Signature(signature_id) => {
56+
let Some(signature) = db.get_signature_index().get(&signature_id) else {
57+
return Ok(ResultTypeOrContinue::Continue);
58+
};
59+
60+
let ret = signature.get_return_type();
61+
if let LuaType::TypeGuard(_) = ret {
62+
return get_type_at_call_expr_by_type_guard(
63+
db,
64+
tree,
65+
cache,
66+
root,
67+
var_ref_id,
68+
flow_node,
69+
call_expr,
70+
signature.to_doc_func_type(),
71+
condition_flow,
72+
);
73+
}
74+
5675
let Some(signature_cast) = db
5776
.get_flow_index()
5877
.get_signature_cast(&cache.get_file_id(), &signature_id)
@@ -102,7 +121,7 @@ fn get_type_at_call_expr_by_type_guard(
102121
var_ref_id: &VarRefId,
103122
flow_node: &FlowNode,
104123
call_expr: LuaCallExpr,
105-
guard_type: LuaType,
124+
func_type: Arc<LuaFunctionType>,
106125
condition_flow: InferConditionFlow,
107126
) -> Result<ResultTypeOrContinue, InferFailReason> {
108127
let Some(arg_list) = call_expr.get_args_list() else {
@@ -121,6 +140,26 @@ fn get_type_at_call_expr_by_type_guard(
121140
return Ok(ResultTypeOrContinue::Continue);
122141
}
123142

143+
let mut return_type = func_type.get_ret().clone();
144+
if return_type.contain_tpl() {
145+
let call_expr_type = LuaType::DocFunction(func_type);
146+
let inst_func = infer_call_expr_func(
147+
db,
148+
cache,
149+
call_expr,
150+
call_expr_type,
151+
&mut InferGuard::new(),
152+
None,
153+
)?;
154+
155+
return_type = inst_func.get_ret().clone();
156+
}
157+
158+
let guard_type = match return_type {
159+
LuaType::TypeGuard(guard) => guard.deref().clone(),
160+
_ => return Ok(ResultTypeOrContinue::Continue),
161+
};
162+
124163
match condition_flow {
125164
InferConditionFlow::TrueCondition => Ok(ResultTypeOrContinue::Result(guard_type)),
126165
InferConditionFlow::FalseCondition => {

0 commit comments

Comments
 (0)