Skip to content

Commit aef7e02

Browse files
committed
update infer conditional
1 parent b500bab commit aef7e02

File tree

2 files changed

+51
-10
lines changed

2 files changed

+51
-10
lines changed

crates/emmylua_code_analysis/src/compilation/test/generic_test.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,4 +719,36 @@ mod test {
719719
// assert_eq!(result_ty, ws.ty("number"));
720720
// }
721721
}
722+
723+
#[test]
724+
fn test_generic_extends_function_params() {
725+
let mut ws = VirtualWorkspace::new();
726+
ws.def(
727+
r#"
728+
---@alias ConstructorParameters<T> T extends new (fun(...: infer P): any) and P or never
729+
---@alias Parameters<T extends function> T extends (fun(...: infer P): any) and P or never
730+
731+
---@alias Procedure fun(...: any[]): any
732+
733+
---@alias MockParameters<T> T extends table and ConstructorParameters<T> or T extends Procedure and Parameters<T> or never
734+
735+
---@class Mock<T>
736+
---@field calls MockParameters<T>[]
737+
"#,
738+
);
739+
ws.def(
740+
r#"
741+
---@generic T: Procedure
742+
---@param a T?
743+
---@return Mock<T>
744+
local function fn(a)
745+
end
746+
747+
result = fn().calls
748+
"#,
749+
);
750+
751+
let result_ty = ws.expr_ty("result");
752+
assert_eq!(ws.humanize_type(result_ty), "any[][]");
753+
}
722754
}

crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -639,17 +639,16 @@ fn collect_infer_assignments(
639639
};
640640
rest_types.push(source_ty.clone());
641641
}
642+
let ty = match rest_types.len() {
643+
0 => LuaType::Never,
644+
1 => rest_types[0].clone(),
645+
_ => LuaType::Tuple(
646+
LuaTupleType::new(rest_types, LuaTupleStatus::InferResolve)
647+
.into(),
648+
),
649+
};
642650

643-
let tuple_ty = LuaType::Tuple(
644-
LuaTupleType::new(rest_types, LuaTupleStatus::InferResolve)
645-
.into(),
646-
);
647-
if !collect_infer_assignments(
648-
db,
649-
&tuple_ty,
650-
pattern_ty,
651-
assignments,
652-
) {
651+
if !collect_infer_assignments(db, &ty, pattern_ty, assignments) {
653652
return false;
654653
}
655654
}
@@ -683,6 +682,16 @@ fn collect_infer_assignments(
683682
false
684683
}
685684
}
685+
LuaType::Ref(type_decl_id) => {
686+
if let Some(type_decl) = db.get_type_index().get_type_decl(type_decl_id) {
687+
if type_decl.is_alias()
688+
&& let Some(origin) = type_decl.get_alias_origin(db, None)
689+
{
690+
return collect_infer_assignments(db, &origin, &pattern, assignments);
691+
}
692+
}
693+
false
694+
}
686695
_ => false,
687696
}
688697
}

0 commit comments

Comments
 (0)