Skip to content

Commit 99d7608

Browse files
committed
unconstrained real generate op
1 parent 561b051 commit 99d7608

File tree

2 files changed

+105
-61
lines changed

2 files changed

+105
-61
lines changed

src/ProbProg.jl

Lines changed: 96 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
module ProbProg
22

3-
using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray, AbstractConcreteNumber
3+
using ..Reactant:
4+
MLIR,
5+
TracedUtils,
6+
AbstractConcreteArray,
7+
AbstractConcreteNumber,
8+
AbstractRNG,
9+
TracedRArray
410
using ..Compiler: @jit
511
using Enzyme
612

713
mutable struct ProbProgTrace
814
choices::Dict{Symbol,Any}
915
retval::Any
16+
weight::Any
1017

1118
function ProbProgTrace()
12-
return new(Dict{Symbol,Any}(), nothing)
19+
return new(Dict{Symbol,Any}(), nothing, nothing)
1320
end
1421
end
1522

@@ -63,7 +70,10 @@ function __init__()
6370
end
6471

6572
function sample(
66-
f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample")
73+
f::Function,
74+
args::Vararg{Any,Nargs};
75+
symbol::Symbol=gensym("sample"),
76+
logpdf::Union{Nothing,Function}=nothing,
6777
) where {Nargs}
6878
argprefix::Symbol = gensym("samplearg")
6979
resprefix::Symbol = gensym("sampleresult")
@@ -102,20 +112,68 @@ function sample(
102112
sym = TracedUtils.get_attribute_by_name(func2, "sym_name")
103113
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym))
104114

115+
# Specify which outputs to add to the trace.
105116
traced_output_indices = Int[]
106117
for (i, res) in enumerate(linear_results)
107118
if TracedUtils.has_idx(res, resprefix)
108119
push!(traced_output_indices, i - 1)
109120
end
110121
end
111122

123+
# Specify which inputs to pass to logpdf.
124+
traced_input_indices = Int[]
125+
for (i, a) in enumerate(linear_args)
126+
idx, _ = TracedUtils.get_argidx(a, argprefix)
127+
if fnwrap && idx == 1 # TODO: add test for fnwrap
128+
continue
129+
end
130+
131+
if fnwrap
132+
idx -= 1
133+
end
134+
135+
if !(args[idx] isa AbstractRNG)
136+
push!(traced_input_indices, i - 1)
137+
end
138+
end
139+
112140
symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol))
113141

142+
# Construct MLIR attribute if Julia logpdf function is provided.
143+
logpdf_attr = nothing
144+
if logpdf !== nothing
145+
# Just to get static information about the sample. TODO: kwargs?
146+
example_sample = f(args...)
147+
148+
# Remove AbstractRNG from `f`'s argument list if present, assuming that
149+
# logpdf parameters follows `(sample, args...)` convention.
150+
logpdf_args = (example_sample,)
151+
if !isempty(args) && args[1] isa AbstractRNG
152+
logpdf_args = (example_sample, Base.tail(args)...) # TODO: kwargs?
153+
end
154+
155+
logpdf_mlir = invokelatest(
156+
TracedUtils.make_mlir_fn,
157+
logpdf,
158+
logpdf_args,
159+
(),
160+
string(logpdf),
161+
false;
162+
do_transpose=false,
163+
args_in_result=:all,
164+
)
165+
166+
logpdf_sym = TracedUtils.get_attribute_by_name(logpdf_mlir.f, "sym_name")
167+
logpdf_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(logpdf_sym))
168+
end
169+
114170
sample_op = MLIR.Dialects.enzyme.sample(
115171
batch_inputs;
116172
outputs=out_tys,
117173
fn=fn_attr,
174+
logpdf=logpdf_attr,
118175
symbol=symbol_addr,
176+
traced_input_indices=traced_input_indices,
119177
traced_output_indices=traced_output_indices,
120178
)
121179

@@ -143,11 +201,19 @@ function sample(
143201
end
144202

145203
function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
146-
res = @jit optimize = :probprog generate_internal(f, args...)
147-
return res isa AbstractConcreteArray ? Array(res) : res
204+
trace = ProbProgTrace()
205+
206+
weight, res = @jit optimize = :probprog generate_internal(f, args...; trace)
207+
208+
trace.retval = res isa AbstractConcreteArray ? Array(res) : res
209+
trace.weight = Array(weight)[1]
210+
211+
return trace, trace.weight
148212
end
149213

150-
function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
214+
function generate_internal(
215+
f::Function, args::Vararg{Any,Nargs}; trace::ProbProgTrace
216+
) where {Nargs}
151217
argprefix::Symbol = gensym("generatearg")
152218
resprefix::Symbol = gensym("generateresult")
153219
resargprefix::Symbol = gensym("generateresarg")
@@ -169,7 +235,8 @@ function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
169235
fnwrap = mlir_fn_res.fnwrapped
170236
func2 = mlir_fn_res.f
171237

172-
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
238+
f_out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
239+
out_tys = [MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)); f_out_tys]
173240
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
174241
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
175242

@@ -186,10 +253,17 @@ function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
186253
end
187254
end
188255

189-
gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname)
256+
trace_addr = reinterpret(UInt64, pointer_from_objref(trace))
257+
258+
# Output: (weight, f's outputs...)
259+
gen_op = MLIR.Dialects.enzyme.generate(
260+
batch_inputs; outputs=out_tys, fn=fname, trace=trace_addr
261+
)
262+
263+
weight = TracedRArray(MLIR.IR.result(gen_op, 1))
190264

191265
for (i, res) in enumerate(linear_results)
192-
resv = MLIR.IR.result(gen_op, i)
266+
resv = MLIR.IR.result(gen_op, i + 1) # to skip weight
193267
if TracedUtils.has_idx(res, resprefix)
194268
path = TracedUtils.get_idx(res, resprefix)
195269
TracedUtils.set!(result, path[2:end], resv)
@@ -208,7 +282,7 @@ function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
208282
end
209283
end
210284

211-
return result
285+
return weight, result
212286
end
213287

214288
function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
@@ -299,7 +373,6 @@ function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple)
299373
LAST = '\u2514'
300374

301375
indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n'])
302-
indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n'])
303376
indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' '])
304377
indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' '])
305378

@@ -320,6 +393,10 @@ function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple)
320393
n += 1
321394
end
322395

396+
if trace.weight !== nothing
397+
n += 1
398+
end
399+
323400
cur = 1
324401

325402
if trace.retval !== nothing
@@ -328,6 +405,12 @@ function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple)
328405
cur += 1
329406
end
330407

408+
if trace.weight !== nothing
409+
print(io, indent_vert_str)
410+
print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n")
411+
cur += 1
412+
end
413+
331414
for (key, value) in sorted_choices
332415
print(io, indent_vert_str)
333416
print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n")
@@ -337,7 +420,7 @@ end
337420

338421
function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace)
339422
println(io, "ProbProgTrace:")
340-
if isempty(trace.choices) && trace.retval === nothing
423+
if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing
341424
println(io, " (empty)")
342425
else
343426
_show_pretty(io, trace, 0, ())
@@ -350,7 +433,7 @@ function Base.show(io::IO, trace::ProbProgTrace)
350433
has_retval = trace.retval !== nothing
351434
print(io, "ProbProgTrace($(choices_count) choices")
352435
if has_retval
353-
print(io, ", retval=$(trace.retval)")
436+
print(io, ", retval=$(trace.retval), weight=$(trace.weight)")
354437
end
355438
print(io, ")")
356439
else

test/probprog/generate.jl

Lines changed: 9 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,81 +2,42 @@ using Reactant, Test, Random, Statistics
22
using Reactant: ProbProg
33

44
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
5+
normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2)
56

67
function model(seed, μ, σ, shape)
78
rng = Random.default_rng()
89
Random.seed!(rng, seed)
9-
s = ProbProg.sample(normal, rng, μ, σ, shape)
10-
t = ProbProg.sample(normal, rng, s, σ, shape)
10+
s = ProbProg.sample(normal, rng, μ, σ, shape; symbol=:s, logpdf=normal_logpdf)
11+
t = ProbProg.sample(normal, rng, s, σ, shape; symbol=:t, logpdf=normal_logpdf)
1112
return t
1213
end
1314

1415
@testset "Generate" begin
15-
@testset "deterministic" begin
16-
shape = (10000,)
17-
seed1 = Reactant.to_rarray(UInt64[1, 4])
18-
seed2 = Reactant.to_rarray(UInt64[1, 4])
19-
μ1 = Reactant.ConcreteRNumber(0.0)
20-
μ2 = Reactant.ConcreteRNumber(1000.0)
21-
σ1 = Reactant.ConcreteRNumber(1.0)
22-
σ2 = Reactant.ConcreteRNumber(1.0)
23-
24-
generate_model(seed, μ, σ, shape) =
25-
ProbProg.generate_internal(model, seed, μ, σ, shape)
26-
27-
model_compiled = @compile optimize = :probprog generate_model(seed1, μ1, σ1, shape)
28-
29-
@test Array(model_compiled(seed1, μ1, σ1, shape))
30-
Array(model_compiled(seed1, μ1, σ1, shape))
31-
@test mean(Array(model_compiled(seed1, μ1, σ1, shape))) 0.0 atol = 0.05 rtol =
32-
0.05
33-
@test mean(Array(model_compiled(seed2, μ2, σ2, shape))) 1000.0 atol = 0.05 rtol =
34-
0.05
35-
@test !(all(
36-
Array(model_compiled(seed1, μ1, σ1, shape)) .≈
37-
Array(model_compiled(seed2, μ2, σ2, shape)),
38-
))
39-
end
4016
@testset "hlo" begin
41-
shape = (10000,)
17+
shape = (10,)
4218
seed = Reactant.to_rarray(UInt64[1, 4])
4319
μ = Reactant.ConcreteRNumber(0.0)
4420
σ = Reactant.ConcreteRNumber(1.0)
4521

4622
before = @code_hlo optimize = :no_enzyme ProbProg.generate_internal(
47-
model, seed, μ, σ, shape
23+
model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace()
4824
)
4925
@test contains(repr(before), "enzyme.generate")
5026
@test contains(repr(before), "enzyme.sample")
5127

5228
after = @code_hlo optimize = :probprog ProbProg.generate_internal(
53-
model, seed, μ, σ, shape
29+
model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace()
5430
)
5531
@test !contains(repr(after), "enzyme.generate")
5632
@test !contains(repr(after), "enzyme.sample")
5733
end
5834

5935
@testset "normal" begin
60-
shape = (10000,)
36+
shape = (1000,)
6137
seed = Reactant.to_rarray(UInt64[1, 4])
6238
μ = Reactant.ConcreteRNumber(0.0)
6339
σ = Reactant.ConcreteRNumber(1.0)
64-
X = ProbProg.generate(model, seed, μ, σ, shape)
65-
@test mean(X) 0.0 atol = 0.05 rtol = 0.05
66-
end
67-
68-
@testset "correctness" begin
69-
op(x, y) = x * y'
70-
71-
function fake_model(x, y)
72-
return ProbProg.sample(op, x, y)
73-
end
74-
75-
x = reshape(collect(Float64, 1:12), (4, 3))
76-
y = reshape(collect(Float64, 1:12), (4, 3))
77-
x_ra = Reactant.to_rarray(x)
78-
y_ra = Reactant.to_rarray(y)
79-
80-
@test ProbProg.generate(fake_model, x_ra, y_ra) == op(x, y)
40+
trace, weight = ProbProg.generate(model, seed, μ, σ, shape)
41+
@test mean(trace.retval) 0.0 atol = 0.05 rtol = 0.05
8142
end
8243
end

0 commit comments

Comments
 (0)