Skip to content

Commit b92a733

Browse files
committed
compiled generate test
1 parent f4a6415 commit b92a733

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

test/probprog/generate.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,38 @@ end
3939
normal_logpdf(trace.choices[:t][1], constraint[:s][1], 1.0, shape)
4040
@test weight expected_weight atol = 1e-6
4141
end
42+
43+
@testset "compiled" begin
44+
shape = (10,)
45+
seed = Reactant.to_rarray(UInt64[1, 4])
46+
μ = Reactant.ConcreteRNumber(0.0)
47+
σ = Reactant.ConcreteRNumber(1.0)
48+
49+
constraint1 = Dict{Symbol,Any}(:s => (fill(0.1, shape),))
50+
51+
constrained_symbols = collect(keys(constraint1)) # This doesn't change
52+
53+
constraint_ptr1 = Reactant.ConcreteRNumber(
54+
reinterpret(UInt64, pointer_from_objref(constraint1))
55+
)
56+
57+
wrapper_fn(constraint_ptr, seed, μ, σ) = ProbProg.generate_internal(
58+
model, seed, μ, σ, shape; constraint_ptr, constrained_symbols
59+
)
60+
61+
compiled_fn = @compile optimize = :probprog wrapper_fn(constraint_ptr1, seed, μ, σ)
62+
63+
trace1, weight = compiled_fn(constraint_ptr1, seed, μ, σ)
64+
trace1 = unsafe_pointer_to_objref(Ptr{Any}(Array(trace1)[1]))
65+
66+
constraint2 = Dict{Symbol,Any}(:s => (fill(0.2, shape),))
67+
constraint_ptr2 = Reactant.ConcreteRNumber(
68+
reinterpret(UInt64, pointer_from_objref(constraint2))
69+
)
70+
71+
trace2, _ = compiled_fn(constraint_ptr2, seed, μ, σ)
72+
trace2 = unsafe_pointer_to_objref(Ptr{Any}(Array(trace2)[1]))
73+
74+
@test trace1.choices[:s] != trace2.choices[:s]
75+
end
4276
end

0 commit comments

Comments
 (0)