Skip to content

Commit d5ae280

Browse files
authored
:= to keep track of generated quantities (#594)
* added assignemnt operator * use special construct to hijack assume for `:=` * forgot to include change in previous commit * test the assignment operator * remove syntax incompat with older Julia versions * improved existing test
1 parent 4cf395b commit d5ae280

File tree

4 files changed

+94
-3
lines changed

4 files changed

+94
-3
lines changed

src/compiler.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,9 +371,33 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
371371
)
372372
end
373373

374+
# Modify the assignment operators.
375+
args_assign = getargs_coloneq(expr)
376+
if args_assign !== nothing
377+
L, R = args_assign
378+
return Base.remove_linenums!(
379+
generate_assign(
380+
generate_mainbody!(mod, found, L, warn),
381+
generate_mainbody!(mod, found, R, warn),
382+
),
383+
)
384+
end
385+
374386
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
375387
end
376388

389+
function generate_assign(left, right)
390+
right_expr = :($(TrackedValue)($right))
391+
tilde_expr = generate_tilde(left, right_expr)
392+
return quote
393+
if $(is_extracting_values)(__context__)
394+
$tilde_expr
395+
else
396+
$left = $right
397+
end
398+
end
399+
end
400+
377401
function generate_tilde_literal(left, right)
378402
# If the LHS is a literal, it is always an observation
379403
@gensym value

src/utils.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,20 @@ function getargs_assignment(expr::Expr)
163163
end
164164
end
165165

166+
"""
167+
getargs_coloneq(x)
168+
169+
Return the arguments `L` and `R`, if `x` is an expression of the form `L := R`, or `nothing`
170+
otherwise.
171+
"""
172+
getargs_coloneq(x) = nothing
173+
function getargs_coloneq(expr::Expr)
174+
return MacroTools.@match expr begin
175+
(L_ := R_) => (L, R)
176+
x_ => nothing
177+
end
178+
end
179+
166180
function to_namedtuple_expr(syms)
167181
length(syms) == 0 && return :(NamedTuple())
168182

src/values_as_in_model.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
struct TrackedValue{T}
2+
value::T
3+
end
4+
5+
is_tracked_value(::TrackedValue) = true
6+
is_tracked_value(::Any) = false
7+
8+
check_tilde_rhs(x::TrackedValue) = x
19

210
"""
311
ValuesAsInModelContext
@@ -29,6 +37,13 @@ function setchildcontext(context::ValuesAsInModelContext, child)
2937
return ValuesAsInModelContext(context.values, child)
3038
end
3139

40+
is_extracting_values(context::ValuesAsInModelContext) = true
41+
function is_extracting_values(context::AbstractContext)
42+
return is_extracting_values(NodeTrait(context), context)
43+
end
44+
is_extracting_values(::IsParent, ::AbstractContext) = false
45+
is_extracting_values(::IsLeaf, ::AbstractContext) = false
46+
3247
function Base.push!(context::ValuesAsInModelContext, vn::VarName, value)
3348
return setindex!(context.values, copy(value), vn)
3449
end
@@ -48,7 +63,12 @@ end
4863

4964
# `tilde_asssume`
5065
function tilde_assume(context::ValuesAsInModelContext, right, vn, vi)
51-
value, logp, vi = tilde_assume(childcontext(context), right, vn, vi)
66+
if is_tracked_value(right)
67+
value = right.value
68+
logp = zero(getlogp(vi))
69+
else
70+
value, logp, vi = tilde_assume(childcontext(context), right, vn, vi)
71+
end
5272
# Save the value.
5373
push!(context, vn, value)
5474
# Save the value.
@@ -58,7 +78,12 @@ end
5878
function tilde_assume(
5979
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi
6080
)
61-
value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
81+
if is_tracked_value(right)
82+
value = right.value
83+
logp = zero(getlogp(vi))
84+
else
85+
value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
86+
end
6287
# Save the value.
6388
push!(context, vn, value)
6489
# Pass on.

test/compiler.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,34 @@ module Issue537 end
687687
# And one explicit test for logging so know that is working.
688688
@model demo_with_logging() = @info "hi"
689689
model = demo_with_logging()
690-
@test model() == nothing
690+
@test model() === nothing
691+
# Make sure that the log message is present.
692+
@test_logs (:info, "hi") model()
693+
end
694+
695+
@testset ":= (tracked values)" begin
696+
@model function demo_tracked()
697+
x ~ Normal()
698+
y := 100 + x
699+
return (; x, y)
700+
end
701+
@model function demo_tracked_submodel()
702+
@submodel (x, y) = demo_tracked()
703+
return (; x, y)
704+
end
705+
for model in [demo_tracked(), demo_tracked_submodel()]
706+
# Make sure it's runnable and `y` is present in the return-value.
707+
@test model() isa NamedTuple{(:x, :y)}
708+
709+
# `VarInfo` should only contain `x`.
710+
varinfo = VarInfo(model)
711+
@test haskey(varinfo, @varname(x))
712+
@test !haskey(varinfo, @varname(y))
713+
714+
# While `values_as_in_model` should contain both `x` and `y`.
715+
values = values_as_in_model(model, deepcopy(varinfo))
716+
@test haskey(values, @varname(x))
717+
@test haskey(values, @varname(y))
718+
end
691719
end
692720
end

0 commit comments

Comments
 (0)