Skip to content

Commit ef2e770

Browse files
committed
fixed tracing infra
1 parent ebeceb8 commit ef2e770

File tree

3 files changed

+165
-114
lines changed

3 files changed

+165
-114
lines changed

src/Compiler.jl

Lines changed: 81 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,21 +1740,91 @@ function compile_mlir!(
17401740
),
17411741
"only_enzyme",
17421742
)
1743+
elseif optimize === :probprog_no_lowering
1744+
run_pass_pipeline!(
1745+
mod,
1746+
join(
1747+
if raise_first
1748+
[
1749+
"mark-func-memory-effects",
1750+
opt_passes,
1751+
kern,
1752+
raise_passes,
1753+
"enzyme-batch",
1754+
opt_passes2,
1755+
enzyme_pass,
1756+
"probprog",
1757+
opt_passes2,
1758+
"canonicalize",
1759+
"remove-unnecessary-enzyme-ops",
1760+
"enzyme-simplify-math",
1761+
opt_passes2,
1762+
]
1763+
else
1764+
[
1765+
"mark-func-memory-effects",
1766+
opt_passes,
1767+
"enzyme-batch",
1768+
opt_passes2,
1769+
enzyme_pass,
1770+
"probprog",
1771+
opt_passes2,
1772+
"canonicalize",
1773+
"remove-unnecessary-enzyme-ops",
1774+
"enzyme-simplify-math",
1775+
opt_passes2,
1776+
kern,
1777+
raise_passes,
1778+
]
1779+
end,
1780+
",",
1781+
),
1782+
"probprog_no_lowering",
1783+
)
17431784
elseif optimize === :probprog
17441785
run_pass_pipeline!(
17451786
mod,
17461787
join(
1747-
[
1748-
"mark-func-memory-effects",
1749-
"enzyme-batch",
1750-
"probprog",
1751-
"canonicalize",
1752-
"remove-unnecessary-enzyme-ops",
1753-
"enzyme-simplify-math",
1754-
lower_enzyme_probprog_pass,
1755-
jit
1756-
],
1757-
',',
1788+
if raise_first
1789+
[
1790+
"mark-func-memory-effects",
1791+
opt_passes,
1792+
kern,
1793+
raise_passes,
1794+
"enzyme-batch",
1795+
opt_passes2,
1796+
enzyme_pass,
1797+
"probprog",
1798+
opt_passes2,
1799+
"canonicalize",
1800+
"remove-unnecessary-enzyme-ops",
1801+
"enzyme-simplify-math",
1802+
opt_passes2,
1803+
lower_enzymexla_linalg_pass,
1804+
lower_enzyme_probprog_pass,
1805+
jit,
1806+
]
1807+
else
1808+
[
1809+
"mark-func-memory-effects",
1810+
opt_passes,
1811+
"enzyme-batch",
1812+
opt_passes2,
1813+
enzyme_pass,
1814+
"probprog",
1815+
opt_passes2,
1816+
"canonicalize",
1817+
"remove-unnecessary-enzyme-ops",
1818+
"enzyme-simplify-math",
1819+
opt_passes2,
1820+
kern,
1821+
raise_passes,
1822+
lower_enzymexla_linalg_pass,
1823+
lower_enzyme_probprog_pass,
1824+
jit,
1825+
]
1826+
end,
1827+
",",
17581828
),
17591829
"probprog",
17601830
)

src/ProbProg.jl

Lines changed: 54 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -3,58 +3,48 @@ module ProbProg
33
using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray
44
using Enzyme
55

6-
struct SampleMetadata
7-
shape::NTuple{N,Int} where {N}
8-
element_type::Type
9-
is_scalar::Bool
10-
11-
function SampleMetadata(
12-
shape::NTuple{N,Int}, element_type::Type, is_scalar::Bool
13-
) where {N}
14-
return new(shape, element_type, is_scalar)
15-
end
16-
end
17-
18-
const SAMPLE_METADATA_CACHE = Dict{Symbol,SampleMetadata}()
19-
206
function createTrace()
21-
return Dict{Symbol,Any}(:_integrity_check => 0x123456789abcdef)
7+
return Dict{Symbol,Any}()
228
end
239

2410
function addSampleToTraceLowered(
25-
trace_ptr_ptr::Ptr{Ptr{Cvoid}}, symbol_ptr_ptr::Ptr{Ptr{Cvoid}}, sample_ptr::Ptr{Cvoid}
11+
trace_ptr_ptr::Ptr{Ptr{Any}},
12+
symbol_ptr_ptr::Ptr{Ptr{Any}},
13+
sample_ptr::Ptr{Any},
14+
num_dims_ptr::Ptr{Int64},
15+
shape_array_ptr::Ptr{Int64},
16+
datatype_width_ptr::Ptr{Int64},
2617
)
2718
trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))
2819
symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))
2920

30-
@assert haskey(SAMPLE_METADATA_CACHE, symbol) "Symbol $symbol not found in metadata cache"
21+
num_dims = unsafe_load(num_dims_ptr)
22+
shape_array = unsafe_wrap(Array, shape_array_ptr, num_dims)
23+
datatype_width = unsafe_load(datatype_width_ptr)
3124

32-
metadata = SAMPLE_METADATA_CACHE[symbol]
33-
shape = metadata.shape
34-
element_type = metadata.element_type
35-
is_scalar = metadata.is_scalar
25+
julia_type = if datatype_width == 32
26+
Float32
27+
elseif datatype_width == 64
28+
Float64
29+
else
30+
error("Unsupported datatype width: $datatype_width")
31+
end
3632

37-
if is_scalar
38-
trace[symbol] = unsafe_load(reinterpret(Ptr{element_type}, sample_ptr))
33+
typed_ptr = Ptr{julia_type}(sample_ptr)
34+
if num_dims == 0
35+
trace[symbol] = unsafe_load(typed_ptr)
3936
else
40-
trace[symbol] = copy(
41-
reshape(
42-
unsafe_wrap(
43-
Array{element_type},
44-
reinterpret(Ptr{element_type}, sample_ptr),
45-
prod(shape),
46-
),
47-
shape,
48-
),
49-
)
37+
trace[symbol] = copy(unsafe_wrap(Array, typed_ptr, Tuple(shape_array)))
5038
end
5139

5240
return nothing
5341
end
5442

5543
function __init__()
5644
add_sample_to_trace_ptr = @cfunction(
57-
addSampleToTraceLowered, Cvoid, (Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Ptr{Cvoid})
45+
addSampleToTraceLowered,
46+
Cvoid,
47+
(Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Any}, Ptr{Int64}, Ptr{Int64}, Ptr{Int64})
5848
)
5949
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(
6050
:enzyme_probprog_add_sample_to_trace::Cstring, add_sample_to_trace_ptr::Ptr{Cvoid}
@@ -105,32 +95,29 @@ end
10595

10696
for (i, res) in enumerate(linear_results)
10797
resv = MLIR.IR.result(gen_op, i)
108-
for path in res.paths
109-
isempty(path) && continue
110-
if path[1] == resprefix
111-
TracedUtils.set!(result, path[2:end], resv)
112-
elseif path[1] == argprefix
113-
idx = path[2]::Int
114-
if idx == 1 && fnwrap
115-
TracedUtils.set!(f, path[3:end], resv)
116-
else
117-
if fnwrap
118-
idx -= 1
119-
end
120-
TracedUtils.set!(args[idx], path[3:end], resv)
98+
if TracedUtils.has_idx(res, resprefix)
99+
path = TracedUtils.get_idx(res, resprefix)
100+
TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv))
101+
elseif TracedUtils.has_idx(res, argprefix)
102+
idx, path = TracedUtils.get_argidx(res, argprefix)
103+
if idx == 1 && fnwrap
104+
TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv))
105+
else
106+
if fnwrap
107+
idx -= 1
121108
end
109+
TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv))
122110
end
111+
else
112+
TracedUtils.set!(res, (), TracedUtils.transpose_val(resv))
123113
end
124114
end
125115

126116
return result
127117
end
128118

129119
@noinline function sample!(
130-
f::Function,
131-
args::Vararg{Any,Nargs};
132-
symbol::Symbol=gensym("sample"),
133-
trace::Union{Dict,Nothing}=nothing,
120+
f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample")
134121
) where {Nargs}
135122
argprefix::Symbol = gensym("samplearg")
136123
resprefix::Symbol = gensym("sampleresult")
@@ -169,24 +156,21 @@ end
169156
sym = TracedUtils.get_attribute_by_name(func2, "sym_name")
170157
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym))
171158

172-
if !isempty(linear_results)
173-
sample_result = linear_results[1] # TODO: consider multiple results
174-
sample_mlir_data = TracedUtils.get_mlir_data(sample_result)
175-
@assert sample_mlir_data isa MLIR.IR.Value "Sample $sample_result is not a MLIR.IR.Value"
176-
177-
sample_type = MLIR.IR.type(sample_mlir_data)
178-
sample_shape = size(sample_type)
179-
sample_element_type = MLIR.IR.julia_type(eltype(sample_type))
180-
181-
SAMPLE_METADATA_CACHE[symbol] = SampleMetadata(
182-
sample_shape, sample_element_type, length(sample_shape) == 0
183-
)
159+
traced_output_indices = Int[]
160+
for (i, res) in enumerate(linear_results)
161+
if TracedUtils.has_idx(res, resprefix)
162+
push!(traced_output_indices, i - 1)
163+
end
184164
end
185165

186166
symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol))
187167

188168
sample_op = MLIR.Dialects.enzyme.sample(
189-
batch_inputs; outputs=out_tys, fn=fn_attr, symbol=symbol_addr
169+
batch_inputs;
170+
outputs=out_tys,
171+
fn=fn_attr,
172+
symbol=symbol_addr,
173+
traced_output_indices=traced_output_indices,
190174
)
191175

192176
for (i, res) in enumerate(linear_results)
@@ -213,7 +197,7 @@ end
213197
end
214198

215199
@noinline function simulate!(
216-
f::Function, args::Vararg{Any,Nargs}; trace::Dict
200+
f::Function, args::Vararg{Any,Nargs}; trace::Dict{Symbol,Any}
217201
) where {Nargs}
218202
argprefix::Symbol = gensym("simulatearg")
219203
resprefix::Symbol = gensym("simulateresult")
@@ -278,25 +262,16 @@ end
278262
end
279263
end
280264

281-
return trace, result
265+
return result
282266
end
283267

284-
function print_trace(trace::Dict)
285-
println("Probabilistic Program Trace:")
268+
function print_trace(trace::Dict{Symbol,Any})
269+
println("### Probabilistic Program Trace ###")
286270
for (symbol, sample) in trace
287-
symbol == :_integrity_check && continue
288-
metadata = SAMPLE_METADATA_CACHE[symbol]
289-
290271
println(" $symbol:")
291272
println(" Sample: $(sample)")
292-
println(" Shape: $(metadata.shape)")
293-
println(" Element Type: $(metadata.element_type)")
294273
end
274+
println("### End of Trace ###")
295275
end
296276

297-
function clear_sample_metadata_cache!()
298-
empty!(SAMPLE_METADATA_CACHE)
299-
return nothing
300-
end
301-
302-
end
277+
end

test/probprog/simulate.jl

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,52 @@
11
using Reactant, Test, Random, StableRNGs, Statistics
22
using Reactant: ProbProg
3-
using Libdl: Libdl
43

5-
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
4+
@testset "Simulate" begin
5+
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
66

7-
function simulate_model(seed, μ, σ, shape)
8-
function model(seed, μ, σ, shape)
9-
rng = Random.default_rng()
10-
Random.seed!(rng, seed)
11-
s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol=:s)
12-
t = ProbProg.sample!(normal, rng, s, σ, shape; symbol=:t)
13-
return t
14-
end
7+
function simulate_model(trace, seed, μ, σ, shape)
8+
function model(seed, μ, σ, shape)
9+
rng = Random.default_rng()
10+
Random.seed!(rng, seed)
11+
s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol=:s)
12+
t = ProbProg.sample!(normal, rng, s, σ, shape; symbol=:t)
13+
return t
14+
end
1515

16-
return ProbProg.simulate!(model, seed, μ, σ, shape)
17-
end
18-
19-
@testset "Simulate" begin
16+
result = ProbProg.simulate!(model, seed, μ, σ, shape; trace)
17+
return result
18+
end
2019
@testset "normal_hlo" begin
2120
shape = (10000,)
2221
seed = Reactant.to_rarray(UInt64[1, 4])
23-
μ = Reactant.ConcreteRArray(0.0)
24-
σ = Reactant.ConcreteRArray(1.0)
22+
μ = Reactant.ConcreteRNumber(0.0)
23+
σ = Reactant.ConcreteRNumber(1.0)
24+
25+
trace = ProbProg.createTrace()
2526

26-
before = @code_hlo optimize = :no_enzyme simulate_model(seed, μ, σ, shape)
27+
before = @code_hlo optimize = :no_enzyme simulate_model(trace, seed, μ, σ, shape)
2728
@test contains(repr(before), "enzyme.simulate")
2829
@test contains(repr(before), "enzyme.sample")
2930

30-
after = @code_hlo optimize = :probprog simulate_model(seed, μ, σ, shape)
31+
after = @code_hlo optimize = :probprog simulate_model(trace, seed, μ, σ, shape)
3132
@test !contains(repr(after), "enzyme.simulate")
3233
@test !contains(repr(after), "enzyme.sample")
3334
@test contains(repr(after), "enzyme_probprog_add_sample_to_trace")
34-
@test contains(repr(after), "enzyme_probprog_init_trace")
3535
end
3636

3737
@testset "normal_simulate" begin
3838
shape = (3, 3, 3)
3939
seed = Reactant.to_rarray(UInt64[1, 4])
40-
μ = Reactant.ConcreteRArray(0.0)
41-
σ = Reactant.ConcreteRArray(1.0)
42-
X = ProbProg.getTrace(@jit optimize = :probprog simulate_model(seed, μ, σ, shape))
43-
@test X[:_integrity_check] == 0x123456789abcdef
44-
ProbProg.print_trace(X)
40+
μ = Reactant.ConcreteRNumber(0.0)
41+
σ = Reactant.ConcreteRNumber(1.0)
42+
43+
trace = ProbProg.createTrace()
44+
45+
result = Array(
46+
@jit optimize = :probprog sync = true simulate_model(trace, seed, μ, σ, shape)
47+
)
48+
49+
ProbProg.print_trace(trace)
50+
@test size(result) == shape
4551
end
4652
end

0 commit comments

Comments
 (0)