Skip to content

Commit 65d3595

Browse files
committed
fix up tests
1 parent b2d583a commit 65d3595

File tree

4 files changed

+32
-19
lines changed

4 files changed

+32
-19
lines changed

test/probprog/blr.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,32 @@
11
using Reactant, Test, Random
2-
using Reactant: ProbProg
2+
using Reactant: ProbProg, ReactantRNG
33

4-
function normal(rng, μ, σ, shape)
5-
return μ .+ σ .* randn(rng, shape)
6-
end
4+
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
75

8-
function bernoulli_logit(rng, logit, shape)
9-
return rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit)))
6+
function normal_logpdf(x, μ, σ, _)
7+
return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .*.^ 2)))
108
end
119

12-
function blr(seed, N, K)
13-
rng = Random.default_rng()
14-
Random.seed!(rng, seed)
10+
bernoulli_logit(rng, logit, shape) = rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit)))
11+
bernoulli_logit_logpdf(x, logit, _) = sum(x .* logit .- log1p.(exp.(logit)))
1512

13+
# https://github.com/facebookresearch/pplbench/blob/main/pplbench/models/logistic_regression.py
14+
function blr(rng, N, K)
1615
# α ~ Normal(0, 10, size = 1)
17-
α = ProbProg.sample(normal, rng, 0, 10, (1,); symbol=)
16+
α = ProbProg.sample(rng, normal, 0, 10, (1,); symbol=, logpdf=normal_logpdf)
1817

1918
# β ~ Normal(0, 2.5, size = K)
20-
β = ProbProg.sample(normal, rng, 0, 2.5, (K,); symbol=)
19+
β = ProbProg.sample(rng, normal, 0, 2.5, (K,); symbol=, logpdf=normal_logpdf)
2120

2221
# X ~ Normal(0, 10, size = (N, K))
23-
X = ProbProg.sample(normal, rng, 0, 10, (N, K); symbol=:X)
22+
X = ProbProg.sample(rng, normal, 0, 10, (N, K); symbol=:X, logpdf=normal_logpdf)
2423

2524
# μ = α .+ X * β
2625
μ = α .+ X * β
2726

28-
Y = ProbProg.sample(bernoulli_logit, rng, μ, (N,); symbol=:Y)
27+
Y = ProbProg.sample(
28+
rng, bernoulli_logit, μ, (N,); symbol=:Y, logpdf=bernoulli_logit_logpdf
29+
)
2930

3031
return Y
3132
end
@@ -35,7 +36,10 @@ end
3536
K = 3 # number of features
3637
seed = Reactant.to_rarray(UInt64[1, 4])
3738

38-
trace = ProbProg.simulate(blr, seed, N, K)
39+
rng = ReactantRNG(seed)
40+
41+
trace, _ = ProbProg.simulate(rng, blr, N, K)
42+
println(trace)
3943

40-
@test size(Array(trace.retval)) == (N,)
44+
@test size(trace.retval[1]) == (N,)
4145
end

test/probprog/generate.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ using Reactant, Test, Random, Statistics
22
using Reactant: ProbProg, ReactantRNG
33

44
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
5-
normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2)
5+
6+
function normal_logpdf(x, μ, σ, _)
7+
return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .*.^ 2)))
8+
end
69

710
function model(rng, μ, σ, shape)
811
s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf)
@@ -49,7 +52,7 @@ end
4952

5053
constraint1 = Dict{Symbol,Any}(:s => (fill(0.1, shape),))
5154

52-
constrained_symbols = collect(keys(constraint1)) # This doesn't change
55+
constrained_symbols = Set(keys(constraint1))
5356

5457
constraint_ptr1 = Reactant.ConcreteRNumber(
5558
reinterpret(UInt64, pointer_from_objref(constraint1))

test/probprog/linear_regression.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ using Reactant: ProbProg, ReactantRNG
44
# Reference: https://www.gen.dev/docs/stable/getting_started/linear_regression/
55

66
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
7-
normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2)
7+
8+
function normal_logpdf(x, μ, σ, _)
9+
return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .*.^ 2)))
10+
end
811

912
function my_model(rng, xs)
1013
slope = ProbProg.sample(

test/probprog/simulate.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ using Reactant, Test, Random
22
using Reactant: ProbProg, ReactantRNG
33

44
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
5-
normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2)
5+
6+
function normal_logpdf(x, μ, σ, _)
7+
return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .*.^ 2)))
8+
end
69

710
function product_two_normals(rng, μ, σ, shape)
811
a = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:a, logpdf=normal_logpdf)

0 commit comments

Comments
 (0)