Skip to content

Commit 7743986

Browse files
authored
Merge pull request #288 from ReactiveBayes/270-error-for-iteration-within-constraint-function
Make conditional walk to only convert ranges inside indexed statements
2 parents 727b632 + 47397fd commit 7743986

File tree

3 files changed

+63
-3
lines changed

3 files changed

+63
-3
lines changed

src/model_macro.jl

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ function (w::guarded_walk)(f, x)
1414
return w.guard(x) ? x : walk(x, x -> w(f, x), f)
1515
end
1616

17+
not_enter_indexed_walk = guarded_walk((x) -> (x isa Expr && x.head == :ref) || (x isa Expr && x.head == :call && x.args[1] == :new))
18+
not_created_by = guarded_walk((x) -> (x isa Expr && !isempty(x.args) && x.args[1] == :created_by))
19+
1720
struct walk_until_occurrence{E}
1821
patterns::E
1922
end
2023

21-
not_enter_indexed_walk = guarded_walk((x) -> (x isa Expr && x.head == :ref) || (x isa Expr && x.head == :call && x.args[1] == :new))
22-
not_created_by = guarded_walk((x) -> (x isa Expr && !isempty(x.args) && x.args[1] == :created_by))
23-
2424
function (w::walk_until_occurrence{E})(f, x) where {E <: Tuple}
2525
return walk(x, z -> any(pattern -> @capture(x, $(pattern)), w.patterns) ? z : w(f, z), f)
2626
end
@@ -29,6 +29,55 @@ function (w::walk_until_occurrence{E})(f, x) where {E <: Expr}
2929
return walk(x, z -> @capture(x, $(w.patterns)) ? z : w(f, z), f)
3030
end
3131

32+
"""
33+
conditional_walk{C}
34+
35+
A walking strategy that applies a function only when a condition is met.
36+
The condition is checked at each level, and once met, the function is applied
37+
to all subsequent expressions in that branch.
38+
39+
# Fields
40+
- `condition`: Function that determines when to start applying the transformation
41+
42+
# Example
43+
```julia
44+
# Only convert : to CombinedRange when inside a :ref expression
45+
is_inside_indexing = conditional_walk(x -> x isa Expr && x.head == :ref)
46+
what_walk(::typeof(convert_to_combined_range)) = is_inside_indexing
47+
```
48+
"""
49+
struct conditional_walk{C}
50+
condition::C
51+
end
52+
53+
function (w::conditional_walk{C})(f, x, should_apply::Bool = false) where {C}
54+
# Check if we should flip the flag
55+
new_should_apply = should_apply || w.condition(x)
56+
57+
if new_should_apply
58+
# Apply the function
59+
return walk(x, z -> w(f, z, new_should_apply), f)
60+
else
61+
# Continue walking recursively, passing down the flag
62+
return walk(x, z -> w(f, z, new_should_apply), identity)
63+
end
64+
end
65+
66+
"""
67+
is_inside_indexing
68+
69+
A walking strategy that only applies transformations when inside an indexing expression.
70+
This is useful for converting ranged statements (like `1:10`) to `CombinedRange` only
71+
when they appear in array indexing contexts like `x[1:10]`.
72+
73+
# Example
74+
```julia
75+
# Only convert : to CombinedRange when inside a :ref expression
76+
what_walk(::typeof(convert_to_combined_range)) = is_inside_indexing
77+
```
78+
"""
79+
is_inside_indexing = conditional_walk(x -> x isa Expr && x.head == :ref)
80+
3281
what_walk(anything) = postwalk
3382

3483
"""

src/plugins/variational_constraints/variational_constraints_macro.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ function create_factorization_combinedrange(e::Expr)
9999
return e
100100
end
101101

102+
what_walk(::typeof(create_factorization_combinedrange)) = is_inside_indexing
103+
102104
__convert_to_indexed_statement(e::Symbol) = :(GraphPPL.IndexedVariable($(QuoteNode(e)), nothing))
103105
function __convert_to_indexed_statement(e::Expr)
104106
if @capture(e, (var_[index_]))

test/plugins/variational_constraints/variational_constraints_macro_tests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,15 @@ end
452452
q(x) = q(x[GraphPPL.CombinedRange(begin, end)])
453453
end
454454
@test_expression_generating apply_pipeline(input, create_factorization_combinedrange) output
455+
456+
# Test 2: Make sure create_factorization_combinedrange only works within indexed statements
457+
input = quote
458+
for i in 1:10
459+
q(state[i]) = q(state[i])
460+
end
461+
end
462+
output = input
463+
@test_expression_generating apply_pipeline(input, create_factorization_combinedrange) output
455464
end
456465

457466
@testitem "convert_variable_statements" begin

0 commit comments

Comments
 (0)