Skip to content

Commit 1c5297c

Browse files
committed
minor changes
1 parent 46e0f6b commit 1c5297c

File tree

3 files changed

+36
-29
lines changed

3 files changed

+36
-29
lines changed

src/ProbProg.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,6 @@ module ProbProg
33
using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray
44
using Enzyme
55

6-
function createTrace()
7-
return Dict{Symbol,Any}()
8-
end
9-
106
function addSampleToTraceLowered(
117
trace_ptr_ptr::Ptr{Ptr{Any}},
128
symbol_ptr_ptr::Ptr{Ptr{Any}},
@@ -26,6 +22,8 @@ function addSampleToTraceLowered(
2622
Float32
2723
elseif datatype_width == 64
2824
Float64
25+
elseif datatype_width == 1
26+
Bool
2927
else
3028
@ccall printf("Unsupported datatype width: %d\n"::Cstring, datatype_width::Cint)::Cvoid
3129
return nothing
@@ -268,6 +266,10 @@ end
268266
return result
269267
end
270268

269+
function create_trace()
270+
return Dict{Symbol,Any}()
271+
end
272+
271273
function print_trace(trace::Dict{Symbol,Any})
272274
println("### Probabilistic Program Trace ###")
273275
for (symbol, sample) in trace

test/probprog/blr.jl

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,45 @@
1-
using Reactant, Test, Random, StableRNGs, Statistics
1+
using Reactant, Test, Random
22
using Reactant: ProbProg
3-
using Libdl: Libdl
43

5-
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
6-
bernoulli_logit(rng, logit, shape) = rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit)))
4+
function normal(rng, μ, σ, shape)
5+
return μ .+ σ .* randn(rng, shape)
6+
end
77

8-
function blr(seed, N, K)
9-
function model(seed, N, K)
10-
rng = Random.default_rng()
11-
Random.seed!(rng, seed)
8+
function bernoulli_logit(rng, logit, shape)
9+
return rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit)))
10+
end
1211

13-
# α ~ Normal(0, 10, size = 1)
14-
α = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=)
12+
function blr(seed, N, K)
13+
rng = Random.default_rng()
14+
Random.seed!(rng, seed)
1515

16-
# β ~ Normal(0, 2.5, size = K)
17-
β = ProbProg.sample!(normal, rng, 0, 2.5, (K,); symbol=:β)
16+
# α ~ Normal(0, 10, size = 1)
17+
α = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=:α)
1818

19-
# X ~ Normal(0, 10, size = (N, K))
20-
X = ProbProg.sample!(normal, rng, 0, 10, (N, K); symbol=:X) # TODO: double check transpose
19+
# β ~ Normal(0, 2.5, size = K)
20+
β = ProbProg.sample!(normal, rng, 0, 2.5, (K,); symbol=:β)
2121

22-
# μ = α .+ X * β
23-
μ = α .+ X * β
22+
# X ~ Normal(0, 10, size = (N, K))
23+
X = ProbProg.sample!(normal, rng, 0, 10, (N, K); symbol=:X)
2424

25-
ProbProg.sample!(bernoulli_logit, rng, μ, (N,); symbol=:Y)
25+
# μ = α .+ X * β
26+
μ = α .+ X * β
2627

27-
return μ
28-
end
28+
Y = ProbProg.sample!(bernoulli_logit, rng, μ, (N,); symbol=:Y)
2929

30-
return ProbProg.simulate!(model, seed, N, K)
30+
return Y
3131
end
3232

3333
@testset "BLR" begin
3434
N = 5 # number of observations
3535
K = 3 # number of features
3636
seed = Reactant.to_rarray(UInt64[1, 4])
3737

38-
X = ProbProg.getTrace(@jit optimize = :probprog blr(seed, N, K))
39-
ProbProg.print_trace(X)
38+
trace = ProbProg.create_trace()
39+
40+
@test size(
41+
Array(@jit optimize = :probprog ProbProg.simulate!(blr, seed, N, K; trace))
42+
) == (N,)
43+
44+
ProbProg.print_trace(trace)
4045
end

test/probprog/simulate.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ using Reactant: ProbProg
2222
μ = Reactant.ConcreteRNumber(0.0)
2323
σ = Reactant.ConcreteRNumber(1.0)
2424

25-
trace = ProbProg.createTrace()
25+
trace = ProbProg.create_trace()
2626

2727
before = @code_hlo optimize = :no_enzyme simulate_model(trace, seed, μ, σ, shape)
2828
@test contains(repr(before), "enzyme.simulate")
@@ -40,7 +40,7 @@ using Reactant: ProbProg
4040
μ = Reactant.ConcreteRNumber(0.0)
4141
σ = Reactant.ConcreteRNumber(1.0)
4242

43-
trace = ProbProg.createTrace()
43+
trace = ProbProg.create_trace()
4444

4545
result = Array(@jit optimize = :probprog simulate_model(trace, seed, μ, σ, shape))
4646

@@ -57,7 +57,7 @@ using Reactant: ProbProg
5757
return ProbProg.sample!(op, x, y; symbol=:matmul)
5858
end
5959

60-
trace = ProbProg.createTrace()
60+
trace = ProbProg.create_trace()
6161
x = reshape(collect(Float64, 1:12), (4, 3))
6262
y = reshape(collect(Float64, 1:12), (4, 3))
6363
x_ra = Reactant.to_rarray(x)

0 commit comments

Comments
 (0)