@@ -3,58 +3,48 @@ module ProbProg
3
3
using .. Reactant: MLIR, TracedUtils, AbstractConcreteArray
4
4
using Enzyme
5
5
6
- struct SampleMetadata
7
- shape:: NTuple{N,Int} where {N}
8
- element_type:: Type
9
- is_scalar:: Bool
10
-
11
- function SampleMetadata (
12
- shape:: NTuple{N,Int} , element_type:: Type , is_scalar:: Bool
13
- ) where {N}
14
- return new (shape, element_type, is_scalar)
15
- end
16
- end
17
-
18
- const SAMPLE_METADATA_CACHE = Dict {Symbol,SampleMetadata} ()
19
-
20
6
function createTrace ()
21
- return Dict {Symbol,Any} (:_integrity_check => 0x123456789abcdef )
7
+ return Dict {Symbol,Any} ()
22
8
end
23
9
24
10
function addSampleToTraceLowered (
25
- trace_ptr_ptr:: Ptr{Ptr{Cvoid}} , symbol_ptr_ptr:: Ptr{Ptr{Cvoid}} , sample_ptr:: Ptr{Cvoid}
11
+ trace_ptr_ptr:: Ptr{Ptr{Any}} ,
12
+ symbol_ptr_ptr:: Ptr{Ptr{Any}} ,
13
+ sample_ptr:: Ptr{Any} ,
14
+ num_dims_ptr:: Ptr{Int64} ,
15
+ shape_array_ptr:: Ptr{Int64} ,
16
+ datatype_width_ptr:: Ptr{Int64} ,
26
17
)
27
18
trace = unsafe_pointer_to_objref (unsafe_load (trace_ptr_ptr))
28
19
symbol = unsafe_pointer_to_objref (unsafe_load (symbol_ptr_ptr))
29
20
30
- @assert haskey (SAMPLE_METADATA_CACHE, symbol) " Symbol $symbol not found in metadata cache"
21
+ num_dims = unsafe_load (num_dims_ptr)
22
+ shape_array = unsafe_wrap (Array, shape_array_ptr, num_dims)
23
+ datatype_width = unsafe_load (datatype_width_ptr)
31
24
32
- metadata = SAMPLE_METADATA_CACHE[symbol]
33
- shape = metadata. shape
34
- element_type = metadata. element_type
35
- is_scalar = metadata. is_scalar
25
+ julia_type = if datatype_width == 32
26
+ Float32
27
+ elseif datatype_width == 64
28
+ Float64
29
+ else
30
+ error (" Unsupported datatype width: $datatype_width " )
31
+ end
36
32
37
- if is_scalar
38
- trace[symbol] = unsafe_load (reinterpret (Ptr{element_type}, sample_ptr))
33
+ typed_ptr = Ptr {julia_type} (sample_ptr)
34
+ if num_dims == 0
35
+ trace[symbol] = unsafe_load (typed_ptr)
39
36
else
40
- trace[symbol] = copy (
41
- reshape (
42
- unsafe_wrap (
43
- Array{element_type},
44
- reinterpret (Ptr{element_type}, sample_ptr),
45
- prod (shape),
46
- ),
47
- shape,
48
- ),
49
- )
37
+ trace[symbol] = copy (unsafe_wrap (Array, typed_ptr, Tuple (shape_array)))
50
38
end
51
39
52
40
return nothing
53
41
end
54
42
55
43
function __init__ ()
56
44
add_sample_to_trace_ptr = @cfunction (
57
- addSampleToTraceLowered, Cvoid, (Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Ptr{Cvoid})
45
+ addSampleToTraceLowered,
46
+ Cvoid,
47
+ (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Any}, Ptr{Int64}, Ptr{Int64}, Ptr{Int64})
58
48
)
59
49
@ccall MLIR. API. mlir_c. EnzymeJaXMapSymbol (
60
50
:enzyme_probprog_add_sample_to_trace :: Cstring , add_sample_to_trace_ptr:: Ptr{Cvoid}
105
95
106
96
for (i, res) in enumerate (linear_results)
107
97
resv = MLIR. IR. result (gen_op, i)
108
- for path in res. paths
109
- isempty (path) && continue
110
- if path[1 ] == resprefix
111
- TracedUtils. set! (result, path[2 : end ], resv)
112
- elseif path[1 ] == argprefix
113
- idx = path[2 ]:: Int
114
- if idx == 1 && fnwrap
115
- TracedUtils. set! (f, path[3 : end ], resv)
116
- else
117
- if fnwrap
118
- idx -= 1
119
- end
120
- TracedUtils. set! (args[idx], path[3 : end ], resv)
98
+ if TracedUtils. has_idx (res, resprefix)
99
+ path = TracedUtils. get_idx (res, resprefix)
100
+ TracedUtils. set! (result, path[2 : end ], TracedUtils. transpose_val (resv))
101
+ elseif TracedUtils. has_idx (res, argprefix)
102
+ idx, path = TracedUtils. get_argidx (res, argprefix)
103
+ if idx == 1 && fnwrap
104
+ TracedUtils. set! (f, path[3 : end ], TracedUtils. transpose_val (resv))
105
+ else
106
+ if fnwrap
107
+ idx -= 1
121
108
end
109
+ TracedUtils. set! (args[idx], path[3 : end ], TracedUtils. transpose_val (resv))
122
110
end
111
+ else
112
+ TracedUtils. set! (res, (), TracedUtils. transpose_val (resv))
123
113
end
124
114
end
125
115
126
116
return result
127
117
end
128
118
129
119
@noinline function sample! (
130
- f:: Function ,
131
- args:: Vararg{Any,Nargs} ;
132
- symbol:: Symbol = gensym (" sample" ),
133
- trace:: Union{Dict,Nothing} = nothing ,
120
+ f:: Function , args:: Vararg{Any,Nargs} ; symbol:: Symbol = gensym (" sample" )
134
121
) where {Nargs}
135
122
argprefix:: Symbol = gensym (" samplearg" )
136
123
resprefix:: Symbol = gensym (" sampleresult" )
@@ -169,24 +156,21 @@ end
169
156
sym = TracedUtils. get_attribute_by_name (func2, " sym_name" )
170
157
fn_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (sym))
171
158
172
- if ! isempty (linear_results)
173
- sample_result = linear_results[1 ] # TODO : consider multiple results
174
- sample_mlir_data = TracedUtils. get_mlir_data (sample_result)
175
- @assert sample_mlir_data isa MLIR. IR. Value " Sample $sample_result is not a MLIR.IR.Value"
176
-
177
- sample_type = MLIR. IR. type (sample_mlir_data)
178
- sample_shape = size (sample_type)
179
- sample_element_type = MLIR. IR. julia_type (eltype (sample_type))
180
-
181
- SAMPLE_METADATA_CACHE[symbol] = SampleMetadata (
182
- sample_shape, sample_element_type, length (sample_shape) == 0
183
- )
159
+ traced_output_indices = Int[]
160
+ for (i, res) in enumerate (linear_results)
161
+ if TracedUtils. has_idx (res, resprefix)
162
+ push! (traced_output_indices, i - 1 )
163
+ end
184
164
end
185
165
186
166
symbol_addr = reinterpret (UInt64, pointer_from_objref (symbol))
187
167
188
168
sample_op = MLIR. Dialects. enzyme. sample (
189
- batch_inputs; outputs= out_tys, fn= fn_attr, symbol= symbol_addr
169
+ batch_inputs;
170
+ outputs= out_tys,
171
+ fn= fn_attr,
172
+ symbol= symbol_addr,
173
+ traced_output_indices= traced_output_indices,
190
174
)
191
175
192
176
for (i, res) in enumerate (linear_results)
213
197
end
214
198
215
199
@noinline function simulate! (
216
- f:: Function , args:: Vararg{Any,Nargs} ; trace:: Dict
200
+ f:: Function , args:: Vararg{Any,Nargs} ; trace:: Dict{Symbol,Any}
217
201
) where {Nargs}
218
202
argprefix:: Symbol = gensym (" simulatearg" )
219
203
resprefix:: Symbol = gensym (" simulateresult" )
@@ -278,25 +262,16 @@ end
278
262
end
279
263
end
280
264
281
- return trace, result
265
+ return result
282
266
end
283
267
284
- function print_trace (trace:: Dict )
285
- println (" Probabilistic Program Trace: " )
268
+ function print_trace (trace:: Dict{Symbol,Any} )
269
+ println (" ### Probabilistic Program Trace ### " )
286
270
for (symbol, sample) in trace
287
- symbol == :_integrity_check && continue
288
- metadata = SAMPLE_METADATA_CACHE[symbol]
289
-
290
271
println (" $symbol :" )
291
272
println (" Sample: $(sample) " )
292
- println (" Shape: $(metadata. shape) " )
293
- println (" Element Type: $(metadata. element_type) " )
294
273
end
274
+ println (" ### End of Trace ###" )
295
275
end
296
276
297
- function clear_sample_metadata_cache! ()
298
- empty! (SAMPLE_METADATA_CACHE)
299
- return nothing
300
- end
301
-
302
- end
277
+ end
0 commit comments