Skip to content

Commit b545a93

Browse files
committed
Fix for VarNames with non-identity lenses
1 parent 8c3bff4 commit b545a93

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

src/context_implementations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ function tilde_assume!!(context, right, vn, vi)
131131
# change in the future.
132132
if should_auto_prefix(right)
133133
dppl_model = right.model.model # This isa DynamicPPL.Model
134-
prefixed_submodel_context = PrefixContext{getsym(vn)}(dppl_model.context)
134+
prefixed_submodel_context = PrefixContext{Symbol(vn)}(dppl_model.context)
135135
new_dppl_model = contextualize(dppl_model, prefixed_submodel_context)
136136
right = to_submodel(new_dppl_model, true)
137137
end

test/submodels.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,36 @@ using Test
105105
end
106106
end
107107

108+
@testset "Complex prefixes" begin
109+
mutable struct P
110+
a::Float64
111+
b::Float64
112+
end
113+
@model function f()
114+
x = Vector{Float64}(undef, 1)
115+
x[1] ~ Normal()
116+
y ~ Normal()
117+
return x[1]
118+
end
119+
@model function g()
120+
p = P(1.0, 2.0)
121+
p.a ~ to_submodel(f())
122+
p.b ~ Normal()
123+
return (p.a, p.b)
124+
end
125+
expected_vns = Set([
126+
@varname(var"p.a".x[1]), @varname(var"p.a".y), @varname(p.b)
127+
])
128+
@test Set(keys(VarInfo(g()))) == expected_vns
129+
130+
# Check that we can condition/fix on any of them from the outside
131+
for vn in expected_vns
132+
op_g = op(g(), (vn => 1.0))
133+
vi = VarInfo(op_g)
134+
@test Set(keys(vi)) == symdiff(expected_vns, Set([vn]))
135+
end
136+
end
137+
108138
@testset "Nested submodels" begin
109139
@model function f()
110140
x ~ Normal()

0 commit comments

Comments
 (0)