1
1
module ProbProg
2
2
3
- using .. Reactant: MLIR, TracedUtils, AbstractConcreteArray, AbstractConcreteNumber
3
+ using .. Reactant:
4
+ MLIR,
5
+ TracedUtils,
6
+ AbstractConcreteArray,
7
+ AbstractConcreteNumber,
8
+ AbstractRNG,
9
+ TracedRArray
4
10
using .. Compiler: @jit
5
11
using Enzyme
6
12
7
13
mutable struct ProbProgTrace
8
14
choices:: Dict{Symbol,Any}
9
15
retval:: Any
16
+ weight:: Any
10
17
11
18
function ProbProgTrace ()
12
- return new (Dict {Symbol,Any} (), nothing )
19
+ return new (Dict {Symbol,Any} (), nothing , nothing )
13
20
end
14
21
end
15
22
@@ -63,7 +70,10 @@ function __init__()
63
70
end
64
71
65
72
function sample (
66
- f:: Function , args:: Vararg{Any,Nargs} ; symbol:: Symbol = gensym (" sample" )
73
+ f:: Function ,
74
+ args:: Vararg{Any,Nargs} ;
75
+ symbol:: Symbol = gensym (" sample" ),
76
+ logpdf:: Union{Nothing,Function} = nothing ,
67
77
) where {Nargs}
68
78
argprefix:: Symbol = gensym (" samplearg" )
69
79
resprefix:: Symbol = gensym (" sampleresult" )
@@ -102,20 +112,68 @@ function sample(
102
112
sym = TracedUtils. get_attribute_by_name (func2, " sym_name" )
103
113
fn_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (sym))
104
114
115
+ # Specify which outputs to add to the trace.
105
116
traced_output_indices = Int[]
106
117
for (i, res) in enumerate (linear_results)
107
118
if TracedUtils. has_idx (res, resprefix)
108
119
push! (traced_output_indices, i - 1 )
109
120
end
110
121
end
111
122
123
+ # Specify which inputs to pass to logpdf.
124
+ traced_input_indices = Int[]
125
+ for (i, a) in enumerate (linear_args)
126
+ idx, _ = TracedUtils. get_argidx (a, argprefix)
127
+ if fnwrap && idx == 1 # TODO : add test for fnwrap
128
+ continue
129
+ end
130
+
131
+ if fnwrap
132
+ idx -= 1
133
+ end
134
+
135
+ if ! (args[idx] isa AbstractRNG)
136
+ push! (traced_input_indices, i - 1 )
137
+ end
138
+ end
139
+
112
140
symbol_addr = reinterpret (UInt64, pointer_from_objref (symbol))
113
141
142
+ # Construct MLIR attribute if Julia logpdf function is provided.
143
+ logpdf_attr = nothing
144
+ if logpdf != = nothing
145
+ # Just to get static information about the sample. TODO : kwargs?
146
+ example_sample = f (args... )
147
+
148
+ # Remove AbstractRNG from `f`'s argument list if present, assuming that
149
+ # logpdf parameters follows `(sample, args...)` convention.
150
+ logpdf_args = (example_sample,)
151
+ if ! isempty (args) && args[1 ] isa AbstractRNG
152
+ logpdf_args = (example_sample, Base. tail (args)... ) # TODO : kwargs?
153
+ end
154
+
155
+ logpdf_mlir = invokelatest (
156
+ TracedUtils. make_mlir_fn,
157
+ logpdf,
158
+ logpdf_args,
159
+ (),
160
+ string (logpdf),
161
+ false ;
162
+ do_transpose= false ,
163
+ args_in_result= :all ,
164
+ )
165
+
166
+ logpdf_sym = TracedUtils. get_attribute_by_name (logpdf_mlir. f, " sym_name" )
167
+ logpdf_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (logpdf_sym))
168
+ end
169
+
114
170
sample_op = MLIR. Dialects. enzyme. sample (
115
171
batch_inputs;
116
172
outputs= out_tys,
117
173
fn= fn_attr,
174
+ logpdf= logpdf_attr,
118
175
symbol= symbol_addr,
176
+ traced_input_indices= traced_input_indices,
119
177
traced_output_indices= traced_output_indices,
120
178
)
121
179
@@ -143,11 +201,19 @@ function sample(
143
201
end
144
202
145
203
function generate (f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
146
- res = @jit optimize = :probprog generate_internal (f, args... )
147
- return res isa AbstractConcreteArray ? Array (res) : res
204
+ trace = ProbProgTrace ()
205
+
206
+ weight, res = @jit optimize = :probprog generate_internal (f, args... ; trace)
207
+
208
+ trace. retval = res isa AbstractConcreteArray ? Array (res) : res
209
+ trace. weight = Array (weight)[1 ]
210
+
211
+ return trace, trace. weight
148
212
end
149
213
150
- function generate_internal (f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
214
+ function generate_internal (
215
+ f:: Function , args:: Vararg{Any,Nargs} ; trace:: ProbProgTrace
216
+ ) where {Nargs}
151
217
argprefix:: Symbol = gensym (" generatearg" )
152
218
resprefix:: Symbol = gensym (" generateresult" )
153
219
resargprefix:: Symbol = gensym (" generateresarg" )
@@ -169,7 +235,8 @@ function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
169
235
fnwrap = mlir_fn_res. fnwrapped
170
236
func2 = mlir_fn_res. f
171
237
172
- out_tys = [MLIR. IR. type (TracedUtils. get_mlir_data (res)) for res in linear_results]
238
+ f_out_tys = [MLIR. IR. type (TracedUtils. get_mlir_data (res)) for res in linear_results]
239
+ out_tys = [MLIR. IR. TensorType (Int64[], MLIR. IR. Type (Float64)); f_out_tys]
173
240
fname = TracedUtils. get_attribute_by_name (func2, " sym_name" )
174
241
fname = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
175
242
@@ -186,10 +253,17 @@ function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
186
253
end
187
254
end
188
255
189
- gen_op = MLIR. Dialects. enzyme. generate (batch_inputs; outputs= out_tys, fn= fname)
256
+ trace_addr = reinterpret (UInt64, pointer_from_objref (trace))
257
+
258
+ # Output: (weight, f's outputs...)
259
+ gen_op = MLIR. Dialects. enzyme. generate (
260
+ batch_inputs; outputs= out_tys, fn= fname, trace= trace_addr
261
+ )
262
+
263
+ weight = TracedRArray (MLIR. IR. result (gen_op, 1 ))
190
264
191
265
for (i, res) in enumerate (linear_results)
192
- resv = MLIR. IR. result (gen_op, i)
266
+ resv = MLIR. IR. result (gen_op, i + 1 ) # to skip weight
193
267
if TracedUtils. has_idx (res, resprefix)
194
268
path = TracedUtils. get_idx (res, resprefix)
195
269
TracedUtils. set! (result, path[2 : end ], resv)
@@ -208,7 +282,7 @@ function generate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
208
282
end
209
283
end
210
284
211
- return result
285
+ return weight, result
212
286
end
213
287
214
288
function simulate (f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
@@ -299,7 +373,6 @@ function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple)
299
373
LAST = ' \u 2514'
300
374
301
375
indent_vert = vcat (Char[' ' for _ in 1 : pre], Char[VERT, ' \n ' ])
302
- indent_vert_last = vcat (Char[' ' for _ in 1 : pre], Char[VERT, ' \n ' ])
303
376
indent = vcat (Char[' ' for _ in 1 : pre], Char[PLUS, HORZ, HORZ, ' ' ])
304
377
indent_last = vcat (Char[' ' for _ in 1 : pre], Char[LAST, HORZ, HORZ, ' ' ])
305
378
@@ -320,6 +393,10 @@ function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple)
320
393
n += 1
321
394
end
322
395
396
+ if trace. weight != = nothing
397
+ n += 1
398
+ end
399
+
323
400
cur = 1
324
401
325
402
if trace. retval != = nothing
@@ -328,6 +405,12 @@ function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple)
328
405
cur += 1
329
406
end
330
407
408
+ if trace. weight != = nothing
409
+ print (io, indent_vert_str)
410
+ print (io, (cur == n ? indent_last_str : indent_str) * " weight : $(trace. weight) \n " )
411
+ cur += 1
412
+ end
413
+
331
414
for (key, value) in sorted_choices
332
415
print (io, indent_vert_str)
333
416
print (io, (cur == n ? indent_last_str : indent_str) * " $(repr (key)) : $value \n " )
337
420
338
421
function Base. show (io:: IO , :: MIME"text/plain" , trace:: ProbProgTrace )
339
422
println (io, " ProbProgTrace:" )
340
- if isempty (trace. choices) && trace. retval === nothing
423
+ if isempty (trace. choices) && trace. retval === nothing && trace . weight === nothing
341
424
println (io, " (empty)" )
342
425
else
343
426
_show_pretty (io, trace, 0 , ())
@@ -350,7 +433,7 @@ function Base.show(io::IO, trace::ProbProgTrace)
350
433
has_retval = trace. retval != = nothing
351
434
print (io, " ProbProgTrace($(choices_count) choices" )
352
435
if has_retval
353
- print (io, " , retval=$(trace. retval) " )
436
+ print (io, " , retval=$(trace. retval) , weight= $(trace . weight) " )
354
437
end
355
438
print (io, " )" )
356
439
else
0 commit comments