Skip to content

Commit f4a422b

Browse files
authored
Merge pull request #324 from probcomp/update-no-change
Add shorthand variants of `update` and `regenerate` that assume the arguments are unchanged
2 parents a4746a1 + 58ddffd commit f4a422b

File tree

3 files changed

+62
-0
lines changed

3 files changed

+62
-0
lines changed

src/gen_fn_interface.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,19 @@ function update(trace, args::Tuple, argdiffs::Tuple, ::ChoiceMap)
276276
error("Not implemented")
277277
end
278278

279+
"""
280+
(new_trace, weight, retdiff, discard) = update(trace, constraints::ChoiceMap)
281+
282+
Shorthand variant of
283+
[`update`](@ref update(::Any, ::Tuple, ::Tuple, ::ChoiceMap))
284+
which assumes the arguments are unchanged.
285+
"""
286+
function update(trace, constraints::ChoiceMap)
287+
args = get_args(trace)
288+
argdiffs = Tuple(NoChange() for _ in args)
289+
return update(trace, args, argdiffs, constraints)
290+
end
291+
279292
"""
280293
(new_trace, weight, retdiff) = regenerate(trace, args::Tuple, argdiffs::Tuple,
281294
selection::Selection)
@@ -307,6 +320,19 @@ function regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection)
307320
error("Not implemented")
308321
end
309322

323+
"""
324+
(new_trace, weight, retdiff) = regenerate(trace, selection::Selection)
325+
326+
Shorthand variant of
327+
[`regenerate`](@ref regenerate(::Any, ::Tuple, ::Tuple, ::Selection))
328+
which assumes the arguments are unchanged.
329+
"""
330+
function regenerate(trace, selection::Selection)
331+
args = get_args(trace)
332+
argdiffs = Tuple(NoChange() for _ in args)
333+
return regenerate(trace, args, argdiffs, selection)
334+
end
335+
310336
"""
311337
arg_grads = accumulate_param_gradients!(trace, retgrad=nothing, scale_factor=1.)
312338

test/gen_fn_interface.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
@gen function f_dynamic(x, y)
2+
z ~ normal(0, 1)
3+
return x + y + z
4+
end
5+
@gen (static) function f_static(x, y)
6+
z ~ normal(0, 1)
7+
return x + y + z
8+
end
9+
@load_generated_functions()
10+
11+
for (lang, f) in [:dynamic => f_dynamic,
12+
:static => f_static]
13+
@testset "update(...) shorthand assuming unchanged args ($lang modeling lang)" begin
14+
trace0 = simulate(f, (5, 6))
15+
16+
constraints = choicemap((:z, 0))
17+
trace1, _, _, discard = update(trace0, constraints)
18+
# The main test is that the shorthand version runs without crashing,
19+
# which is already shown by the time we get here. Beyond that, let's
20+
# sanity-check that `update` did what it's supposed to.
21+
@test get_args(trace1) == (5, 6)
22+
@test trace1[:z] == 0
23+
@test :z in keys(get_values_shallow(discard))
24+
end
25+
26+
@testset "regenerate(...) shorthand assuming unchanged args ($lang modeling lang)" begin
27+
trace0 = simulate(f, (5, 6))
28+
trace1, _, _ = regenerate(trace0, select(:z))
29+
# The main test is that the shorthand version runs without crashing,
30+
# which is already shown by the time we get here. Beyond that, let's
31+
# sanity-check that `regenerate` did what it's supposed to.
32+
@test get_args(trace1) == (5, 6)
33+
@test trace1[:z] != trace0[:z]
34+
end
35+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ include("autodiff.jl")
6868
include("diff.jl")
6969
include("selection.jl")
7070
include("assignment.jl")
71+
include("gen_fn_interface.jl")
7172
include("dsl/dsl.jl")
7273
include("optional_args.jl")
7374
include("static_ir/static_ir.jl")

0 commit comments

Comments
 (0)