Skip to content

Commit 91a0850

Browse files
committed
API change
1 parent d707053 commit 91a0850

File tree

6 files changed

+106
-109
lines changed

6 files changed

+106
-109
lines changed

src/ProbProg.jl

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
module ProbProg
22

3-
using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray
3+
using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray, AbstractConcreteNumber
4+
using ..Compiler: @jit
45
using Enzyme
56

7+
mutable struct ProbProgTrace
8+
choices::Dict{Symbol,Any}
9+
retval::Any
10+
11+
function ProbProgTrace()
12+
return new(Dict{Symbol,Any}(), nothing)
13+
end
14+
end
15+
616
function addSampleToTraceLowered(
717
trace_ptr_ptr::Ptr{Ptr{Any}},
818
symbol_ptr_ptr::Ptr{Ptr{Any}},
@@ -31,9 +41,9 @@ function addSampleToTraceLowered(
3141

3242
typed_ptr = Ptr{julia_type}(sample_ptr)
3343
if num_dims == 0
34-
trace[symbol] = unsafe_load(typed_ptr)
44+
trace.choices[symbol] = unsafe_load(typed_ptr)
3545
else
36-
trace[symbol] = copy(unsafe_wrap(Array, typed_ptr, Tuple(shape_array)))
46+
trace.choices[symbol] = copy(unsafe_wrap(Array, typed_ptr, Tuple(shape_array)))
3747
end
3848

3949
return nothing
@@ -52,7 +62,7 @@ function __init__()
5262
return nothing
5363
end
5464

55-
@noinline function sample!(
65+
function sample(
5666
f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample")
5767
) where {Nargs}
5868
argprefix::Symbol = gensym("samplearg")
@@ -132,7 +142,12 @@ end
132142
return result
133143
end
134144

135-
@noinline function generate!(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
145+
function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
146+
res = @jit optimize = :probprog generate_internal(f, args...)
147+
return res isa AbstractConcreteArray ? Array(res) : res
148+
end
149+
150+
function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
136151
argprefix::Symbol = gensym("generatearg")
137152
resprefix::Symbol = gensym("generateresult")
138153
resargprefix::Symbol = gensym("generateresarg")
@@ -196,8 +211,18 @@ end
196211
return result
197212
end
198213

199-
@noinline function simulate!(
200-
f::Function, args::Vararg{Any,Nargs}; trace::Dict{Symbol,Any}
214+
function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
215+
trace = ProbProgTrace()
216+
217+
res = @jit optimize = :probprog sync = true simulate_internal(f, args...; trace)
218+
219+
trace.retval = res isa AbstractConcreteArray ? Array(res) : res
220+
221+
return trace
222+
end
223+
224+
function simulate_internal(
225+
f::Function, args::Vararg{Any,Nargs}; trace::ProbProgTrace
201226
) where {Nargs}
202227
argprefix::Symbol = gensym("simulatearg")
203228
resprefix::Symbol = gensym("simulateresult")
@@ -266,16 +291,5 @@ end
266291
return result
267292
end
268293

269-
function create_trace()
270-
return Dict{Symbol,Any}()
271-
end
272294

273-
function print_trace(trace::Dict{Symbol,Any})
274-
println("### Probabilistic Program Trace ###")
275-
for (symbol, sample) in trace
276-
println(" $symbol:")
277-
println(" Sample: $(sample)")
278-
end
279-
return println("### End of Trace ###")
280-
end
281295
end

src/Reactant.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ include("stdlibs/Base.jl")
176176

177177
# Other Integrations
178178
include("Enzyme.jl")
179-
include("ProbProg.jl")
180179

181180
const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}
182181

@@ -189,6 +188,7 @@ export OptimizeCommunicationOptions
189188
include("Compiler.jl")
190189

191190
include("Overlay.jl")
191+
include("ProbProg.jl")
192192

193193
using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, traced_getfield, compile
194194
export ConcreteRArray,

test/probprog/blr.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@ function blr(seed, N, K)
1414
Random.seed!(rng, seed)
1515

1616
# α ~ Normal(0, 10, size = 1)
17-
α = ProbProg.sample!(normal, rng, 0, 10, (1,); symbol=)
17+
α = ProbProg.sample(normal, rng, 0, 10, (1,); symbol=)
1818

1919
# β ~ Normal(0, 2.5, size = K)
20-
β = ProbProg.sample!(normal, rng, 0, 2.5, (K,); symbol=)
20+
β = ProbProg.sample(normal, rng, 0, 2.5, (K,); symbol=)
2121

2222
# X ~ Normal(0, 10, size = (N, K))
23-
X = ProbProg.sample!(normal, rng, 0, 10, (N, K); symbol=:X)
23+
X = ProbProg.sample(normal, rng, 0, 10, (N, K); symbol=:X)
2424

2525
# μ = α .+ X * β
2626
μ = α .+ X * β
2727

28-
Y = ProbProg.sample!(bernoulli_logit, rng, μ, (N,); symbol=:Y)
28+
Y = ProbProg.sample(bernoulli_logit, rng, μ, (N,); symbol=:Y)
2929

3030
return Y
3131
end
@@ -35,11 +35,9 @@ end
3535
K = 3 # number of features
3636
seed = Reactant.to_rarray(UInt64[1, 4])
3737

38-
trace = ProbProg.create_trace()
38+
trace = ProbProg.simulate(blr, seed, N, K)
3939

40-
@test size(
41-
Array(@jit optimize = :probprog ProbProg.simulate!(blr, seed, N, K; trace))
42-
) == (N,)
40+
@test size(Array(trace.retval)) == (N,)
4341

44-
ProbProg.print_trace(trace)
42+
println(trace)
4543
end

test/probprog/generate.jl

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
1-
using Reactant, Test, Random, StableRNGs, Statistics
1+
using Reactant, Test, Random, Statistics
22
using Reactant: ProbProg
33

44
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
55

6-
function generate_model(seed, μ, σ, shape)
7-
function model(seed, μ, σ, shape)
8-
rng = Random.default_rng()
9-
Random.seed!(rng, seed)
10-
s = ProbProg.sample!(normal, rng, μ, σ, shape)
11-
t = ProbProg.sample!(normal, rng, s, σ, shape)
12-
return t
13-
end
14-
15-
return ProbProg.generate!(model, seed, μ, σ, shape)
6+
function model(seed, μ, σ, shape)
7+
rng = Random.default_rng()
8+
Random.seed!(rng, seed)
9+
s = ProbProg.sample(normal, rng, μ, σ, shape)
10+
t = ProbProg.sample(normal, rng, s, σ, shape)
11+
return t
1612
end
1713

1814
@testset "Generate" begin
@@ -25,6 +21,9 @@ end
2521
σ1 = Reactant.ConcreteRNumber(1.0)
2622
σ2 = Reactant.ConcreteRNumber(1.0)
2723

24+
generate_model(seed, μ, σ, shape) =
25+
ProbProg.generate_internal(model, seed, μ, σ, shape)
26+
2827
model_compiled = @compile optimize = :probprog generate_model(seed1, μ1, σ1, shape)
2928

3029
@test Array(model_compiled(seed1, μ1, σ1, shape))
@@ -44,11 +43,15 @@ end
4443
μ = Reactant.ConcreteRNumber(0.0)
4544
σ = Reactant.ConcreteRNumber(1.0)
4645

47-
before = @code_hlo optimize = :no_enzyme generate_model(seed, μ, σ, shape)
46+
before = @code_hlo optimize = :no_enzyme ProbProg.generate_internal(
47+
model, seed, μ, σ, shape
48+
)
4849
@test contains(repr(before), "enzyme.generate")
4950
@test contains(repr(before), "enzyme.sample")
5051

51-
after = @code_hlo optimize = :probprog generate_model(seed, μ, σ, shape)
52+
after = @code_hlo optimize = :probprog ProbProg.generate_internal(
53+
model, seed, μ, σ, shape
54+
)
5255
@test !contains(repr(after), "enzyme.generate")
5356
@test !contains(repr(after), "enzyme.sample")
5457
end
@@ -58,23 +61,22 @@ end
5861
seed = Reactant.to_rarray(UInt64[1, 4])
5962
μ = Reactant.ConcreteRNumber(0.0)
6063
σ = Reactant.ConcreteRNumber(1.0)
61-
X = Array(@jit optimize = :probprog generate_model(seed, μ, σ, shape))
64+
X = ProbProg.generate(model, seed, μ, σ, shape)
6265
@test mean(X) 0.0 atol = 0.05 rtol = 0.05
6366
end
6467

6568
@testset "correctness" begin
6669
op(x, y) = x * y'
6770

6871
function fake_model(x, y)
69-
return ProbProg.sample!(op, x, y)
72+
return ProbProg.sample(op, x, y)
7073
end
7174

7275
x = reshape(collect(Float64, 1:12), (4, 3))
7376
y = reshape(collect(Float64, 1:12), (4, 3))
7477
x_ra = Reactant.to_rarray(x)
7578
y_ra = Reactant.to_rarray(y)
7679

77-
@test Array(@jit optimize = :probprog ProbProg.generate!(fake_model, x_ra, y_ra)) ==
78-
op(x, y)
80+
@test ProbProg.generate(fake_model, x_ra, y_ra) == op(x, y)
7981
end
8082
end

test/probprog/sample.jl

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,21 @@
1-
using Reactant, Test, Random, StableRNGs, Statistics
1+
using Reactant, Test, Random
22
using Reactant: ProbProg
33

4-
@noinline normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
4+
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
55

6-
function sample1(seed, μ, σ, shape)
7-
function model(seed, μ, σ, shape)
8-
rng = Random.default_rng()
9-
Random.seed!(rng, seed)
10-
s = ProbProg.sample!(normal, rng, μ, σ, shape)
11-
return s
12-
end
13-
14-
return ProbProg.generate!(model, seed, μ, σ, shape)
6+
function one_sample(seed, μ, σ, shape)
7+
rng = Random.default_rng()
8+
Random.seed!(rng, seed)
9+
s = ProbProg.sample(normal, rng, μ, σ, shape)
10+
return s
1511
end
1612

17-
function sample2(seed, μ, σ, shape)
18-
function model(seed, μ, σ, shape)
19-
rng = Random.default_rng()
20-
Random.seed!(rng, seed)
21-
_ = ProbProg.sample!(normal, rng, μ, σ, shape)
22-
t = ProbProg.sample!(normal, rng, μ, σ, shape)
23-
return t
24-
end
25-
26-
return ProbProg.generate!(model, seed, μ, σ, shape)
13+
function two_samples(seed, μ, σ, shape)
14+
rng = Random.default_rng()
15+
Random.seed!(rng, seed)
16+
_ = ProbProg.sample(normal, rng, μ, σ, shape)
17+
t = ProbProg.sample(normal, rng, μ, σ, shape)
18+
return t
2719
end
2820

2921
@testset "test" begin
@@ -32,19 +24,19 @@ end
3224
seed = Reactant.to_rarray(UInt64[1, 4])
3325
μ = Reactant.ConcreteRNumber(0.0)
3426
σ = Reactant.ConcreteRNumber(1.0)
35-
before = @code_hlo optimize = false sample2(seed, μ, σ, shape)
27+
before = @code_hlo optimize = false ProbProg.generate_internal(one_sample, seed, μ, σ, shape)
3628
@test contains(repr(before), "enzyme.sample")
37-
after = @code_hlo optimize = :probprog sample2(seed, μ, σ, shape)
29+
after = @code_hlo optimize = :probprog ProbProg.generate_internal(two_samples, seed, μ, σ, shape)
3830
@test !contains(repr(after), "enzyme.sample")
3931
end
4032

41-
@testset "sample_normal" begin
33+
@testset "rng_state" begin
4234
shape = (10,)
4335
seed = Reactant.to_rarray(UInt64[1, 4])
4436
μ = Reactant.ConcreteRNumber(0.0)
4537
σ = Reactant.ConcreteRNumber(1.0)
46-
X = Array(@jit optimize = :probprog sample1(seed, μ, σ, shape))
47-
Y = Array(@jit optimize = :probprog sample2(seed, μ, σ, shape))
38+
X = ProbProg.generate(one_sample, seed, μ, σ, shape)
39+
Y = ProbProg.generate(two_samples, seed, μ, σ, shape)
4840
@test !all(X .≈ Y)
4941
end
5042
end

test/probprog/simulate.jl

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,32 @@
1-
using Reactant, Test, Random, StableRNGs, Statistics
1+
using Reactant, Test, Random
22
using Reactant: ProbProg
33

4-
@testset "Simulate" begin
5-
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
4+
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
65

7-
function simulate_model(trace, seed, μ, σ, shape)
8-
function model(seed, μ, σ, shape)
9-
rng = Random.default_rng()
10-
Random.seed!(rng, seed)
11-
s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol=:s)
12-
t = ProbProg.sample!(normal, rng, s, σ, shape; symbol=:t)
13-
return t
14-
end
6+
function model(seed, μ, σ, shape)
7+
rng = Random.default_rng()
8+
Random.seed!(rng, seed)
9+
s = ProbProg.sample(normal, rng, μ, σ, shape; symbol=:s)
10+
t = ProbProg.sample(normal, rng, s, σ, shape; symbol=:t)
11+
return t
12+
end
1513

16-
result = ProbProg.simulate!(model, seed, μ, σ, shape; trace)
17-
return result
18-
end
19-
@testset "normal_hlo" begin
20-
shape = (10000,)
14+
@testset "Simulate" begin
15+
@testset "simulate_hlo" begin
16+
shape = (3, 3, 3)
2117
seed = Reactant.to_rarray(UInt64[1, 4])
2218
μ = Reactant.ConcreteRNumber(0.0)
2319
σ = Reactant.ConcreteRNumber(1.0)
2420

25-
trace = ProbProg.create_trace()
26-
27-
before = @code_hlo optimize = :no_enzyme simulate_model(trace, seed, μ, σ, shape)
21+
before = @code_hlo optimize = false ProbProg.simulate_internal(
22+
model, seed, μ, σ, shape; trace = ProbProg.ProbProgTrace()
23+
)
2824
@test contains(repr(before), "enzyme.simulate")
29-
@test contains(repr(before), "enzyme.sample")
3025

31-
after = @code_hlo optimize = :probprog simulate_model(trace, seed, μ, σ, shape)
26+
after = @code_hlo optimize = :probprog ProbProg.simulate_internal(
27+
model, seed, μ, σ, shape; trace = ProbProg.ProbProgTrace()
28+
)
3229
@test !contains(repr(after), "enzyme.simulate")
33-
@test !contains(repr(after), "enzyme.sample")
34-
@test contains(repr(after), "enzyme_probprog_add_sample_to_trace")
3530
end
3631

3732
@testset "normal_simulate" begin
@@ -40,34 +35,30 @@ using Reactant: ProbProg
4035
μ = Reactant.ConcreteRNumber(0.0)
4136
σ = Reactant.ConcreteRNumber(1.0)
4237

43-
trace = ProbProg.create_trace()
44-
45-
result = Array(@jit optimize = :probprog simulate_model(trace, seed, μ, σ, shape))
38+
trace = ProbProg.simulate(model, seed, μ, σ, shape)
4639

47-
@test size(result) == shape
48-
@test haskey(trace, :s)
49-
@test haskey(trace, :t)
50-
@test size(trace[:s]) == shape
51-
@test size(trace[:t]) == shape
40+
@test size(trace.retval) == shape
41+
@test haskey(trace.choices, :s)
42+
@test haskey(trace.choices, :t)
43+
@test size(trace.choices[:s]) == shape
44+
@test size(trace.choices[:t]) == shape
5245
end
5346

5447
@testset "correctness" begin
5548
op(x, y) = x * y'
5649
function fake_model(x, y)
57-
return ProbProg.sample!(op, x, y; symbol=:matmul)
50+
return ProbProg.sample(op, x, y; symbol=:matmul)
5851
end
5952

60-
trace = ProbProg.create_trace()
6153
x = reshape(collect(Float64, 1:12), (4, 3))
6254
y = reshape(collect(Float64, 1:12), (4, 3))
6355
x_ra = Reactant.to_rarray(x)
6456
y_ra = Reactant.to_rarray(y)
6557

66-
@test Array(
67-
@jit optimize = :probprog ProbProg.simulate!(fake_model, x_ra, y_ra; trace)
68-
) == op(x, y)
58+
trace = ProbProg.simulate(fake_model, x_ra, y_ra)
6959

70-
@test haskey(trace, :matmul)
71-
@test trace[:matmul] == op(x, y)
60+
@test Array(trace.retval) == op(x, y)
61+
@test haskey(trace.choices, :matmul)
62+
@test trace.choices[:matmul] == op(x, y)
7263
end
7364
end

0 commit comments

Comments
 (0)