Skip to content

Commit 3115208

Browse files
committed
feature: add class_default_call
将`signature`添加为默认的`__call`调用目标, 但如果已经显式为`class`声明了`---@overload`则不会添加. 配置以下选项: runtime.class_default_call.function_name = "__init" untime.class_default_call.force_non_colon = true runtime.class_default_call.force_return_self ```lua ---@Class MyClass local M = {} function M:__init(a) end A = M() -- `A` is `MyClass` ```
1 parent 4fd95d0 commit 3115208

File tree

10 files changed

+208
-10
lines changed

10 files changed

+208
-10
lines changed

crates/emmylua_code_analysis/resources/schema.json

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@
119119
},
120120
"runtime": {
121121
"default": {
122+
"classDefaultCall": {
123+
"forceNonColon": false,
124+
"forceReturnSelf": false,
125+
"functionName": ""
126+
},
122127
"extensions": [],
123128
"frameworkVersions": [],
124129
"requireLikeFunction": [],
@@ -184,6 +189,26 @@
184189
}
185190
},
186191
"definitions": {
192+
"ClassDefaultCall": {
193+
"type": "object",
194+
"properties": {
195+
"forceNonColon": {
196+
"description": "Mandatory non`:` definition. When `function_name` is not empty, it takes effect.",
197+
"default": true,
198+
"type": "boolean"
199+
},
200+
"forceReturnSelf": {
201+
"description": "Force to return `self`.",
202+
"default": true,
203+
"type": "boolean"
204+
},
205+
"functionName": {
206+
"description": "class default overload function. eg. \"__init\".",
207+
"default": "",
208+
"type": "string"
209+
}
210+
}
211+
},
187212
"DiagnosticCode": {
188213
"oneOf": [
189214
{
@@ -836,6 +861,19 @@
836861
"EmmyrcRuntime": {
837862
"type": "object",
838863
"properties": {
864+
"classDefaultCall": {
865+
"description": "class default overload function.",
866+
"default": {
867+
"forceNonColon": false,
868+
"forceReturnSelf": false,
869+
"functionName": ""
870+
},
871+
"allOf": [
872+
{
873+
"$ref": "#/definitions/ClassDefaultCall"
874+
}
875+
]
876+
},
839877
"extensions": {
840878
"description": "file Extensions. eg: .lua, .lua.txt",
841879
"default": [],

crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ pub fn infer_for_range_iter_expr_func(
118118
.iter()
119119
.filter_map(|overload_id| {
120120
let operator = operator_index.get_operator(overload_id)?;
121-
let func = operator.get_operator_func();
121+
let func = operator.get_operator_func(db);
122122
match func {
123123
LuaType::DocFunction(f) => Some(f.clone()),
124124
_ => None,

crates/emmylua_code_analysis/src/compilation/analyzer/lua/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use crate::{
2424
db_index::{DbIndex, LuaType},
2525
profile::Profile,
2626
semantic::infer_expr,
27-
FileId, InferFailReason,
27+
Emmyrc, FileId, InferFailReason,
2828
};
2929

3030
use super::AnalyzeContext;
@@ -101,6 +101,10 @@ impl LuaAnalyzer<'_> {
101101
context,
102102
}
103103
}
104+
105+
pub fn get_emmyrc(&self) -> &Emmyrc {
106+
self.db.get_emmyrc()
107+
}
104108
}
105109

106110
impl LuaAnalyzer<'_> {

crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ use crate::{
99
unresolve::{UnResolveDecl, UnResolveMember},
1010
},
1111
db_index::{LuaDeclId, LuaMemberId, LuaMemberOwner, LuaType},
12-
InFiled, InferFailReason, LuaTypeCache, LuaTypeOwner,
12+
InFiled, InferFailReason, LuaOperator, LuaOperatorMetaMethod, LuaOperatorOwner, LuaTypeCache,
13+
LuaTypeOwner, OperatorFunction,
1314
};
1415

1516
use super::LuaAnalyzer;
@@ -417,6 +418,8 @@ pub fn analyze_func_stat(analyzer: &mut LuaAnalyzer, func_stat: LuaFuncStat) ->
417418
.get_type_index_mut()
418419
.bind_type(type_owner, LuaTypeCache::InferType(signature_type.clone()));
419420

421+
try_add_class_default_call(analyzer, func_name, signature_type);
422+
420423
Some(())
421424
}
422425

@@ -500,3 +503,56 @@ fn special_assign_pattern(
500503

501504
Some(())
502505
}
506+
507+
pub fn try_add_class_default_call(
508+
analyzer: &mut LuaAnalyzer,
509+
func_name: LuaVarExpr,
510+
signature_type: LuaType,
511+
) -> Option<()> {
512+
let LuaType::Signature(signature_id) = signature_type else {
513+
return None;
514+
};
515+
516+
let default_name = &analyzer
517+
.get_emmyrc()
518+
.runtime
519+
.class_default_call
520+
.function_name;
521+
522+
if default_name.is_empty() {
523+
return None;
524+
}
525+
if let LuaVarExpr::IndexExpr(index_expr) = func_name {
526+
let index_key = index_expr.get_index_key()?;
527+
if index_key.get_path_part() == *default_name {
528+
let prefix_expr = index_expr.get_prefix_expr()?;
529+
match analyzer.infer_expr(&prefix_expr.into()) {
530+
Ok(prefix_type) => match prefix_type {
531+
LuaType::Def(decl_id) => {
532+
// 如果已经存在, 则不添加
533+
let call = analyzer.db.get_operator_index().get_operators(
534+
&LuaOperatorOwner::Type(decl_id.clone()),
535+
LuaOperatorMetaMethod::Call,
536+
);
537+
if call.is_some() {
538+
return None;
539+
}
540+
541+
let operator = LuaOperator::new(
542+
decl_id.into(),
543+
LuaOperatorMetaMethod::Call,
544+
analyzer.file_id,
545+
index_expr.get_range(),
546+
OperatorFunction::DefaultCall(signature_id),
547+
);
548+
analyzer.db.get_operator_index_mut().add_operator(operator);
549+
}
550+
_ => {}
551+
},
552+
Err(_) => {}
553+
}
554+
}
555+
}
556+
557+
Some(())
558+
}

crates/emmylua_code_analysis/src/compilation/test/overload_test.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#[cfg(test)]
22
mod test {
3+
use std::{ops::Deref, sync::Arc};
4+
35
use crate::{DiagnosticCode, VirtualWorkspace};
46

57
#[test]
@@ -25,4 +27,30 @@ mod test {
2527
"#
2628
));
2729
}
30+
31+
#[test]
32+
fn test_class_default_call() {
33+
let mut ws = VirtualWorkspace::new();
34+
let mut emmyrc = ws.analysis.emmyrc.deref().clone();
35+
emmyrc.runtime.class_default_call.function_name = "__init".to_string();
36+
emmyrc.runtime.class_default_call.force_non_colon = true;
37+
emmyrc.runtime.class_default_call.force_return_self = true;
38+
ws.analysis.update_config(Arc::new(emmyrc));
39+
40+
ws.def(
41+
r#"
42+
---@class MyClass
43+
local M = {}
44+
45+
function M:__init(a)
46+
end
47+
48+
A = M()
49+
"#,
50+
);
51+
52+
let ty = ws.expr_ty("A");
53+
let expected = ws.ty("MyClass");
54+
assert_eq!(ws.humanize_type(ty), ws.humanize_type(expected));
55+
}
2856
}

crates/emmylua_code_analysis/src/config/configs/runtime.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ pub struct EmmyrcRuntime {
2020
#[serde(default)]
2121
/// Require pattern. eg. "?.lua", "?/init.lua"
2222
pub require_pattern: Vec<String>,
23+
#[serde(default)]
24+
/// class default overload function.
25+
pub class_default_call: ClassDefaultCall,
2326
}
2427

2528
impl Default for EmmyrcRuntime {
@@ -30,6 +33,7 @@ impl Default for EmmyrcRuntime {
3033
framework_versions: Default::default(),
3134
extensions: Default::default(),
3235
require_pattern: Default::default(),
36+
class_default_call: Default::default(),
3337
}
3438
}
3539
}
@@ -79,6 +83,24 @@ impl EmmyrcLuaVersion {
7983
}
8084
}
8185

86+
#[derive(Serialize, Deserialize, Debug, JsonSchema, Clone, Default)]
87+
#[serde(rename_all = "camelCase")]
88+
pub struct ClassDefaultCall {
89+
#[serde(default)]
90+
/// class default overload function. eg. "__init".
91+
pub function_name: String,
92+
#[serde(default = "default_true")]
93+
/// Mandatory non`:` definition. When `function_name` is not empty, it takes effect.
94+
pub force_non_colon: bool,
95+
/// Force to return `self`.
96+
#[serde(default = "default_true")]
97+
pub force_return_self: bool,
98+
}
99+
100+
fn default_true() -> bool {
101+
true
102+
}
103+
82104
#[cfg(test)]
83105
mod tests {
84106
use super::*;

crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pub struct LuaOperator {
2323
pub enum OperatorFunction {
2424
Func(Arc<LuaFunctionType>),
2525
Signature(LuaSignatureId),
26+
DefaultCall(LuaSignatureId),
2627
}
2728

2829
impl LuaOperator {
@@ -71,6 +72,8 @@ impl LuaOperator {
7172

7273
LuaType::Any
7374
}
75+
// 只有 .field 才有`operand`, call 不会有这个
76+
OperatorFunction::DefaultCall(_) => LuaType::Unknown,
7477
}
7578
}
7679

@@ -92,15 +95,62 @@ impl LuaOperator {
9295
}
9396
}
9497

98+
Ok(LuaType::Any)
99+
}
100+
OperatorFunction::DefaultCall(signature_id) => {
101+
let emmyrc = db.get_emmyrc();
102+
if emmyrc.runtime.class_default_call.force_return_self {
103+
return Ok(LuaType::SelfInfer);
104+
}
105+
106+
if let Some(signature) = db.get_signature_index().get(signature_id) {
107+
if signature.resolve_return == SignatureReturnStatus::UnResolve {
108+
return Err(InferFailReason::UnResolveSignatureReturn(
109+
signature_id.clone(),
110+
));
111+
}
112+
113+
let return_type = signature.return_docs.get(0);
114+
if let Some(return_type) = return_type {
115+
return Ok(return_type.type_ref.clone());
116+
}
117+
}
118+
95119
Ok(LuaType::Any)
96120
}
97121
}
98122
}
99123

100-
pub fn get_operator_func(&self) -> LuaType {
124+
pub fn get_operator_func(&self, db: &DbIndex) -> LuaType {
101125
match &self.func {
102126
OperatorFunction::Func(func) => LuaType::DocFunction(func.clone()),
103127
OperatorFunction::Signature(signature) => LuaType::Signature(*signature),
128+
OperatorFunction::DefaultCall(signature_id) => {
129+
let emmyrc = db.get_emmyrc();
130+
131+
if let Some(signature) = db.get_signature_index().get(signature_id) {
132+
let params = signature.get_type_params();
133+
let is_colon_define = if emmyrc.runtime.class_default_call.force_non_colon {
134+
false
135+
} else {
136+
signature.is_colon_define
137+
};
138+
let return_type = if emmyrc.runtime.class_default_call.force_return_self {
139+
LuaType::SelfInfer
140+
} else {
141+
signature.get_return_type()
142+
};
143+
let func_type = LuaFunctionType::new(
144+
signature.is_async,
145+
is_colon_define,
146+
params,
147+
return_type,
148+
);
149+
return LuaType::DocFunction(Arc::new(func_type));
150+
}
151+
152+
LuaType::Signature(*signature_id)
153+
}
104154
}
105155
}
106156

crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ fn infer_type_doc_function(
252252
let operator = operator_index
253253
.get_operator(overload_id)
254254
.ok_or(InferFailReason::None)?;
255-
let func = operator.get_operator_func();
255+
let func = operator.get_operator_func(db);
256256
match func {
257257
LuaType::DocFunction(f) => {
258258
if f.contain_self() {
@@ -326,7 +326,7 @@ fn infer_generic_type_doc_function(
326326
let operator = operator_index
327327
.get_operator(overload_id)
328328
.ok_or(InferFailReason::None)?;
329-
let func = operator.get_operator_func();
329+
let func = operator.get_operator_func(db);
330330
match func {
331331
LuaType::DocFunction(_) => {
332332
let new_f = instantiate_type_generic(db, &func, &substitutor);
@@ -403,7 +403,7 @@ fn infer_table_type_doc_function(db: &DbIndex, table: InFiled<TextRange>) -> Inf
403403
.get_operator_index()
404404
.get_operator(operator_id)
405405
.ok_or(InferFailReason::None)?;
406-
let func = operator.get_operator_func();
406+
let func = operator.get_operator_func(db);
407407
match func {
408408
LuaType::DocFunction(func) => {
409409
return Ok(func.into());

crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ fn check_doc_func_type_compact_for_custom_type(
198198
.get_operator_index()
199199
.get_operator(operator_id)
200200
.ok_or(TypeCheckFailReason::TypeNotMatch)?;
201-
let call_type = operator.get_operator_func();
201+
let call_type = operator.get_operator_func(db);
202202
match call_type {
203203
LuaType::DocFunction(doc_func) => {
204204
match check_doc_func_type_compact_for_params(

crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ fn build_type_signature_help(
276276

277277
for operator_id in operator_ids {
278278
let operator = db.get_operator_index().get_operator(operator_id)?;
279-
let call_type = operator.get_operator_func();
279+
let call_type = operator.get_operator_func(db);
280280
match call_type {
281281
LuaType::DocFunction(func_type) => {
282282
return build_doc_function_signature_help(
@@ -336,7 +336,7 @@ fn build_table_call_signature_help(
336336
.get_operators(&operator_owner, LuaOperatorMetaMethod::Call)?
337337
.first()?;
338338
let operator = db.get_operator_index().get_operator(operator_ids)?;
339-
let call_type = operator.get_operator_func();
339+
let call_type = operator.get_operator_func(db);
340340
match call_type {
341341
LuaType::DocFunction(func_type) => {
342342
return build_doc_function_signature_help(builder, &func_type, colon_call, current_idx);

0 commit comments

Comments
 (0)