Skip to content

Commit 58cbe4d

Browse files
committed
feat: return_cast class field
1 parent 34b5cb4 commit 58cbe4d

File tree

7 files changed

+434
-22
lines changed

7 files changed

+434
-22
lines changed

crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,16 @@ pub fn analyze_return(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturn) -> Optio
266266

267267
pub fn analyze_return_cast(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturnCast) -> Option<()> {
268268
if let Some(LuaSemanticDeclId::Signature(signature_id)) = get_owner_id(analyzer, None, false) {
269-
let name_token = tag.get_name_token()?;
270-
let name = name_token.get_name_text();
269+
// Extract name from either name_token or key_expr
270+
let name = if let Some(key_expr) = tag.get_key_expr() {
271+
// Handle multi-level expressions like self.xxx
272+
extract_name_from_expr(&key_expr)
273+
} else if let Some(name_token) = tag.get_name_token() {
274+
// Fallback to simple name token
275+
name_token.get_name_text().to_string()
276+
} else {
277+
return None;
278+
};
271279

272280
let op_types: Vec<_> = tag.get_op_types().collect();
273281
let cast_op_type = op_types.first()?;
@@ -297,7 +305,7 @@ pub fn analyze_return_cast(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturnCast)
297305
analyzer.db.get_flow_index_mut().add_signature_cast(
298306
analyzer.file_id,
299307
signature_id,
300-
name.to_string(),
308+
name,
301309
cast_op_type.to_ptr(),
302310
fallback_cast,
303311
);
@@ -308,6 +316,45 @@ pub fn analyze_return_cast(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturnCast)
308316
Some(())
309317
}
310318

319+
// Helper function to extract name string from expression
320+
fn extract_name_from_expr(expr: &LuaExpr) -> String {
321+
match expr {
322+
LuaExpr::NameExpr(name_expr) => {
323+
if let Some(token) = name_expr.get_name_token() {
324+
token.get_name_text().to_string()
325+
} else {
326+
String::new()
327+
}
328+
}
329+
LuaExpr::IndexExpr(index_expr) => {
330+
// Recursively build the path like "self.xxx" or "self.a.b"
331+
let prefix = if let Some(prefix_expr) = index_expr.get_prefix_expr() {
332+
extract_name_from_expr(&prefix_expr)
333+
} else {
334+
String::new()
335+
};
336+
337+
let suffix = if let Some(key) = index_expr.get_index_key() {
338+
match key {
339+
emmylua_parser::LuaIndexKey::Name(name_token) => name_token.get_name_text().to_string(),
340+
_ => String::new(),
341+
}
342+
} else {
343+
String::new()
344+
};
345+
346+
if prefix.is_empty() {
347+
suffix
348+
} else if suffix.is_empty() {
349+
prefix
350+
} else {
351+
format!("{}.{}", prefix, suffix)
352+
}
353+
}
354+
_ => String::new(),
355+
}
356+
}
357+
311358
pub fn analyze_overload(analyzer: &mut DocAnalyzer, tag: LuaDocTagOverload) -> Option<()> {
312359
if let Some(decl_id) = analyzer.current_type_id.clone() {
313360
let type_ref = infer_type(analyzer, tag.get_type()?);

crates/emmylua_code_analysis/src/compilation/test/flow.rs

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,4 +1476,135 @@ _2 = a[1]
14761476
"#,
14771477
);
14781478
}
1479+
1480+
#[test]
1481+
fn test_return_cast_self_field() {
1482+
let mut ws = VirtualWorkspace::new();
1483+
1484+
ws.def(
1485+
r#"
1486+
---@class MyClass
1487+
---@field value string|number
1488+
local MyClass = {}
1489+
1490+
---Check if value field is string
1491+
---@param self MyClass
1492+
---@return_cast self.value string
1493+
function MyClass:check_string()
1494+
return type(self.value) == "string"
1495+
end
1496+
1497+
---@param obj MyClass
1498+
function test(obj)
1499+
if obj:check_string() then
1500+
a = obj.value
1501+
else
1502+
b = obj.value
1503+
end
1504+
end
1505+
"#,
1506+
);
1507+
1508+
let a = ws.expr_ty("a");
1509+
let a_expected = ws.ty("string");
1510+
assert_eq!(a, a_expected);
1511+
1512+
let b = ws.expr_ty("b");
1513+
let b_expected = ws.ty("number");
1514+
assert_eq!(b, b_expected);
1515+
}
1516+
1517+
#[test]
1518+
fn test_return_cast_self_field_with_fallback() {
1519+
let mut ws = VirtualWorkspace::new();
1520+
1521+
ws.def(
1522+
r#"
1523+
---@class MyClass
1524+
---@field data table|nil
1525+
local MyClass = {}
1526+
1527+
---Check if data exists
1528+
---@param self MyClass
1529+
---@return_cast self.data table else nil
1530+
function MyClass:has_data()
1531+
return self.data ~= nil
1532+
end
1533+
1534+
---@param obj MyClass
1535+
function test(obj)
1536+
if obj:has_data() then
1537+
c = obj.data
1538+
else
1539+
d = obj.data
1540+
end
1541+
end
1542+
"#,
1543+
);
1544+
1545+
let c = ws.expr_ty("c");
1546+
let c_str = ws.humanize_type(c);
1547+
assert_eq!(c_str, "table");
1548+
1549+
let d = ws.expr_ty("d");
1550+
let d_expected = ws.ty("nil");
1551+
assert_eq!(d, d_expected);
1552+
}
1553+
1554+
#[test]
1555+
fn test_return_cast_self_field_complex() {
1556+
let mut ws = VirtualWorkspace::new();
1557+
1558+
ws.def(
1559+
r#"
1560+
---@class Vehicle
1561+
---@field type "car"|"bike"|"truck"
1562+
---@field engine string|nil
1563+
local Vehicle = {}
1564+
1565+
---@param self Vehicle
1566+
---@return_cast self.type "car"
1567+
function Vehicle:is_car()
1568+
return self.type == "car"
1569+
end
1570+
1571+
---@param self Vehicle
1572+
---@return_cast self.engine string else nil
1573+
function Vehicle:has_engine()
1574+
return self.engine ~= nil
1575+
end
1576+
1577+
---@param v Vehicle
1578+
function test(v)
1579+
if v:is_car() then
1580+
e = v.type
1581+
else
1582+
f = v.type
1583+
end
1584+
1585+
if v:has_engine() then
1586+
g = v.engine
1587+
else
1588+
h = v.engine
1589+
end
1590+
end
1591+
"#,
1592+
);
1593+
1594+
let e = ws.expr_ty("e");
1595+
let e_expected = ws.ty("\"car\"");
1596+
assert_eq!(e, e_expected);
1597+
1598+
let f = ws.expr_ty("f");
1599+
let f_expected = ws.ty("\"bike\"|\"truck\"");
1600+
assert_eq!(f, f_expected);
1601+
1602+
let g = ws.expr_ty("g");
1603+
let g_expected = ws.ty("string");
1604+
assert_eq!(g, g_expected);
1605+
1606+
let h = ws.expr_ty("h");
1607+
let h_expected = ws.ty("nil");
1608+
assert_eq!(h, h_expected);
1609+
}
14791610
}

crates/emmylua_code_analysis/src/diagnostic/checker/undefined_global.rs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use std::collections::HashSet;
22

3-
use emmylua_parser::{LuaAstNode, LuaClosureExpr, LuaNameExpr};
3+
use emmylua_parser::{
4+
LuaAst, LuaAstNode, LuaClosureExpr, LuaComment, LuaDocTagReturnCast, LuaNameExpr,
5+
};
46
use rowan::TextRange;
57

68
use crate::{DiagnosticCode, LuaSignatureId, SemanticModel};
@@ -94,6 +96,7 @@ fn check_name_expr(
9496
}
9597

9698
fn check_self_name(semantic_model: &SemanticModel, name_expr: LuaNameExpr) -> Option<()> {
99+
// Check if self is in a method context (regular Lua code)
97100
let closure_expr = name_expr.ancestors::<LuaClosureExpr>();
98101
for closure_expr in closure_expr {
99102
let signature_id =
@@ -105,6 +108,47 @@ fn check_self_name(semantic_model: &SemanticModel, name_expr: LuaNameExpr) -> Op
105108
if signature.is_method(semantic_model, None) {
106109
return Some(());
107110
}
111+
112+
// Check if self is a parameter of this function (from @param self)
113+
if signature.find_param_idx("self").is_some() {
114+
return Some(());
115+
}
108116
}
117+
118+
// Check if self is in @return_cast tag
119+
// The name_expr might be inside a doc comment, not inside actual Lua code
120+
for ancestor in name_expr.syntax().ancestors() {
121+
if let Some(return_cast_tag) = LuaDocTagReturnCast::cast(ancestor.clone()) {
122+
// Find the LuaComment that contains this tag
123+
for comment_ancestor in return_cast_tag.syntax().ancestors() {
124+
if let Some(comment) = LuaComment::cast(comment_ancestor) {
125+
// Get the owner (function) of this comment
126+
if let Some(owner) = comment.get_owner() {
127+
if let LuaAst::LuaClosureExpr(closure) = owner {
128+
let sig_id = LuaSignatureId::from_closure(
129+
semantic_model.get_file_id(),
130+
&closure,
131+
);
132+
if let Some(sig) =
133+
semantic_model.get_db().get_signature_index().get(&sig_id)
134+
{
135+
// Check if the owner function is a method
136+
if sig.is_method(semantic_model, None) {
137+
return Some(());
138+
}
139+
// Check if self is a parameter
140+
if sig.find_param_idx("self").is_some() {
141+
return Some(());
142+
}
143+
}
144+
}
145+
}
146+
break;
147+
}
148+
}
149+
break;
150+
}
151+
}
152+
109153
None
110154
}

crates/emmylua_code_analysis/src/diagnostic/test/undefined_global_test.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,41 @@ mod test {
2020
"#
2121
));
2222
}
23+
24+
#[test]
25+
fn test_return_cast_self_no_undefined_global() {
26+
let mut ws = VirtualWorkspace::new();
27+
// @return_cast self should not produce undefined-global error in methods
28+
assert!(!ws.check_code_for(
29+
DiagnosticCode::UndefinedGlobal,
30+
r#"
31+
---@class MyClass
32+
local MyClass = {}
33+
34+
---@return_cast self MyClass
35+
function MyClass:check1()
36+
return true
37+
end
38+
"#
39+
));
40+
}
41+
42+
#[test]
43+
fn test_return_cast_self_field_no_undefined_global() {
44+
let mut ws = VirtualWorkspace::new();
45+
// @return_cast self.field should not produce undefined-global error in methods
46+
assert!(!ws.check_code_for(
47+
DiagnosticCode::UndefinedGlobal,
48+
r#"
49+
---@class MyClass
50+
---@field value string|number
51+
local MyClass = {}
52+
53+
---@return_cast self.value string
54+
function MyClass:check_string()
55+
return type(self.value) == "string"
56+
end
57+
"#
58+
));
59+
}
2360
}

0 commit comments

Comments
 (0)