diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs index 8b4e15393..deb76cee5 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs @@ -250,7 +250,11 @@ pub fn analyze_return_cast(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturnCast) if let Some(LuaSemanticDeclId::Signature(signature_id)) = get_owner_id(analyzer) { let name_token = tag.get_name_token()?; let name = name_token.get_name_text(); - let cast_op_type = tag.get_op_type()?; + + let op_types: Vec<_> = tag.get_op_types().collect(); + let cast_op_type = op_types.first()?; + + // Bind the true condition type if let Some(node_type) = cast_op_type.get_type() { let typ = infer_type(analyzer, node_type.clone()); let infiled_syntax_id = InFiled::new(analyzer.file_id, node_type.get_syntax_id()); @@ -258,11 +262,26 @@ pub fn analyze_return_cast(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturnCast) bind_type(analyzer.db, type_owner, LuaTypeCache::DocType(typ)); }; + // Bind the false condition type if present + let fallback_cast = if op_types.len() > 1 { + let fallback_op_type = &op_types[1]; + if let Some(node_type) = fallback_op_type.get_type() { + let typ = infer_type(analyzer, node_type.clone()); + let infiled_syntax_id = InFiled::new(analyzer.file_id, node_type.get_syntax_id()); + let type_owner = LuaTypeOwner::SyntaxId(infiled_syntax_id); + bind_type(analyzer.db, type_owner, LuaTypeCache::DocType(typ)); + } + Some(fallback_op_type.to_ptr()) + } else { + None + }; + analyzer.db.get_flow_index_mut().add_signature_cast( analyzer.file_id, signature_id, name.to_string(), cast_op_type.to_ptr(), + fallback_cast, ); } else { report_orphan_tag(analyzer, &tag); diff --git a/crates/emmylua_code_analysis/src/compilation/test/flow.rs b/crates/emmylua_code_analysis/src/compilation/test/flow.rs index 1f49b0f59..ac6615ced 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/flow.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/flow.rs @@ -1355,4 +1355,108 @@ _2 = a[1] "# )); } + + #[test] + fn test_return_cast_with_fallback() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Creature + + ---@class Player: Creature + + ---@class Monster: Creature + + ---@return boolean + ---@return_cast creature Player else Monster + local function isPlayer(creature) + return true + end + + local creature ---@type Creature + + if isPlayer(creature) then + a = creature + else + b = creature + end + "#, + ); + + let a = ws.expr_ty("a"); + let a_expected = ws.ty("Player"); + assert_eq!(a, a_expected); + + let b = ws.expr_ty("b"); + let b_expected = ws.ty("Monster"); + assert_eq!(b, b_expected); + } + + #[test] + fn test_return_cast_with_fallback_self() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Creature + + ---@class Player: Creature + + ---@class Monster: Creature + local m = {} + + ---@return boolean + ---@return_cast self Player else Monster + function m:isPlayer() + end + + if m:isPlayer() then + a = m + else + b = m + end + "#, + ); + + let a = ws.expr_ty("a"); + let a_expected = ws.ty("Player"); + assert_eq!(a, a_expected); + + let b = ws.expr_ty("b"); + let b_expected = ws.ty("Monster"); + assert_eq!(b, b_expected); + } + + #[test] + fn test_return_cast_backward_compatibility() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@return boolean + ---@return_cast n integer + local function isInteger(n) + return true + end + + local a ---@type integer | string + + if isInteger(a) then + d = a + else + e = a + end + "#, + ); + + let d = ws.expr_ty("d"); + let d_expected = ws.ty("integer"); + assert_eq!(d, d_expected); + + // Should still use the original behavior (remove integer from union) + let e = ws.expr_ty("e"); + let e_expected = ws.ty("string"); + assert_eq!(e, e_expected); + } } diff --git a/crates/emmylua_code_analysis/src/db_index/flow/mod.rs b/crates/emmylua_code_analysis/src/db_index/flow/mod.rs index 7ec45ef3b..95ef92cd4 100644 --- a/crates/emmylua_code_analysis/src/db_index/flow/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/flow/mod.rs @@ -52,11 +52,19 @@ impl LuaFlowIndex { signature_id: LuaSignatureId, name: String, cast: LuaAstPtr, + fallback_cast: Option>, ) { self.signature_cast_cache .entry(file_id) .or_default() - .insert(signature_id, LuaSignatureCast { name, cast }); + .insert( + signature_id, + LuaSignatureCast { + name, + cast, + fallback_cast, + }, + ); } } diff --git a/crates/emmylua_code_analysis/src/db_index/flow/signature_cast.rs b/crates/emmylua_code_analysis/src/db_index/flow/signature_cast.rs index f8d206747..c792c4dfa 100644 --- a/crates/emmylua_code_analysis/src/db_index/flow/signature_cast.rs +++ b/crates/emmylua_code_analysis/src/db_index/flow/signature_cast.rs @@ -4,4 +4,5 @@ use emmylua_parser::{LuaAstPtr, LuaDocOpType}; pub struct LuaSignatureCast { pub name: String, pub cast: LuaAstPtr, + pub fallback_cast: Option>, } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs index 48c5efda1..8ad1b8b0e 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs @@ -228,17 +228,50 @@ fn get_type_at_call_expr_by_signature_self( }; let signature_root = syntax_tree.get_chunk_node(); - let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else { - return Ok(ResultTypeOrContinue::Continue); + + // Choose the appropriate cast based on condition_flow and whether fallback exists + let result_type = match condition_flow { + InferConditionFlow::TrueCondition => { + let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else { + return Ok(ResultTypeOrContinue::Continue); + }; + cast_type( + db, + signature_id.get_file_id(), + cast_op_type, + antecedent_type, + condition_flow, + )? + } + InferConditionFlow::FalseCondition => { + // Use fallback_cast if available, otherwise use the default behavior + if let Some(fallback_cast_ptr) = &signature_cast.fallback_cast { + let Some(fallback_op_type) = fallback_cast_ptr.to_node(&signature_root) else { + return Ok(ResultTypeOrContinue::Continue); + }; + cast_type( + db, + signature_id.get_file_id(), + fallback_op_type, + antecedent_type.clone(), + InferConditionFlow::TrueCondition, // Apply fallback as force cast + )? + } else { + // Original behavior: remove the true type from antecedent + let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else { + return Ok(ResultTypeOrContinue::Continue); + }; + cast_type( + db, + signature_id.get_file_id(), + cast_op_type, + antecedent_type, + condition_flow, + )? + } + } }; - let result_type = cast_type( - db, - signature_id.get_file_id(), - cast_op_type, - antecedent_type, - condition_flow, - )?; Ok(ResultTypeOrContinue::Result(result_type)) } @@ -304,17 +337,50 @@ fn get_type_at_call_expr_by_signature_param_name( }; let signature_root = syntax_tree.get_chunk_node(); - let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else { - return Ok(ResultTypeOrContinue::Continue); + + // Choose the appropriate cast based on condition_flow and whether fallback exists + let result_type = match condition_flow { + InferConditionFlow::TrueCondition => { + let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else { + return Ok(ResultTypeOrContinue::Continue); + }; + cast_type( + db, + signature_id.get_file_id(), + cast_op_type, + antecedent_type, + condition_flow, + )? + } + InferConditionFlow::FalseCondition => { + // Use fallback_cast if available, otherwise use the default behavior + if let Some(fallback_cast_ptr) = &signature_cast.fallback_cast { + let Some(fallback_op_type) = fallback_cast_ptr.to_node(&signature_root) else { + return Ok(ResultTypeOrContinue::Continue); + }; + cast_type( + db, + signature_id.get_file_id(), + fallback_op_type, + antecedent_type.clone(), + InferConditionFlow::TrueCondition, // Apply fallback as force cast + )? + } else { + // Original behavior: remove the true type from antecedent + let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else { + return Ok(ResultTypeOrContinue::Continue); + }; + cast_type( + db, + signature_id.get_file_id(), + cast_op_type, + antecedent_type, + condition_flow, + )? + } + } }; - let result_type = cast_type( - db, - signature_id.get_file_id(), - cast_op_type, - antecedent_type, - condition_flow, - )?; Ok(ResultTypeOrContinue::Result(result_type)) } diff --git a/crates/emmylua_parser/src/grammar/doc/tag.rs b/crates/emmylua_parser/src/grammar/doc/tag.rs index b2441390c..9e015cb89 100644 --- a/crates/emmylua_parser/src/grammar/doc/tag.rs +++ b/crates/emmylua_parser/src/grammar/doc/tag.rs @@ -353,6 +353,7 @@ fn parse_tag_return(p: &mut LuaDocParser) -> DocParseResult { } // ---@return_cast +// ---@return_cast else fn parse_tag_return_cast(p: &mut LuaDocParser) -> DocParseResult { p.set_state(LuaDocLexerState::Normal); let m = p.mark(LuaSyntaxKind::DocTagReturnCast); @@ -360,6 +361,13 @@ fn parse_tag_return_cast(p: &mut LuaDocParser) -> DocParseResult { expect_token(p, LuaTokenKind::TkName)?; parse_op_type(p)?; + + // Allow optional second type after 'else' for false condition + if p.current_token() == LuaTokenKind::TkDocElse { + p.bump(); + parse_op_type(p)?; + } + p.set_state(LuaDocLexerState::Description); parse_description(p); Ok(m.complete(p)) diff --git a/crates/emmylua_parser/src/kind/lua_token_kind.rs b/crates/emmylua_parser/src/kind/lua_token_kind.rs index c0e11dde3..0fe1ecae6 100644 --- a/crates/emmylua_parser/src/kind/lua_token_kind.rs +++ b/crates/emmylua_parser/src/kind/lua_token_kind.rs @@ -143,6 +143,7 @@ pub enum LuaTokenKind { TkDocAs, // as TkDocIn, // in TkDocInfer, // infer + TkDocElse, // else (for return_cast) TkDocContinue, // --- TkDocContinueOr, // ---| or ---|+ or ---|> TkDocDetail, // a description diff --git a/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs b/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs index 8adb74d29..aaac08bcf 100644 --- a/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs +++ b/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs @@ -638,6 +638,7 @@ fn to_token_or_name(text: &str) -> LuaTokenKind { "as" => LuaTokenKind::TkDocAs, "and" => LuaTokenKind::TkAnd, "or" => LuaTokenKind::TkOr, + "else" => LuaTokenKind::TkDocElse, _ => LuaTokenKind::TkName, } } diff --git a/crates/emmylua_parser/src/syntax/node/doc/tag.rs b/crates/emmylua_parser/src/syntax/node/doc/tag.rs index 5b6b3edf8..525a185e8 100644 --- a/crates/emmylua_parser/src/syntax/node/doc/tag.rs +++ b/crates/emmylua_parser/src/syntax/node/doc/tag.rs @@ -1458,8 +1458,8 @@ impl LuaAstNode for LuaDocTagReturnCast { impl LuaDocDescriptionOwner for LuaDocTagReturnCast {} impl LuaDocTagReturnCast { - pub fn get_op_type(&self) -> Option { - self.child() + pub fn get_op_types(&self) -> LuaAstChildren { + self.children() } pub fn get_name_token(&self) -> Option {