Skip to content

Commit 65f505b

Browse files
committed
支持泛型在适当情况下使用字面量
1 parent 8020a86 commit 65f505b

File tree

8 files changed

+173
-69
lines changed

8 files changed

+173
-69
lines changed

crates/emmylua_code_analysis/src/compilation/test/generic_test.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,4 +782,38 @@ mod test {
782782
assert_eq!(ws.humanize_type(result_ty), "any[][]");
783783
}
784784
}
785+
786+
#[test]
787+
fn test_constant_decay() {
788+
let mut ws = VirtualWorkspace::new();
789+
ws.def(
790+
r#"
791+
---@alias std.RawGet<T, K> unknown
792+
793+
---@alias std.ConstTpl<T> unknown
794+
795+
---@generic T, K extends keyof T
796+
---@param object T
797+
---@param key K
798+
---@return std.RawGet<T, K>
799+
function pick(object, key)
800+
end
801+
802+
---@class Person
803+
---@field age integer
804+
"#,
805+
);
806+
807+
ws.def(
808+
r#"
809+
---@type Person
810+
local person
811+
812+
result = pick(person, "age")
813+
"#,
814+
);
815+
816+
let result_ty = ws.expr_ty("result");
817+
assert_eq!(ws.humanize_type(result_ty), "integer");
818+
}
785819
}

crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ pub fn build_call_constraint_context(
2929
DocTypeInferContext::new(semantic_model.get_db(), semantic_model.get_file_id());
3030
for (idx, doc_type) in type_list.get_types().enumerate() {
3131
let ty = infer_doc_type(doc_ctx, &doc_type);
32-
substitutor.insert_type(GenericTplId::Func(idx as u32), ty);
32+
substitutor.insert_type(GenericTplId::Func(idx as u32), ty, true);
3333
}
3434
}
3535

@@ -88,11 +88,14 @@ fn record_generic_assignment(
8888
substitutor: &mut TypeSubstitutor,
8989
) {
9090
match param_type {
91-
LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) => {
92-
substitutor.insert_type(tpl_ref.get_tpl_id(), arg_type.clone());
91+
LuaType::TplRef(tpl_ref) => {
92+
substitutor.insert_type(tpl_ref.get_tpl_id(), arg_type.clone(), true);
93+
}
94+
LuaType::ConstTplRef(tpl_ref) => {
95+
substitutor.insert_type(tpl_ref.get_tpl_id(), arg_type.clone(), false);
9396
}
9497
LuaType::StrTplRef(str_tpl_ref) => {
95-
substitutor.insert_type(str_tpl_ref.get_tpl_id(), arg_type.clone());
98+
substitutor.insert_type(str_tpl_ref.get_tpl_id(), arg_type.clone(), true);
9699
}
97100
LuaType::Variadic(variadic) => {
98101
if let Some(inner) = variadic.get_type(0) {

crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use crate::{
1515
instantiate_type::instantiate_doc_function,
1616
tpl_context::TplContext,
1717
tpl_pattern::{
18-
constant_decay, multi_param_tpl_pattern_match_multi_return, tpl_pattern_match,
18+
multi_param_tpl_pattern_match_multi_return, tpl_pattern_match,
1919
variadic_tpl_pattern_match,
2020
},
2121
},
@@ -110,7 +110,7 @@ fn apply_call_generic_type_list(
110110
let typ = infer_doc_type(doc_ctx, &doc_type);
111111
context
112112
.substitutor
113-
.insert_type(GenericTplId::Func(i as u32), typ);
113+
.insert_type(GenericTplId::Func(i as u32), typ, true);
114114
}
115115
}
116116

@@ -166,7 +166,7 @@ fn infer_generic_types_from_call(
166166
let mut arg_types = vec![];
167167
for arg_expr in &arg_exprs[i..] {
168168
let arg_type = infer_expr(db, context.cache, arg_expr.clone())?;
169-
arg_types.push(constant_decay(arg_type));
169+
arg_types.push(arg_type);
170170
}
171171
variadic_tpl_pattern_match(context, variadic, &arg_types)?;
172172
break;

crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ pub fn instantiate_alias_call(
1717
alias_call: &LuaAliasCallType,
1818
substitutor: &TypeSubstitutor,
1919
) -> LuaType {
20-
let operands = alias_call
21-
.get_operands()
20+
let operand_exprs = alias_call.get_operands();
21+
let operands = operand_exprs
2222
.iter()
2323
.map(|it| instantiate_type_generic(db, it, substitutor))
2424
.collect::<Vec<_>>();
@@ -76,15 +76,33 @@ pub fn instantiate_alias_call(
7676
return LuaType::Unknown;
7777
}
7878

79-
instantiate_rawget_call(db, &operands[0], &operands[1])
79+
let key = resolve_literal_operand(operand_exprs.get(1), substitutor)
80+
.unwrap_or_else(|| operands[1].clone());
81+
82+
instantiate_rawget_call(db, &operands[0], &key)
8083
}
8184
LuaAliasCallKind::Index => {
8285
if operands.len() != 2 {
8386
return LuaType::Unknown;
8487
}
8588

86-
instantiate_index_call(db, &operands[0], &operands[1])
89+
let key = resolve_literal_operand(operand_exprs.get(1), substitutor)
90+
.unwrap_or_else(|| operands[1].clone());
91+
92+
instantiate_index_call(db, &operands[0], &key)
93+
}
94+
}
95+
}
96+
97+
fn resolve_literal_operand(
98+
operand: Option<&LuaType>,
99+
substitutor: &TypeSubstitutor,
100+
) -> Option<LuaType> {
101+
match operand {
102+
Some(LuaType::TplRef(tpl_ref)) | Some(LuaType::ConstTplRef(tpl_ref)) => {
103+
substitutor.get_raw_type(tpl_ref.get_tpl_id()).cloned()
87104
}
105+
_ => None,
88106
}
89107
}
90108

crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ fn instantiate_tuple(db: &DbIndex, tuple: &LuaTupleType, substitutor: &TypeSubst
8888
new_types.push(ty.clone().unwrap_or(LuaType::Unknown));
8989
}
9090
}
91-
SubstitutorValue::Type(ty) => new_types.push(ty.clone()),
91+
SubstitutorValue::Type(ty) => new_types.push(ty.default().clone()),
9292
SubstitutorValue::MultiBase(base) => new_types.push(base.clone()),
9393
}
9494
}
@@ -131,9 +131,10 @@ pub fn instantiate_doc_function(
131131
if let Some(value) = substitutor.get(tpl.get_tpl_id()) {
132132
match value {
133133
SubstitutorValue::Type(ty) => {
134+
let resolved_type = ty.default();
134135
// 如果参数是 `...: T...` 且类型是 tuple, 那么我们将展开 tuple
135136
if origin_param.0 == "..."
136-
&& let LuaType::Tuple(tuple) = ty
137+
&& let LuaType::Tuple(tuple) = resolved_type
137138
{
138139
for (i, typ) in tuple.get_types().iter().enumerate() {
139140
let param_name = format!("var{}", i);
@@ -195,10 +196,7 @@ pub fn instantiate_doc_function(
195196
}
196197
}
197198

198-
// 将 substitutor 中存储的类型的 def 转为 ref
199-
let mut modified_substitutor = substitutor.clone();
200-
modified_substitutor.convert_def_to_ref();
201-
let mut inst_ret_type = instantiate_type_generic(db, tpl_ret, &modified_substitutor);
199+
let mut inst_ret_type = instantiate_type_generic(db, tpl_ret, substitutor);
202200
// 对于可变返回值, 如果实例化是 tuple, 那么我们将展开 tuple
203201
if let LuaType::Variadic(_) = &&tpl_ret
204202
&& let LuaType::Tuple(tuple) = &inst_ret_type
@@ -323,7 +321,7 @@ fn instantiate_tpl_ref(_: &DbIndex, tpl: &GenericTpl, substitutor: &TypeSubstitu
323321
return constraint.clone();
324322
}
325323
}
326-
SubstitutorValue::Type(ty) => return ty.clone(),
324+
SubstitutorValue::Type(ty) => return ty.default().clone(),
327325
SubstitutorValue::MultiTypes(types) => {
328326
return LuaType::Variadic(VariadicType::Multi(types.clone()).into());
329327
}
@@ -385,13 +383,16 @@ fn instantiate_variadic_type(
385383
return LuaType::Never;
386384
}
387385
SubstitutorValue::Type(ty) => {
386+
let resolved_type = ty.default();
388387
if matches!(
389-
ty,
388+
resolved_type,
390389
LuaType::Nil | LuaType::Any | LuaType::Unknown | LuaType::Never
391390
) {
392-
return ty.clone();
391+
return resolved_type.clone();
393392
}
394-
return LuaType::Variadic(VariadicType::Base(ty.clone()).into());
393+
return LuaType::Variadic(
394+
VariadicType::Base(resolved_type.clone()).into(),
395+
);
395396
}
396397
SubstitutorValue::MultiTypes(types) => {
397398
return LuaType::Variadic(VariadicType::Multi(types.clone()).into());
@@ -456,7 +457,14 @@ fn instantiate_conditional(
456457
&& alias_call.get_call_kind() == LuaAliasCallKind::Extends
457458
&& alias_call.get_operands().len() == 2
458459
{
459-
let mut left = instantiate_type_generic(db, &alias_call.get_operands()[0], substitutor);
460+
let left_operand = &alias_call.get_operands()[0];
461+
let mut left = instantiate_type_generic(db, left_operand, substitutor);
462+
// 如果左侧是泛型, 那么我们取字面量类型
463+
if let LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) = left_operand {
464+
if let Some(raw) = substitutor.get_raw_type(tpl_ref.get_tpl_id()) {
465+
left = raw.clone();
466+
}
467+
}
460468
let right_origin = &alias_call.get_operands()[1];
461469
let right = instantiate_type_generic(db, right_origin, substitutor);
462470
// 如果存在 new 标记与左侧为类定义, 那么我们需要的是他的构造函数签名
@@ -506,7 +514,7 @@ fn instantiate_conditional(
506514
let tpl_id_map = resolve_infer_tpl_ids(conditional, substitutor, &infer_names);
507515
for (name, ty) in infer_assignments.iter() {
508516
if let Some(tpl_id) = tpl_id_map.get(name.as_str()) {
509-
true_substitutor.insert_type(*tpl_id, ty.clone());
517+
true_substitutor.insert_type(*tpl_id, ty.clone(), true);
510518
}
511519
}
512520
}
@@ -881,7 +889,7 @@ fn instantiate_mapped_value(
881889
replacement: &LuaType,
882890
) -> LuaType {
883891
let mut local_substitutor = substitutor.clone();
884-
local_substitutor.insert_type(tpl_id, replacement.clone());
892+
local_substitutor.insert_type(tpl_id, replacement.clone(), true);
885893
let mut result = instantiate_type_generic(db, &mapped.value, &local_substitutor);
886894
// 根据 readonly 和 optional 属性进行处理
887895
if mapped.is_optional {

crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,14 @@ pub fn tpl_pattern_match(
120120
if tpl.get_tpl_id().is_func() {
121121
context
122122
.substitutor
123-
.insert_type(tpl.get_tpl_id(), constant_decay(target));
123+
.insert_type(tpl.get_tpl_id(), target.clone(), true);
124124
}
125125
}
126126
LuaType::ConstTplRef(tpl) => {
127127
if tpl.get_tpl_id().is_func() {
128-
context.substitutor.insert_type(tpl.get_tpl_id(), target);
128+
context
129+
.substitutor
130+
.insert_type(tpl.get_tpl_id(), target, false);
129131
}
130132
}
131133
LuaType::StrTplRef(str_tpl) => {
@@ -135,7 +137,7 @@ pub fn tpl_pattern_match(
135137
let type_name = SmolStr::new(format!("{}{}{}", prefix, s, suffix));
136138
context
137139
.substitutor
138-
.insert_type(str_tpl.get_tpl_id(), type_name.into());
140+
.insert_type(str_tpl.get_tpl_id(), type_name.into(), true);
139141
}
140142
}
141143
LuaType::Array(array_type) => {
@@ -600,7 +602,7 @@ fn param_type_list_pattern_match_type_list(
600602
if i >= targets.len() {
601603
if let VariadicType::Base(LuaType::TplRef(tpl_ref)) = inner.deref() {
602604
let tpl_id = tpl_ref.get_tpl_id();
603-
context.substitutor.insert_type(tpl_id, LuaType::Nil);
605+
context.substitutor.insert_type(tpl_id, LuaType::Nil, true);
604606
}
605607
break;
606608
}
@@ -671,7 +673,9 @@ fn return_type_pattern_match_target_type(
671673
VariadicType::Base(source_base) => {
672674
if let LuaType::TplRef(type_ref) = source_base {
673675
let tpl_id = type_ref.get_tpl_id();
674-
context.substitutor.insert_type(tpl_id, target_base.clone());
676+
context
677+
.substitutor
678+
.insert_type(tpl_id, target_base.clone(), true);
675679
}
676680
}
677681
VariadicType::Multi(source_multi) => {
@@ -682,16 +686,22 @@ fn return_type_pattern_match_target_type(
682686
&& let LuaType::TplRef(type_ref) = base
683687
{
684688
let tpl_id = type_ref.get_tpl_id();
685-
context
686-
.substitutor
687-
.insert_type(tpl_id, target_base.clone());
689+
context.substitutor.insert_type(
690+
tpl_id,
691+
target_base.clone(),
692+
true,
693+
);
688694
}
689695

690696
break;
691697
}
692698
LuaType::TplRef(tpl_ref) => {
693699
let tpl_id = tpl_ref.get_tpl_id();
694-
context.substitutor.insert_type(tpl_id, target_base.clone());
700+
context.substitutor.insert_type(
701+
tpl_id,
702+
target_base.clone(),
703+
true,
704+
);
695705
}
696706
_ => {}
697707
}
@@ -756,12 +766,12 @@ pub fn variadic_tpl_pattern_match(
756766
let tpl_id = tpl_ref.get_tpl_id();
757767
match target_rest_types.len() {
758768
0 => {
759-
context.substitutor.insert_type(tpl_id, LuaType::Nil);
769+
context.substitutor.insert_type(tpl_id, LuaType::Nil, true);
760770
}
761771
1 => {
762772
context
763773
.substitutor
764-
.insert_type(tpl_id, constant_decay(target_rest_types[0].clone()));
774+
.insert_type(tpl_id, target_rest_types[0].clone(), true);
765775
}
766776
_ => {
767777
context.substitutor.insert_multi_types(
@@ -778,12 +788,14 @@ pub fn variadic_tpl_pattern_match(
778788
let tpl_id = tpl_ref.get_tpl_id();
779789
match target_rest_types.len() {
780790
0 => {
781-
context.substitutor.insert_type(tpl_id, LuaType::Nil);
791+
context.substitutor.insert_type(tpl_id, LuaType::Nil, false);
782792
}
783793
1 => {
784-
context
785-
.substitutor
786-
.insert_type(tpl_id, target_rest_types[0].clone());
794+
context.substitutor.insert_type(
795+
tpl_id,
796+
target_rest_types[0].clone(),
797+
false,
798+
);
787799
}
788800
_ => {
789801
context
@@ -808,9 +820,7 @@ pub fn variadic_tpl_pattern_match(
808820
let tpl_id = tpl_ref.get_tpl_id();
809821
match target_rest_types.get(i) {
810822
Some(t) => {
811-
context
812-
.substitutor
813-
.insert_type(tpl_id, constant_decay(t.clone()));
823+
context.substitutor.insert_type(tpl_id, t.clone(), true);
814824
}
815825
None => {
816826
break;

0 commit comments

Comments
 (0)