Skip to content

Commit a344726

Browse files
committed
partial refactoring
1 parent dd9dcab commit a344726

File tree

1 file changed

+58
-92
lines changed

1 file changed

+58
-92
lines changed

src/ProbProg.jl

Lines changed: 58 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
module ProbProg
22

3-
using ..Reactant: Reactant, XLA, MLIR, TracedUtils, TracedRArray, ConcretePJRTArray
4-
using ReactantCore: ReactantCore
5-
using Libdl: Libdl
6-
3+
using ..Reactant: MLIR, TracedUtils, AbstractConcreteArray
74
using Enzyme
85

96
struct SampleMetadata
@@ -18,16 +15,10 @@ struct SampleMetadata
1815
end
1916
end
2017

21-
const SAMPLE_METADATA_CACHE = IdDict{Symbol,SampleMetadata}()
22-
const Trace = IdDict{Symbol,Any}(:_integrity_check => 0x123456789abcdef)
23-
24-
function initTraceLowered(trace_ptr_ptr::Ptr{Ptr{Cvoid}})
25-
trace_ptr = unsafe_load(trace_ptr_ptr)
26-
@assert reinterpret(UInt64, trace_ptr) == 42
27-
28-
unsafe_store!(trace_ptr_ptr, pointer_from_objref(Trace))
18+
const SAMPLE_METADATA_CACHE = Dict{Symbol,SampleMetadata}()
2919

30-
return nothing
20+
function createTrace()
21+
return Dict{Symbol,Any}(:_integrity_check => 0x123456789abcdef)
3122
end
3223

3324
function addSampleToTraceLowered(
@@ -46,7 +37,7 @@ function addSampleToTraceLowered(
4637
if is_scalar
4738
trace[symbol] = unsafe_load(reinterpret(Ptr{element_type}, sample_ptr))
4839
else
49-
trace[symbol] = Base.deepcopy(
40+
trace[symbol] = copy(
5041
reshape(
5142
unsafe_wrap(
5243
Array{element_type},
@@ -62,10 +53,6 @@ function addSampleToTraceLowered(
6253
end
6354

6455
function __init__()
65-
init_trace_ptr = @cfunction(initTraceLowered, Cvoid, (Ptr{Ptr{Cvoid}},))
66-
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(
67-
:enzyme_probprog_init_trace::Cstring, init_trace_ptr::Ptr{Cvoid}
68-
)::Cvoid
6956
add_sample_to_trace_ptr = @cfunction(
7057
addSampleToTraceLowered, Cvoid, (Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Ptr{Cvoid})
7158
)
@@ -81,7 +68,8 @@ end
8168
resprefix::Symbol = gensym("generateresult")
8269
resargprefix::Symbol = gensym("generateresarg")
8370

84-
mlir_fn_res = invokelatest(TracedUtils.make_mlir_fn,
71+
mlir_fn_res = invokelatest(
72+
TracedUtils.make_mlir_fn,
8573
f,
8674
args,
8775
(),
@@ -139,13 +127,17 @@ end
139127
end
140128

141129
@noinline function sample!(
142-
f::Function, args::Vararg{Any,Nargs}; symbol::Symbol=gensym("sample")
130+
f::Function,
131+
args::Vararg{Any,Nargs};
132+
symbol::Symbol=gensym("sample"),
133+
trace::Union{Dict,Nothing}=nothing,
143134
) where {Nargs}
144135
argprefix::Symbol = gensym("samplearg")
145136
resprefix::Symbol = gensym("sampleresult")
146137
resargprefix::Symbol = gensym("sampleresarg")
147138

148-
mlir_fn_res = invokelatest(TracedUtils.make_mlir_fn,
139+
mlir_fn_res = invokelatest(
140+
TracedUtils.make_mlir_fn,
149141
f,
150142
args,
151143
(),
@@ -191,47 +183,44 @@ end
191183
)
192184
end
193185

194-
symbol_ptr = pointer_from_objref(symbol)
195-
symbol_addr = reinterpret(UInt64, symbol_ptr)
196-
addr_attr = MLIR.IR.DenseElementsAttribute([symbol_addr])
186+
symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol))
197187

198188
sample_op = MLIR.Dialects.enzyme.sample(
199-
MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=addr_attr), 1),
200-
batch_inputs;
201-
outputs=out_tys,
202-
fn=fn_attr,
189+
batch_inputs; outputs=out_tys, fn=fn_attr, symbol=symbol_addr
203190
)
204191

205192
for (i, res) in enumerate(linear_results)
206193
resv = MLIR.IR.result(sample_op, i)
207-
208-
for path in res.paths
209-
isempty(path) && continue
210-
if path[1] == resprefix
211-
TracedUtils.set!(result, path[2:end], resv)
212-
elseif path[1] == argprefix
213-
idx = path[2]::Int
214-
if idx == 1 && fnwrap
215-
TracedUtils.set!(f, path[3:end], resv)
216-
else
217-
if fnwrap
218-
idx -= 1
219-
end
220-
TracedUtils.set!(args[idx], path[3:end], resv)
194+
if TracedUtils.has_idx(res, resprefix)
195+
path = TracedUtils.get_idx(res, resprefix)
196+
TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv))
197+
elseif TracedUtils.has_idx(res, argprefix)
198+
idx, path = TracedUtils.get_argidx(res, argprefix)
199+
if idx == 1 && fnwrap
200+
TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv))
201+
else
202+
if fnwrap
203+
idx -= 1
221204
end
205+
TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv))
222206
end
207+
else
208+
TracedUtils.set!(res, (), TracedUtils.transpose_val(resv))
223209
end
224210
end
225211

226212
return result
227213
end
228214

229-
@noinline function simulate!(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
215+
@noinline function simulate!(
216+
f::Function, args::Vararg{Any,Nargs}; trace::Dict
217+
) where {Nargs}
230218
argprefix::Symbol = gensym("simulatearg")
231219
resprefix::Symbol = gensym("simulateresult")
232220
resargprefix::Symbol = gensym("simulateresarg")
233221

234-
mlir_fn_res = TracedUtils.make_mlir_fn(
222+
mlir_fn_res = invokelatest(
223+
TracedUtils.make_mlir_fn,
235224
f,
236225
args,
237226
(),
@@ -242,10 +231,14 @@ end
242231
resprefix,
243232
resargprefix,
244233
)
245-
(; linear_args, linear_results) = mlir_fn_res
234+
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
246235
fnwrap = mlir_fn_res.fnwrapped
247236
func2 = mlir_fn_res.f
248237

238+
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
239+
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
240+
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
241+
249242
batch_inputs = MLIR.IR.Value[]
250243
for a in linear_args
251244
idx, path = TracedUtils.get_argidx(a, argprefix)
@@ -259,63 +252,36 @@ end
259252
end
260253
end
261254

262-
out_tys = MLIR.IR.Type[]
263-
supress_rest = false
264-
for res in linear_results
265-
if TracedUtils.has_idx(res, resprefix) && !supress_rest
266-
push!(out_tys, MLIR.IR.TensorType([1], MLIR.IR.Type(UInt64)))
267-
supress_rest = true
268-
else
269-
# push!(out_tys, MLIR.IR.type(TracedUtils.get_mlir_data(res)))
270-
end
271-
end
255+
trace_addr = reinterpret(UInt64, pointer_from_objref(trace))
272256

273-
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
274-
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
275-
276-
simulate_op = MLIR.Dialects.enzyme.simulate(batch_inputs; outputs=out_tys, fn=fname)
257+
simulate_op = MLIR.Dialects.enzyme.simulate(
258+
batch_inputs; outputs=out_tys, fn=fname, trace=trace_addr
259+
)
277260

278-
result = nothing
279261
for (i, res) in enumerate(linear_results)
280262
resv = MLIR.IR.result(simulate_op, i)
281-
282263
if TracedUtils.has_idx(res, resprefix)
283-
# casted = MLIR.IR.result(
284-
# MLIR.Dialects.builtin.unrealized_conversion_cast(
285-
# resv; to=MLIR.IR.TensorType([1], MLIR.IR.Type(UInt64))
286-
# ),
287-
# 1,
288-
# )
289-
# result = TracedRArray(casted)
290-
result = TracedRArray(resv)
291-
break
292-
# continue
264+
path = TracedUtils.get_idx(res, resprefix)
265+
TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv))
266+
elseif TracedUtils.has_idx(res, argprefix)
267+
idx, path = TracedUtils.get_argidx(res, argprefix)
268+
if idx == 1 && fnwrap
269+
TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv))
270+
else
271+
if fnwrap
272+
idx -= 1
273+
end
274+
TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv))
275+
end
276+
else
277+
TracedUtils.set!(res, (), TracedUtils.transpose_val(resv))
293278
end
294-
295-
# for path in res.paths
296-
# isempty(path) && continue
297-
# if path[1] == argprefix
298-
# idx = path[2]::Int
299-
# if idx == 1 && fnwrap
300-
# TracedUtils.set!(f, path[3:end], resv)
301-
# else
302-
# if fnwrap
303-
# idx -= 1
304-
# end
305-
# TracedUtils.set!(args[idx], path[3:end], resv)
306-
# end
307-
# end
308-
# end
309279
end
310280

311-
return result
312-
end
313-
314-
function getTrace(t::ConcretePJRTArray)
315-
return unsafe_pointer_to_objref(reinterpret(Ptr{Cvoid}, Array{UInt64,1}(t)[1]))
281+
return trace, result
316282
end
317283

318-
function print_trace(trace::IdDict)
284+
function print_trace(trace::Dict)
319285
println("Probabilistic Program Trace:")
320286
for (symbol, sample) in trace
321287
symbol == :_integrity_check && continue

0 commit comments

Comments
 (0)