Skip to content

Commit e53fc7c

Browse files
committed
working vectorized blr test
1 parent 4e017d0 commit e53fc7c

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

src/ProbProg.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ end
152152
string(f),
153153
false;
154154
args_in_result=:all,
155+
do_transpose=false, # TODO: double check transpose
155156
argprefix,
156157
resprefix,
157158
resargprefix,

test/probprog/blr.jl

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,38 @@ using Reactant: ProbProg
33
using Libdl: Libdl
44

55
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
6+
bernoulli_logit(rng, logit, shape) = rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit)))
67

7-
function blr(seed, xs)
8-
function model(seed, xs)
8+
function blr(seed, N, K)
9+
function model(seed, N, K)
910
rng = Random.default_rng()
1011
Random.seed!(rng, seed)
11-
slope = ProbProg.sample!(normal, rng, 0, 2, (1,); symbol=:slope)
12-
intercept = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=:intercept)
13-
for (i, x) in enumerate(xs)
14-
ProbProg.sample!(normal, rng, slope * x + intercept, 1, (1,); symbol=Symbol("y-$i"))
15-
end
16-
return intercept
12+
13+
# α ~ Normal(0, 10, size = 1)
14+
α = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=)
15+
16+
# β ~ Normal(0, 2.5, size = K)
17+
β = ProbProg.sample!(normal, rng, 0, 2.5, (K,); symbol=)
18+
19+
# X ~ Normal(0, 10, size = (N, K))
20+
X = ProbProg.sample!(normal, rng, 0, 10, (N, K); symbol=:X) # TODO: double check transpose
21+
22+
# μ = α .+ X * β
23+
μ = α .+ X * β
24+
25+
ProbProg.sample!(bernoulli_logit, rng, μ, (N,); symbol=:Y)
26+
27+
return μ
1728
end
1829

19-
return ProbProg.simulate(model, seed, xs)
30+
return ProbProg.simulate!(model, seed, N, K)
2031
end
2132

2233
@testset "BLR" begin
23-
xs = [1, 2, 3, 4, 5]
34+
N = 5 # number of observations
35+
K = 3 # number of features
2436
seed = Reactant.to_rarray(UInt64[1, 4])
25-
X = ProbProg.getTrace(@jit optimize = :probprog blr(seed, xs))
26-
@test X[:_integrity_check] == 0x123456789abcdef
27-
@show X
37+
38+
X = ProbProg.getTrace(@jit optimize = :probprog blr(seed, N, K))
39+
ProbProg.print_trace(X)
2840
end

0 commit comments

Comments
 (0)