|
39 | 39 | normal_logpdf(trace.choices[:t][1], constraint[:s][1], 1.0, shape)
|
40 | 40 | @test weight ≈ expected_weight atol = 1e-6
|
41 | 41 | end
|
| 42 | + |
| 43 | + @testset "compiled" begin |
| 44 | + shape = (10,) |
| 45 | + seed = Reactant.to_rarray(UInt64[1, 4]) |
| 46 | + μ = Reactant.ConcreteRNumber(0.0) |
| 47 | + σ = Reactant.ConcreteRNumber(1.0) |
| 48 | + |
| 49 | + constraint1 = Dict{Symbol,Any}(:s => (fill(0.1, shape),)) |
| 50 | + |
| 51 | + constrained_symbols = collect(keys(constraint1)) # This doesn't change |
| 52 | + |
| 53 | + constraint_ptr1 = Reactant.ConcreteRNumber( |
| 54 | + reinterpret(UInt64, pointer_from_objref(constraint1)) |
| 55 | + ) |
| 56 | + |
| 57 | + wrapper_fn(constraint_ptr, seed, μ, σ) = ProbProg.generate_internal( |
| 58 | + model, seed, μ, σ, shape; constraint_ptr, constrained_symbols |
| 59 | + ) |
| 60 | + |
| 61 | + compiled_fn = @compile optimize = :probprog wrapper_fn(constraint_ptr1, seed, μ, σ) |
| 62 | + |
| 63 | + trace1, weight = compiled_fn(constraint_ptr1, seed, μ, σ) |
| 64 | + trace1 = unsafe_pointer_to_objref(Ptr{Any}(Array(trace1)[1])) |
| 65 | + |
| 66 | + constraint2 = Dict{Symbol,Any}(:s => (fill(0.2, shape),)) |
| 67 | + constraint_ptr2 = Reactant.ConcreteRNumber( |
| 68 | + reinterpret(UInt64, pointer_from_objref(constraint2)) |
| 69 | + ) |
| 70 | + |
| 71 | + trace2, _ = compiled_fn(constraint_ptr2, seed, μ, σ) |
| 72 | + trace2 = unsafe_pointer_to_objref(Ptr{Any}(Array(trace2)[1])) |
| 73 | + |
| 74 | + @test trace1.choices[:s] != trace2.choices[:s] |
| 75 | + end |
42 | 76 | end
|
0 commit comments