Skip to content

Commit e69821d

Browse files
committed
Fix #790
1 parent 9cf96cc commit e69821d

File tree

2 files changed

+116
-10
lines changed

2 files changed

+116
-10
lines changed

crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
use std::{collections::HashMap, sync::Arc};
22

33
use crate::{
4-
LuaGenericType, LuaMemberOwner, LuaType, LuaTypeCache, RenderLevel, TypeSubstitutor,
5-
humanize_type,
4+
LuaGenericType, LuaMemberOwner, LuaType, LuaTypeCache, LuaTypeDeclId, RenderLevel,
5+
TypeSubstitutor, humanize_type,
66
semantic::{
77
member::find_members,
88
type_check::{is_sub_type_of, type_check_context::TypeCheckContext},
99
},
1010
};
1111

1212
use super::{
13-
TypeCheckResult, check_general_type_compact, check_ref_type_compact,
14-
type_check_fail_reason::TypeCheckFailReason, type_check_guard::TypeCheckGuard,
13+
TypeCheckResult, check_general_type_compact, type_check_fail_reason::TypeCheckFailReason,
14+
type_check_guard::TypeCheckGuard,
1515
};
1616

1717
pub fn check_generic_type_compact(
@@ -48,27 +48,50 @@ pub fn check_generic_type_compact(
4848
if is_tpl {
4949
return Ok(());
5050
}
51-
check_generic_type_compact_generic(
51+
let first_result = check_generic_type_compact_generic(
5252
context,
5353
source_generic,
5454
compact_generic,
5555
check_guard.next_level()?,
56-
)
56+
);
57+
if first_result.is_ok() {
58+
return Ok(());
59+
}
60+
61+
if let Some(supers) = context
62+
.db
63+
.get_type_index()
64+
.get_super_types(&compact_generic.get_base_type_id())
65+
{
66+
for super_type in supers {
67+
let result = check_generic_type_compact(
68+
context,
69+
source_generic,
70+
&super_type,
71+
check_guard.next_level()?,
72+
);
73+
if result.is_ok() {
74+
return Ok(());
75+
}
76+
}
77+
}
78+
79+
first_result
5780
}
5881
LuaType::TableConst(range) => check_generic_type_compact_table(
5982
context,
6083
source_generic,
6184
LuaMemberOwner::Element(range.clone()),
6285
check_guard.next_level()?,
6386
),
64-
LuaType::Ref(_) | LuaType::Def(_) => {
87+
LuaType::Ref(ref_id) | LuaType::Def(ref_id) => {
6588
if is_tpl {
6689
return Ok(());
6790
}
68-
check_ref_type_compact(
91+
check_generic_type_compact_ref_type(
6992
context,
70-
&source_generic.get_base_type_id(),
71-
compact_type,
93+
source_generic,
94+
ref_id,
7295
check_guard.next_level()?,
7396
)
7497
}
@@ -196,3 +219,47 @@ fn check_generic_type_compact_table(
196219

197220
Ok(())
198221
}
222+
223+
fn check_generic_type_compact_ref_type(
224+
context: &TypeCheckContext,
225+
source_generic: &LuaGenericType,
226+
ref_id: &LuaTypeDeclId,
227+
check_guard: TypeCheckGuard,
228+
) -> TypeCheckResult {
229+
let type_decl = context
230+
.db
231+
.get_type_index()
232+
.get_type_decl(ref_id)
233+
.ok_or(TypeCheckFailReason::TypeNotMatch)?;
234+
235+
if type_decl.is_alias() {
236+
if let Some(origin_type) = type_decl.get_alias_origin(context.db, None) {
237+
return check_general_type_compact(
238+
context,
239+
&LuaType::Generic(source_generic.clone().into()),
240+
&origin_type,
241+
check_guard.next_level()?,
242+
);
243+
}
244+
}
245+
246+
for super_type in context
247+
.db
248+
.get_type_index()
249+
.get_super_types(ref_id)
250+
.unwrap_or_default()
251+
{
252+
if check_generic_type_compact(
253+
context,
254+
source_generic,
255+
&super_type,
256+
check_guard.next_level()?,
257+
)
258+
.is_ok()
259+
{
260+
return Ok(());
261+
}
262+
}
263+
264+
Err(TypeCheckFailReason::TypeNotMatch)
265+
}

crates/emmylua_code_analysis/src/semantic/type_check/test.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,43 @@ mod test {
169169
"#
170170
));
171171
}
172+
173+
#[test]
174+
fn test_issue_790() {
175+
let mut ws = VirtualWorkspace::new();
176+
ws.def(
177+
r#"
178+
---@class Holder<T>
179+
180+
---@class StringHolder: Holder<string>
181+
182+
---@class NumberHolder: Holder<number>
183+
184+
---@class StringHolderWith<T>: Holder<string>
185+
186+
---@generic T
187+
---@param a T
188+
---@param b T
189+
function test(a, b) end
190+
"#,
191+
);
192+
193+
assert!(!ws.check_code_for(
194+
DiagnosticCode::ParamTypeNotMatch,
195+
r#"
196+
---@type Holder<string>, NumberHolder
197+
local a, b
198+
test(a, b)
199+
"#
200+
));
201+
202+
assert!(ws.check_code_for(
203+
DiagnosticCode::ParamTypeNotMatch,
204+
r#"
205+
---@type Holder<string>, StringHolderWith<table>
206+
local a, b
207+
test(a, b)
208+
"#
209+
));
210+
}
172211
}

0 commit comments

Comments
 (0)