Skip to content

Commit aabd699

Browse files
committed
refactor infer_guard
Fix #793
1 parent 4c50b71 commit aabd699

File tree

14 files changed

+258
-165
lines changed

14 files changed

+258
-165
lines changed

crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/find_decl_function.rs

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use emmylua_parser::{LuaAstNode, LuaIndexKey, LuaIndexMemberExpr};
24
use rowan::TextRange;
35
use smol_str::SmolStr;
@@ -24,11 +26,11 @@ pub struct FindFunctionType {
2426
}
2527

2628
#[derive(Debug)]
27-
struct DeepGuard {
29+
struct DeepLevel {
2830
deep: usize,
2931
}
3032

31-
impl DeepGuard {
33+
impl DeepLevel {
3234
pub fn new() -> Self {
3335
Self { deep: 0 }
3436
}
@@ -64,13 +66,13 @@ pub fn find_decl_function_type(
6466
index_member_expr
6567
.get_prefix_expr()
6668
.ok_or(InferFailReason::None)?;
67-
let mut deep_guard = DeepGuard::new();
69+
let mut deep_guard = DeepLevel::new();
6870
let reason = match find_function_type_by_member_key(
6971
db,
7072
cache,
7173
prefix_type,
7274
index_member_expr.clone(),
73-
&mut InferGuard::new(),
75+
&InferGuard::new(),
7476
&mut deep_guard,
7577
) {
7678
Ok(member_type) => {
@@ -83,13 +85,13 @@ pub fn find_decl_function_type(
8385
Err(err) => return Err(err),
8486
};
8587

86-
let mut deep_guard = DeepGuard::new();
88+
let mut deep_guard = DeepLevel::new();
8789
match find_function_type_by_operator(
8890
db,
8991
cache,
9092
prefix_type,
9193
index_member_expr,
92-
&mut InferGuard::new(),
94+
&InferGuard::new(),
9395
&mut deep_guard,
9496
) {
9597
Ok(member_type) => {
@@ -110,8 +112,8 @@ fn find_function_type_by_member_key(
110112
cache: &mut LuaInferCache,
111113
prefix_type: &LuaType,
112114
index_expr: LuaIndexMemberExpr,
113-
infer_guard: &mut InferGuard,
114-
deep_guard: &mut DeepGuard,
115+
infer_guard: &Arc<InferGuard>,
116+
deep_guard: &mut DeepLevel,
115117
) -> FunctionTypeResult {
116118
match &prefix_type {
117119
LuaType::Ref(decl_id) => find_custom_type_function_member(
@@ -135,10 +137,10 @@ fn find_function_type_by_member_key(
135137
find_object_function_member(db, cache, object_type, index_expr)
136138
}
137139
LuaType::Union(union_type) => {
138-
find_union_function_member(db, cache, union_type, index_expr, deep_guard)
140+
find_union_function_member(db, cache, union_type, index_expr, infer_guard, deep_guard)
139141
}
140142
LuaType::Generic(generic_type) => {
141-
find_generic_member(db, cache, generic_type, index_expr, deep_guard)
143+
find_generic_member(db, cache, generic_type, index_expr, infer_guard, deep_guard)
142144
}
143145
LuaType::Instance(inst) => {
144146
find_instance_member_decl_type(db, cache, inst, index_expr, infer_guard, deep_guard)
@@ -177,8 +179,8 @@ fn find_custom_type_function_member(
177179
cache: &mut LuaInferCache,
178180
prefix_type_id: LuaTypeDeclId,
179181
index_expr: LuaIndexMemberExpr,
180-
infer_guard: &mut InferGuard,
181-
deep_guard: &mut DeepGuard,
182+
infer_guard: &Arc<InferGuard>,
183+
deep_guard: &mut DeepLevel,
182184
) -> FunctionTypeResult {
183185
infer_guard.check(&prefix_type_id)?;
184186
let type_index = db.get_type_index();
@@ -321,7 +323,8 @@ fn find_union_function_member(
321323
cache: &mut LuaInferCache,
322324
union_type: &LuaUnionType,
323325
index_expr: LuaIndexMemberExpr,
324-
deep_guard: &mut DeepGuard,
326+
infer_guard: &Arc<InferGuard>,
327+
deep_guard: &mut DeepLevel,
325328
) -> FunctionTypeResult {
326329
let mut member_types = Vec::new();
327330
for sub_type in union_type.into_vec() {
@@ -330,7 +333,7 @@ fn find_union_function_member(
330333
cache,
331334
&sub_type,
332335
index_expr.clone(),
333-
&mut InferGuard::new(),
336+
infer_guard,
334337
deep_guard,
335338
);
336339
if let Ok(typ) = result
@@ -349,7 +352,8 @@ fn index_generic_members_from_super_generics(
349352
type_decl_id: &LuaTypeDeclId,
350353
substitutor: &TypeSubstitutor,
351354
index_expr: LuaIndexMemberExpr,
352-
deep_guard: &mut DeepGuard,
355+
infer_guard: &Arc<InferGuard>,
356+
deep_guard: &mut DeepLevel,
353357
) -> Option<LuaType> {
354358
let type_index = db.get_type_index();
355359

@@ -367,7 +371,7 @@ fn index_generic_members_from_super_generics(
367371
cache,
368372
&super_type,
369373
index_expr.clone(),
370-
&mut InferGuard::new(),
374+
&infer_guard.fork(),
371375
deep_guard,
372376
)
373377
.ok()
@@ -382,7 +386,8 @@ fn find_generic_member(
382386
cache: &mut LuaInferCache,
383387
generic_type: &LuaGenericType,
384388
index_expr: LuaIndexMemberExpr,
385-
deep_guard: &mut DeepGuard,
389+
infer_guard: &Arc<InferGuard>,
390+
deep_guard: &mut DeepLevel,
386391
) -> FunctionTypeResult {
387392
let base_type = generic_type.get_base_type();
388393

@@ -395,6 +400,7 @@ fn find_generic_member(
395400
base_type_decl_id,
396401
&substitutor,
397402
index_expr.clone(),
403+
infer_guard,
398404
deep_guard,
399405
);
400406
if let Some(result) = result {
@@ -407,7 +413,7 @@ fn find_generic_member(
407413
cache,
408414
&base_type,
409415
index_expr,
410-
&mut InferGuard::new(),
416+
infer_guard,
411417
deep_guard,
412418
)?;
413419

@@ -419,8 +425,8 @@ fn find_instance_member_decl_type(
419425
cache: &mut LuaInferCache,
420426
inst: &LuaInstanceType,
421427
index_expr: LuaIndexMemberExpr,
422-
infer_guard: &mut InferGuard,
423-
deep_guard: &mut DeepGuard,
428+
infer_guard: &Arc<InferGuard>,
429+
deep_guard: &mut DeepLevel,
424430
) -> FunctionTypeResult {
425431
let origin_type = inst.get_base();
426432
find_function_type_by_member_key(
@@ -438,8 +444,8 @@ fn find_function_type_by_operator(
438444
cache: &mut LuaInferCache,
439445
prefix_type: &LuaType,
440446
index_expr: LuaIndexMemberExpr,
441-
infer_guard: &mut InferGuard,
442-
deep_guard: &mut DeepGuard,
447+
infer_guard: &Arc<InferGuard>,
448+
deep_guard: &mut DeepLevel,
443449
) -> FunctionTypeResult {
444450
match &prefix_type {
445451
LuaType::TableConst(in_filed) => {
@@ -467,13 +473,18 @@ fn find_function_type_by_operator(
467473
}
468474
LuaType::Object(object) => infer_member_by_index_object(db, cache, object, index_expr),
469475
LuaType::Union(union) => {
470-
find_member_by_index_union(db, cache, union, index_expr, deep_guard)
471-
}
472-
LuaType::Intersection(intersection) => {
473-
find_member_by_index_intersection(db, cache, intersection, index_expr, deep_guard)
476+
find_member_by_index_union(db, cache, union, index_expr, infer_guard, deep_guard)
474477
}
478+
LuaType::Intersection(intersection) => find_member_by_index_intersection(
479+
db,
480+
cache,
481+
intersection,
482+
index_expr,
483+
infer_guard,
484+
deep_guard,
485+
),
475486
LuaType::Generic(generic) => {
476-
find_member_by_index_generic(db, cache, generic, index_expr, deep_guard)
487+
find_member_by_index_generic(db, cache, generic, index_expr, infer_guard, deep_guard)
477488
}
478489
LuaType::TableGeneric(table_generic) => {
479490
find_member_by_index_table_generic(db, cache, table_generic, index_expr)
@@ -566,8 +577,8 @@ fn find_member_by_index_custom_type(
566577
cache: &mut LuaInferCache,
567578
prefix_type_id: &LuaTypeDeclId,
568579
index_expr: LuaIndexMemberExpr,
569-
infer_guard: &mut InferGuard,
570-
deep_guard: &mut DeepGuard,
580+
infer_guard: &Arc<InferGuard>,
581+
deep_guard: &mut DeepLevel,
571582
) -> FunctionTypeResult {
572583
infer_guard.check(prefix_type_id)?;
573584
let type_index = db.get_type_index();
@@ -680,7 +691,8 @@ fn find_member_by_index_union(
680691
cache: &mut LuaInferCache,
681692
union: &LuaUnionType,
682693
index_expr: LuaIndexMemberExpr,
683-
deep_guard: &mut DeepGuard,
694+
infer_guard: &Arc<InferGuard>,
695+
deep_guard: &mut DeepLevel,
684696
) -> FunctionTypeResult {
685697
let mut member_type = LuaType::Unknown;
686698
for member in union.into_vec() {
@@ -689,7 +701,7 @@ fn find_member_by_index_union(
689701
cache,
690702
&member,
691703
index_expr.clone(),
692-
&mut InferGuard::new(),
704+
&infer_guard.fork(),
693705
deep_guard,
694706
);
695707
match result {
@@ -715,15 +727,16 @@ fn find_member_by_index_intersection(
715727
cache: &mut LuaInferCache,
716728
intersection: &LuaIntersectionType,
717729
index_expr: LuaIndexMemberExpr,
718-
deep_guard: &mut DeepGuard,
730+
infer_guard: &Arc<InferGuard>,
731+
deep_guard: &mut DeepLevel,
719732
) -> FunctionTypeResult {
720733
for member in intersection.get_types() {
721734
match find_function_type_by_operator(
722735
db,
723736
cache,
724737
member,
725738
index_expr.clone(),
726-
&mut InferGuard::new(),
739+
&infer_guard.fork(),
727740
deep_guard,
728741
) {
729742
Ok(ty) => return Ok(ty),
@@ -742,7 +755,8 @@ fn find_member_by_index_generic(
742755
cache: &mut LuaInferCache,
743756
generic: &LuaGenericType,
744757
index_expr: LuaIndexMemberExpr,
745-
deep_guard: &mut DeepGuard,
758+
infer_guard: &Arc<InferGuard>,
759+
deep_guard: &mut DeepLevel,
746760
) -> FunctionTypeResult {
747761
let base_type = generic.get_base_type();
748762
let type_decl_id = if let LuaType::Ref(id) = base_type {
@@ -763,7 +777,7 @@ fn find_member_by_index_generic(
763777
cache,
764778
&instantiate_type_generic(db, &origin_type, &substitutor),
765779
index_expr.clone(),
766-
&mut InferGuard::new(),
780+
&infer_guard.fork(),
767781
deep_guard,
768782
);
769783
}
@@ -807,7 +821,7 @@ fn find_member_by_index_generic(
807821
cache,
808822
&instantiate_type_generic(db, &super_type, &substitutor),
809823
index_expr.clone(),
810-
&mut InferGuard::new(),
824+
&infer_guard.fork(),
811825
deep_guard,
812826
);
813827
match result {

crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ pub fn try_resolve_closure_parent_params(
271271
closure_params,
272272
&member_type,
273273
self_type,
274-
&mut InferGuard::new(),
274+
&InferGuard::new(),
275275
)
276276
}
277277

@@ -280,7 +280,7 @@ fn resolve_closure_member_type(
280280
closure_params: &UnResolveParentClosureParams,
281281
member_type: &LuaType,
282282
self_type: Option<LuaType>,
283-
infer_guard: &mut InferGuard,
283+
infer_guard: &Arc<InferGuard>,
284284
) -> ResolveResult {
285285
match &member_type {
286286
LuaType::DocFunction(doc_func) => {
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
use std::{cell::RefCell, collections::HashSet, sync::Arc};
2+
3+
use crate::{InferFailReason, LuaTypeDeclId};
4+
5+
/// Guard to prevent infinite recursion
6+
/// Some type may reference itself, so we need to check if we have already inferred this type
7+
///
8+
/// This guard supports inheritance through Rc parent chain, allowing child guards to see
9+
/// parent's visited types while maintaining their own independent tracking for branch protection.
10+
#[derive(Debug, Clone)]
11+
pub struct InferGuard {
12+
/// Current level's visited types
13+
current: RefCell<HashSet<LuaTypeDeclId>>,
14+
/// Parent guard (shared reference)
15+
parent: Option<Arc<InferGuard>>,
16+
}
17+
18+
impl InferGuard {
19+
pub fn new() -> Arc<Self> {
20+
Arc::new(Self {
21+
current: RefCell::new(HashSet::default()),
22+
parent: None,
23+
})
24+
}
25+
26+
/// Create a child guard that inherits from parent
27+
/// This allows branching while preventing infinite recursion across the entire call stack
28+
pub fn fork(self: &Arc<Self>) -> Arc<Self> {
29+
Arc::new(Self {
30+
current: RefCell::new(HashSet::default()),
31+
parent: Some(Arc::clone(self)),
32+
})
33+
}
34+
35+
/// Create a child guard from a non-Arc guard
36+
/// This is a convenience method for when you have a stack-allocated guard
37+
pub fn fork_owned(&self) -> Self {
38+
Self {
39+
current: RefCell::new(HashSet::default()),
40+
parent: self.parent.clone(),
41+
}
42+
}
43+
44+
/// Check if a type has been visited in current branch or any parent
45+
pub fn check(&self, type_id: &LuaTypeDeclId) -> Result<(), InferFailReason> {
46+
// Check in all parent levels first
47+
if self.contains_in_parents(type_id) {
48+
return Err(InferFailReason::RecursiveInfer);
49+
}
50+
51+
// Check in current level
52+
let mut current = self.current.borrow_mut();
53+
if current.contains(type_id) {
54+
return Err(InferFailReason::RecursiveInfer);
55+
}
56+
57+
// Mark as visited in current level
58+
current.insert(type_id.clone());
59+
Ok(())
60+
}
61+
62+
/// Check if a type has been visited in parent chain
63+
fn contains_in_parents(&self, type_id: &LuaTypeDeclId) -> bool {
64+
let mut current_parent = self.parent.as_ref();
65+
while let Some(parent) = current_parent {
66+
if parent.current.borrow().contains(type_id) {
67+
return true;
68+
}
69+
current_parent = parent.parent.as_ref();
70+
}
71+
false
72+
}
73+
74+
/// Check if a type has been visited (without modifying the guard)
75+
pub fn contains(&self, type_id: &LuaTypeDeclId) -> bool {
76+
self.current.borrow().contains(type_id) || self.contains_in_parents(type_id)
77+
}
78+
79+
/// Get the depth of current level
80+
pub fn current_depth(&self) -> usize {
81+
self.current.borrow().len()
82+
}
83+
84+
/// Get the total depth of the entire guard chain
85+
pub fn total_depth(&self) -> usize {
86+
let mut depth = self.current.borrow().len();
87+
let mut current_parent = self.parent.as_ref();
88+
while let Some(parent) = current_parent {
89+
depth += parent.current.borrow().len();
90+
current_parent = parent.parent.as_ref();
91+
}
92+
depth
93+
}
94+
95+
/// Get the level of the guard chain (how many parents)
96+
pub fn level(&self) -> usize {
97+
let mut level = 0;
98+
let mut current_parent = self.parent.as_ref();
99+
while let Some(parent) = current_parent {
100+
level += 1;
101+
current_parent = parent.parent.as_ref();
102+
}
103+
level
104+
}
105+
}

0 commit comments

Comments
 (0)