|
1 |
| -using Reactant, Test, Random, StableRNGs, Statistics |
| 1 | +using Reactant, Test, Random |
2 | 2 | using Reactant: ProbProg
|
3 |
| -using Libdl: Libdl |
4 | 3 |
|
5 |
| -normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) |
6 |
| -bernoulli_logit(rng, logit, shape) = rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit))) |
| 4 | +function normal(rng, μ, σ, shape) |
| 5 | + return μ .+ σ .* randn(rng, shape) |
| 6 | +end |
7 | 7 |
|
8 |
| -function blr(seed, N, K) |
9 |
| - function model(seed, N, K) |
10 |
| - rng = Random.default_rng() |
11 |
| - Random.seed!(rng, seed) |
| 8 | +function bernoulli_logit(rng, logit, shape) |
| 9 | + return rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit))) |
| 10 | +end |
12 | 11 |
|
13 |
| - # α ~ Normal(0, 10, size = 1) |
14 |
| - α = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=:α) |
| 12 | +function blr(seed, N, K) |
| 13 | + rng = Random.default_rng() |
| 14 | + Random.seed!(rng, seed) |
15 | 15 |
|
16 |
| - # β ~ Normal(0, 2.5, size = K) |
17 |
| - β = ProbProg.sample!(normal, rng, 0, 2.5, (K,); symbol=:β) |
| 16 | + # α ~ Normal(0, 10, size = 1) |
| 17 | + α = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=:α) |
18 | 18 |
|
19 |
| - # X ~ Normal(0, 10, size = (N, K)) |
20 |
| - X = ProbProg.sample!(normal, rng, 0, 10, (N, K); symbol=:X) # TODO: double check transpose |
| 19 | + # β ~ Normal(0, 2.5, size = K) |
| 20 | + β = ProbProg.sample!(normal, rng, 0, 2.5, (K,); symbol=:β) |
21 | 21 |
|
22 |
| - # μ = α .+ X * β |
23 |
| - μ = α .+ X * β |
| 22 | + # X ~ Normal(0, 10, size = (N, K)) |
| 23 | + X = ProbProg.sample!(normal, rng, 0, 10, (N, K); symbol=:X) |
24 | 24 |
|
25 |
| - ProbProg.sample!(bernoulli_logit, rng, μ, (N,); symbol=:Y) |
| 25 | + # μ = α .+ X * β |
| 26 | + μ = α .+ X * β |
26 | 27 |
|
27 |
| - return μ |
28 |
| - end |
| 28 | + Y = ProbProg.sample!(bernoulli_logit, rng, μ, (N,); symbol=:Y) |
29 | 29 |
|
30 |
| - return ProbProg.simulate!(model, seed, N, K) |
| 30 | + return Y |
31 | 31 | end
|
32 | 32 |
|
33 | 33 | @testset "BLR" begin
|
34 | 34 | N = 5 # number of observations
|
35 | 35 | K = 3 # number of features
|
36 | 36 | seed = Reactant.to_rarray(UInt64[1, 4])
|
37 | 37 |
|
38 |
| - X = ProbProg.getTrace(@jit optimize = :probprog blr(seed, N, K)) |
39 |
| - ProbProg.print_trace(X) |
| 38 | + trace = ProbProg.create_trace() |
| 39 | + |
| 40 | + @test size( |
| 41 | + Array(@jit optimize = :probprog ProbProg.simulate!(blr, seed, N, K; trace)) |
| 42 | + ) == (N,) |
| 43 | + |
| 44 | + ProbProg.print_trace(trace) |
40 | 45 | end
|
0 commit comments