@@ -3,26 +3,38 @@ using Reactant: ProbProg
3
3
using Libdl: Libdl
4
4
5
5
normal (rng, μ, σ, shape) = μ .+ σ .* randn (rng, shape)
6
+ bernoulli_logit (rng, logit, shape) = rand (rng, shape... ) .< (1 ./ (1 .+ exp .(- logit)))
6
7
7
- function blr (seed, xs )
8
- function model (seed, xs )
8
+ function blr (seed, N, K )
9
+ function model (seed, N, K )
9
10
rng = Random. default_rng ()
10
11
Random. seed! (rng, seed)
11
- slope = ProbProg. sample! (normal, rng, 0 , 2 , (1 ,); symbol= :slope )
12
- intercept = ProbProg. sample! (normal, rng, 0 , 10 , (1 ,); symbol= :intercept )
13
- for (i, x) in enumerate (xs)
14
- ProbProg. sample! (normal, rng, slope * x + intercept, 1 , (1 ,); symbol= Symbol (" y-$i " ))
15
- end
16
- return intercept
12
+
13
+ # α ~ Normal(0, 10, size = 1)
14
+ α = ProbProg. sample! (normal, rng, 0 , 10 , (1 ,); symbol= :α )
15
+
16
+ # β ~ Normal(0, 2.5, size = K)
17
+ β = ProbProg. sample! (normal, rng, 0 , 2.5 , (K,); symbol= :β )
18
+
19
+ # X ~ Normal(0, 10, size = (N, K))
20
+ X = ProbProg. sample! (normal, rng, 0 , 10 , (N, K); symbol= :X ) # TODO : double check transpose
21
+
22
+ # μ = α .+ X * β
23
+ μ = α .+ X * β
24
+
25
+ ProbProg. sample! (bernoulli_logit, rng, μ, (N,); symbol= :Y )
26
+
27
+ return μ
17
28
end
18
29
19
- return ProbProg. simulate (model, seed, xs )
30
+ return ProbProg. simulate! (model, seed, N, K )
20
31
end
21
32
22
33
@testset " BLR" begin
23
- xs = [1 , 2 , 3 , 4 , 5 ]
34
+ N = 5 # number of observations
35
+ K = 3 # number of features
24
36
seed = Reactant. to_rarray (UInt64[1 , 4 ])
25
- X = ProbProg . getTrace ( @jit optimize = :probprog blr (seed, xs))
26
- @test X[ :_integrity_check ] == 0x123456789abcdef
27
- @show X
37
+
38
+ X = ProbProg . getTrace ( @jit optimize = :probprog blr (seed, N, K))
39
+ ProbProg . print_trace (X)
28
40
end
0 commit comments