Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -250,19 +250,38 @@ pub fn analyze_return_cast(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturnCast)
if let Some(LuaSemanticDeclId::Signature(signature_id)) = get_owner_id(analyzer) {
let name_token = tag.get_name_token()?;
let name = name_token.get_name_text();
let cast_op_type = tag.get_op_type()?;

let op_types: Vec<_> = tag.get_op_types().collect();
let cast_op_type = op_types.first()?;

// Bind the true condition type
if let Some(node_type) = cast_op_type.get_type() {
let typ = infer_type(analyzer, node_type.clone());
let infiled_syntax_id = InFiled::new(analyzer.file_id, node_type.get_syntax_id());
let type_owner = LuaTypeOwner::SyntaxId(infiled_syntax_id);
bind_type(analyzer.db, type_owner, LuaTypeCache::DocType(typ));
};

// Bind the false condition type if present
let fallback_cast = if op_types.len() > 1 {
let fallback_op_type = &op_types[1];
if let Some(node_type) = fallback_op_type.get_type() {
let typ = infer_type(analyzer, node_type.clone());
let infiled_syntax_id = InFiled::new(analyzer.file_id, node_type.get_syntax_id());
let type_owner = LuaTypeOwner::SyntaxId(infiled_syntax_id);
bind_type(analyzer.db, type_owner, LuaTypeCache::DocType(typ));
}
Some(fallback_op_type.to_ptr())
} else {
None
};
Comment on lines +257 to +277

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's some code duplication when binding the types for the true and false conditions. This can be refactored into a local closure to improve readability and maintainability. Using op_types.get(1).map(...) is also more idiomatic than checking length and then indexing.

        let bind_op_type = |op_type: &emmylua_parser::LuaDocOpType| {
            if let Some(node_type) = op_type.get_type() {
                let typ = infer_type(analyzer, node_type.clone());
                let infiled_syntax_id = InFiled::new(analyzer.file_id, node_type.get_syntax_id());
                let type_owner = LuaTypeOwner::SyntaxId(infiled_syntax_id);
                bind_type(analyzer.db, type_owner, LuaTypeCache::DocType(typ));
            }
        };

        // Bind the true condition type
        bind_op_type(cast_op_type);

        // Bind the false condition type if present
        let fallback_cast = op_types.get(1).map(|fallback_op_type| {
            bind_op_type(fallback_op_type);
            fallback_op_type.to_ptr()
        });


analyzer.db.get_flow_index_mut().add_signature_cast(
analyzer.file_id,
signature_id,
name.to_string(),
cast_op_type.to_ptr(),
fallback_cast,
);
} else {
report_orphan_tag(analyzer, &tag);
Expand Down
104 changes: 104 additions & 0 deletions crates/emmylua_code_analysis/src/compilation/test/flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1355,4 +1355,108 @@ _2 = a[1]
"#
));
}

#[test]
fn test_return_cast_with_fallback() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@class Creature

---@class Player: Creature

---@class Monster: Creature

---@return boolean
---@return_cast creature Player else Monster
local function isPlayer(creature)
return true
end

local creature ---@type Creature

if isPlayer(creature) then
a = creature
else
b = creature
end
"#,
);

let a = ws.expr_ty("a");
let a_expected = ws.ty("Player");
assert_eq!(a, a_expected);

let b = ws.expr_ty("b");
let b_expected = ws.ty("Monster");
assert_eq!(b, b_expected);
}

#[test]
fn test_return_cast_with_fallback_self() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@class Creature

---@class Player: Creature

---@class Monster: Creature
local m = {}

---@return boolean
---@return_cast self Player else Monster
function m:isPlayer()
end

if m:isPlayer() then
a = m
else
b = m
end
"#,
);

let a = ws.expr_ty("a");
let a_expected = ws.ty("Player");
assert_eq!(a, a_expected);

let b = ws.expr_ty("b");
let b_expected = ws.ty("Monster");
assert_eq!(b, b_expected);
}

#[test]
fn test_return_cast_backward_compatibility() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@return boolean
---@return_cast n integer
local function isInteger(n)
return true
end

local a ---@type integer | string

if isInteger(a) then
d = a
else
e = a
end
"#,
);

let d = ws.expr_ty("d");
let d_expected = ws.ty("integer");
assert_eq!(d, d_expected);

// Should still use the original behavior (remove integer from union)
let e = ws.expr_ty("e");
let e_expected = ws.ty("string");
assert_eq!(e, e_expected);
}
}
10 changes: 9 additions & 1 deletion crates/emmylua_code_analysis/src/db_index/flow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,19 @@ impl LuaFlowIndex {
signature_id: LuaSignatureId,
name: String,
cast: LuaAstPtr<LuaDocOpType>,
fallback_cast: Option<LuaAstPtr<LuaDocOpType>>,
) {
self.signature_cast_cache
.entry(file_id)
.or_default()
.insert(signature_id, LuaSignatureCast { name, cast });
.insert(
signature_id,
LuaSignatureCast {
name,
cast,
fallback_cast,
},
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ use emmylua_parser::{LuaAstPtr, LuaDocOpType};
pub struct LuaSignatureCast {
pub name: String,
pub cast: LuaAstPtr<LuaDocOpType>,
pub fallback_cast: Option<LuaAstPtr<LuaDocOpType>>,
}
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,50 @@ fn get_type_at_call_expr_by_signature_self(
};

let signature_root = syntax_tree.get_chunk_node();
let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else {
return Ok(ResultTypeOrContinue::Continue);

// Choose the appropriate cast based on condition_flow and whether fallback exists
let result_type = match condition_flow {
InferConditionFlow::TrueCondition => {
let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else {
return Ok(ResultTypeOrContinue::Continue);
};
cast_type(
db,
signature_id.get_file_id(),
cast_op_type,
antecedent_type,
condition_flow,
)?
}
InferConditionFlow::FalseCondition => {
// Use fallback_cast if available, otherwise use the default behavior
if let Some(fallback_cast_ptr) = &signature_cast.fallback_cast {
let Some(fallback_op_type) = fallback_cast_ptr.to_node(&signature_root) else {
return Ok(ResultTypeOrContinue::Continue);
};
cast_type(
db,
signature_id.get_file_id(),
fallback_op_type,
antecedent_type.clone(),
InferConditionFlow::TrueCondition, // Apply fallback as force cast
)?
} else {
// Original behavior: remove the true type from antecedent
let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else {
return Ok(ResultTypeOrContinue::Continue);
};
cast_type(
db,
signature_id.get_file_id(),
cast_op_type,
antecedent_type,
condition_flow,
)?
}
}
};
Comment on lines +232 to 273

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of logic for choosing and applying the type cast is duplicated in get_type_at_call_expr_by_signature_param_name (lines 341-382). To improve maintainability and reduce redundancy, you should extract this logic into a new private helper function. This function could take db, signature_id, signature_cast, signature_root, antecedent_type, and condition_flow as arguments.


let result_type = cast_type(
db,
signature_id.get_file_id(),
cast_op_type,
antecedent_type,
condition_flow,
)?;
Ok(ResultTypeOrContinue::Result(result_type))
}

Expand Down Expand Up @@ -304,17 +337,50 @@ fn get_type_at_call_expr_by_signature_param_name(
};

let signature_root = syntax_tree.get_chunk_node();
let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else {
return Ok(ResultTypeOrContinue::Continue);

// Choose the appropriate cast based on condition_flow and whether fallback exists
let result_type = match condition_flow {
InferConditionFlow::TrueCondition => {
let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else {
return Ok(ResultTypeOrContinue::Continue);
};
cast_type(
db,
signature_id.get_file_id(),
cast_op_type,
antecedent_type,
condition_flow,
)?
}
InferConditionFlow::FalseCondition => {
// Use fallback_cast if available, otherwise use the default behavior
if let Some(fallback_cast_ptr) = &signature_cast.fallback_cast {
let Some(fallback_op_type) = fallback_cast_ptr.to_node(&signature_root) else {
return Ok(ResultTypeOrContinue::Continue);
};
cast_type(
db,
signature_id.get_file_id(),
fallback_op_type,
antecedent_type.clone(),
InferConditionFlow::TrueCondition, // Apply fallback as force cast
)?
} else {
// Original behavior: remove the true type from antecedent
let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else {
return Ok(ResultTypeOrContinue::Continue);
};
cast_type(
db,
signature_id.get_file_id(),
cast_op_type,
antecedent_type,
condition_flow,
)?
}
}
};

let result_type = cast_type(
db,
signature_id.get_file_id(),
cast_op_type,
antecedent_type,
condition_flow,
)?;
Ok(ResultTypeOrContinue::Result(result_type))
}

Expand Down
8 changes: 8 additions & 0 deletions crates/emmylua_parser/src/grammar/doc/tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,21 @@ fn parse_tag_return(p: &mut LuaDocParser) -> DocParseResult {
}

// ---@return_cast <param name> <type>
// ---@return_cast <param name> <true_type> else <false_type>
fn parse_tag_return_cast(p: &mut LuaDocParser) -> DocParseResult {
p.set_state(LuaDocLexerState::Normal);
let m = p.mark(LuaSyntaxKind::DocTagReturnCast);
p.bump();
expect_token(p, LuaTokenKind::TkName)?;

parse_op_type(p)?;

// Allow optional second type after 'else' for false condition
if p.current_token() == LuaTokenKind::TkDocElse {
p.bump();
parse_op_type(p)?;
}

p.set_state(LuaDocLexerState::Description);
parse_description(p);
Ok(m.complete(p))
Expand Down
1 change: 1 addition & 0 deletions crates/emmylua_parser/src/kind/lua_token_kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ pub enum LuaTokenKind {
TkDocAs, // as
TkDocIn, // in
TkDocInfer, // infer
TkDocElse, // else (for return_cast)
TkDocContinue, // ---
TkDocContinueOr, // ---| or ---|+ or ---|>
TkDocDetail, // a description
Expand Down
1 change: 1 addition & 0 deletions crates/emmylua_parser/src/lexer/lua_doc_lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ fn to_token_or_name(text: &str) -> LuaTokenKind {
"as" => LuaTokenKind::TkDocAs,
"and" => LuaTokenKind::TkAnd,
"or" => LuaTokenKind::TkOr,
"else" => LuaTokenKind::TkDocElse,
_ => LuaTokenKind::TkName,
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/emmylua_parser/src/syntax/node/doc/tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1458,8 +1458,8 @@ impl LuaAstNode for LuaDocTagReturnCast {
impl LuaDocDescriptionOwner for LuaDocTagReturnCast {}

impl LuaDocTagReturnCast {
pub fn get_op_type(&self) -> Option<LuaDocOpType> {
self.child()
pub fn get_op_types(&self) -> LuaAstChildren<LuaDocOpType> {
self.children()
}

pub fn get_name_token(&self) -> Option<LuaNameToken> {
Expand Down