1
1
module ProbProg
2
2
3
- using .. Reactant: Reactant, XLA, MLIR, TracedUtils, TracedRArray, ConcretePJRTArray
4
- using ReactantCore: ReactantCore
5
- using Libdl: Libdl
6
-
3
+ using .. Reactant: MLIR, TracedUtils, AbstractConcreteArray
7
4
using Enzyme
8
5
9
6
struct SampleMetadata
@@ -18,16 +15,10 @@ struct SampleMetadata
18
15
end
19
16
end
20
17
21
- const SAMPLE_METADATA_CACHE = IdDict {Symbol,SampleMetadata} ()
22
- const Trace = IdDict {Symbol,Any} (:_integrity_check => 0x123456789abcdef )
23
-
24
- function initTraceLowered (trace_ptr_ptr:: Ptr{Ptr{Cvoid}} )
25
- trace_ptr = unsafe_load (trace_ptr_ptr)
26
- @assert reinterpret (UInt64, trace_ptr) == 42
27
-
28
- unsafe_store! (trace_ptr_ptr, pointer_from_objref (Trace))
18
+ const SAMPLE_METADATA_CACHE = Dict {Symbol,SampleMetadata} ()
29
19
30
- return nothing
20
+ function createTrace ()
21
+ return Dict {Symbol,Any} (:_integrity_check => 0x123456789abcdef )
31
22
end
32
23
33
24
function addSampleToTraceLowered (
@@ -46,7 +37,7 @@ function addSampleToTraceLowered(
46
37
if is_scalar
47
38
trace[symbol] = unsafe_load (reinterpret (Ptr{element_type}, sample_ptr))
48
39
else
49
- trace[symbol] = Base . deepcopy (
40
+ trace[symbol] = copy (
50
41
reshape (
51
42
unsafe_wrap (
52
43
Array{element_type},
@@ -62,10 +53,6 @@ function addSampleToTraceLowered(
62
53
end
63
54
64
55
function __init__ ()
65
- init_trace_ptr = @cfunction (initTraceLowered, Cvoid, (Ptr{Ptr{Cvoid}},))
66
- @ccall MLIR. API. mlir_c. EnzymeJaXMapSymbol (
67
- :enzyme_probprog_init_trace :: Cstring , init_trace_ptr:: Ptr{Cvoid}
68
- ):: Cvoid
69
56
add_sample_to_trace_ptr = @cfunction (
70
57
addSampleToTraceLowered, Cvoid, (Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Ptr{Cvoid})
71
58
)
81
68
resprefix:: Symbol = gensym (" generateresult" )
82
69
resargprefix:: Symbol = gensym (" generateresarg" )
83
70
84
- mlir_fn_res = invokelatest (TracedUtils. make_mlir_fn,
71
+ mlir_fn_res = invokelatest (
72
+ TracedUtils. make_mlir_fn,
85
73
f,
86
74
args,
87
75
(),
@@ -139,13 +127,17 @@ end
139
127
end
140
128
141
129
@noinline function sample! (
142
- f:: Function , args:: Vararg{Any,Nargs} ; symbol:: Symbol = gensym (" sample" )
130
+ f:: Function ,
131
+ args:: Vararg{Any,Nargs} ;
132
+ symbol:: Symbol = gensym (" sample" ),
133
+ trace:: Union{Dict,Nothing} = nothing ,
143
134
) where {Nargs}
144
135
argprefix:: Symbol = gensym (" samplearg" )
145
136
resprefix:: Symbol = gensym (" sampleresult" )
146
137
resargprefix:: Symbol = gensym (" sampleresarg" )
147
138
148
- mlir_fn_res = invokelatest (TracedUtils. make_mlir_fn,
139
+ mlir_fn_res = invokelatest (
140
+ TracedUtils. make_mlir_fn,
149
141
f,
150
142
args,
151
143
(),
@@ -191,47 +183,44 @@ end
191
183
)
192
184
end
193
185
194
- symbol_ptr = pointer_from_objref (symbol)
195
- symbol_addr = reinterpret (UInt64, symbol_ptr)
196
- addr_attr = MLIR. IR. DenseElementsAttribute ([symbol_addr])
186
+ symbol_addr = reinterpret (UInt64, pointer_from_objref (symbol))
197
187
198
188
sample_op = MLIR. Dialects. enzyme. sample (
199
- MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= addr_attr), 1 ),
200
- batch_inputs;
201
- outputs= out_tys,
202
- fn= fn_attr,
189
+ batch_inputs; outputs= out_tys, fn= fn_attr, symbol= symbol_addr
203
190
)
204
191
205
192
for (i, res) in enumerate (linear_results)
206
193
resv = MLIR. IR. result (sample_op, i)
207
-
208
- for path in res. paths
209
- isempty (path) && continue
210
- if path[1 ] == resprefix
211
- TracedUtils. set! (result, path[2 : end ], resv)
212
- elseif path[1 ] == argprefix
213
- idx = path[2 ]:: Int
214
- if idx == 1 && fnwrap
215
- TracedUtils. set! (f, path[3 : end ], resv)
216
- else
217
- if fnwrap
218
- idx -= 1
219
- end
220
- TracedUtils. set! (args[idx], path[3 : end ], resv)
194
+ if TracedUtils. has_idx (res, resprefix)
195
+ path = TracedUtils. get_idx (res, resprefix)
196
+ TracedUtils. set! (result, path[2 : end ], TracedUtils. transpose_val (resv))
197
+ elseif TracedUtils. has_idx (res, argprefix)
198
+ idx, path = TracedUtils. get_argidx (res, argprefix)
199
+ if idx == 1 && fnwrap
200
+ TracedUtils. set! (f, path[3 : end ], TracedUtils. transpose_val (resv))
201
+ else
202
+ if fnwrap
203
+ idx -= 1
221
204
end
205
+ TracedUtils. set! (args[idx], path[3 : end ], TracedUtils. transpose_val (resv))
222
206
end
207
+ else
208
+ TracedUtils. set! (res, (), TracedUtils. transpose_val (resv))
223
209
end
224
210
end
225
211
226
212
return result
227
213
end
228
214
229
- @noinline function simulate! (f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
215
+ @noinline function simulate! (
216
+ f:: Function , args:: Vararg{Any,Nargs} ; trace:: Dict
217
+ ) where {Nargs}
230
218
argprefix:: Symbol = gensym (" simulatearg" )
231
219
resprefix:: Symbol = gensym (" simulateresult" )
232
220
resargprefix:: Symbol = gensym (" simulateresarg" )
233
221
234
- mlir_fn_res = TracedUtils. make_mlir_fn (
222
+ mlir_fn_res = invokelatest (
223
+ TracedUtils. make_mlir_fn,
235
224
f,
236
225
args,
237
226
(),
@@ -242,10 +231,14 @@ end
242
231
resprefix,
243
232
resargprefix,
244
233
)
245
- (; linear_args, linear_results) = mlir_fn_res
234
+ (; result, linear_args, in_tys , linear_results) = mlir_fn_res
246
235
fnwrap = mlir_fn_res. fnwrapped
247
236
func2 = mlir_fn_res. f
248
237
238
+ out_tys = [MLIR. IR. type (TracedUtils. get_mlir_data (res)) for res in linear_results]
239
+ fname = TracedUtils. get_attribute_by_name (func2, " sym_name" )
240
+ fname = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
241
+
249
242
batch_inputs = MLIR. IR. Value[]
250
243
for a in linear_args
251
244
idx, path = TracedUtils. get_argidx (a, argprefix)
@@ -259,63 +252,36 @@ end
259
252
end
260
253
end
261
254
262
- out_tys = MLIR. IR. Type[]
263
- supress_rest = false
264
- for res in linear_results
265
- if TracedUtils. has_idx (res, resprefix) && ! supress_rest
266
- push! (out_tys, MLIR. IR. TensorType ([1 ], MLIR. IR. Type (UInt64)))
267
- supress_rest = true
268
- else
269
- # push!(out_tys, MLIR.IR.type(TracedUtils.get_mlir_data(res)))
270
- end
271
- end
255
+ trace_addr = reinterpret (UInt64, pointer_from_objref (trace))
272
256
273
- fname = TracedUtils. get_attribute_by_name (func2, " sym_name" )
274
- fname = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
275
-
276
- simulate_op = MLIR. Dialects. enzyme. simulate (batch_inputs; outputs= out_tys, fn= fname)
257
+ simulate_op = MLIR. Dialects. enzyme. simulate (
258
+ batch_inputs; outputs= out_tys, fn= fname, trace= trace_addr
259
+ )
277
260
278
- result = nothing
279
261
for (i, res) in enumerate (linear_results)
280
262
resv = MLIR. IR. result (simulate_op, i)
281
-
282
263
if TracedUtils. has_idx (res, resprefix)
283
- # casted = MLIR.IR.result(
284
- # MLIR.Dialects.builtin.unrealized_conversion_cast(
285
- # resv; to=MLIR.IR.TensorType([1], MLIR.IR.Type(UInt64))
286
- # ),
287
- # 1,
288
- # )
289
- # result = TracedRArray(casted)
290
- result = TracedRArray (resv)
291
- break
292
- # continue
264
+ path = TracedUtils. get_idx (res, resprefix)
265
+ TracedUtils. set! (result, path[2 : end ], TracedUtils. transpose_val (resv))
266
+ elseif TracedUtils. has_idx (res, argprefix)
267
+ idx, path = TracedUtils. get_argidx (res, argprefix)
268
+ if idx == 1 && fnwrap
269
+ TracedUtils. set! (f, path[3 : end ], TracedUtils. transpose_val (resv))
270
+ else
271
+ if fnwrap
272
+ idx -= 1
273
+ end
274
+ TracedUtils. set! (args[idx], path[3 : end ], TracedUtils. transpose_val (resv))
275
+ end
276
+ else
277
+ TracedUtils. set! (res, (), TracedUtils. transpose_val (resv))
293
278
end
294
-
295
- # for path in res.paths
296
- # isempty(path) && continue
297
- # if path[1] == argprefix
298
- # idx = path[2]::Int
299
- # if idx == 1 && fnwrap
300
- # TracedUtils.set!(f, path[3:end], resv)
301
- # else
302
- # if fnwrap
303
- # idx -= 1
304
- # end
305
- # TracedUtils.set!(args[idx], path[3:end], resv)
306
- # end
307
- # end
308
- # end
309
279
end
310
280
311
- return result
312
- end
313
-
314
- function getTrace (t:: ConcretePJRTArray )
315
- return unsafe_pointer_to_objref (reinterpret (Ptr{Cvoid}, Array {UInt64,1} (t)[1 ]))
281
+ return trace, result
316
282
end
317
283
318
- function print_trace (trace:: IdDict )
284
+ function print_trace (trace:: Dict )
319
285
println (" Probabilistic Program Trace:" )
320
286
for (symbol, sample) in trace
321
287
symbol == :_integrity_check && continue
0 commit comments