@@ -118,6 +118,69 @@ function addSubtrace(
118
118
return nothing
119
119
end
120
120
121
+ function addWeightToTrace (trace_ptr_ptr:: Ptr{Ptr{Any}} , weight_ptr:: Ptr{Any} )
122
+ trace = unsafe_pointer_to_objref (unsafe_load (trace_ptr_ptr)):: ProbProgTrace
123
+ trace. weight = unsafe_load (Ptr {Float64} (weight_ptr))
124
+ return nothing
125
+ end
126
+
127
+ function addRetvalToTrace (
128
+ trace_ptr_ptr:: Ptr{Ptr{Any}} ,
129
+ retval_ptr_array:: Ptr{Ptr{Any}} ,
130
+ num_results_ptr:: Ptr{UInt64} ,
131
+ ndims_array:: Ptr{UInt64} ,
132
+ shape_ptr_array:: Ptr{Ptr{UInt64}} ,
133
+ width_array:: Ptr{UInt64} ,
134
+ )
135
+ trace = unsafe_pointer_to_objref (unsafe_load (trace_ptr_ptr)):: ProbProgTrace
136
+
137
+ num_results = unsafe_load (num_results_ptr)
138
+
139
+ if num_results == 0
140
+ return nothing
141
+ end
142
+
143
+ ndims_array = unsafe_wrap (Array, ndims_array, num_results)
144
+ width_array = unsafe_wrap (Array, width_array, num_results)
145
+ shape_ptr_array = unsafe_wrap (Array, shape_ptr_array, num_results)
146
+ retval_ptr_array = unsafe_wrap (Array, retval_ptr_array, num_results)
147
+
148
+ vals = Any[]
149
+ for i in 1 : num_results
150
+ ndims = ndims_array[i]
151
+ width = width_array[i]
152
+ shape_ptr = shape_ptr_array[i]
153
+ retval_ptr = retval_ptr_array[i]
154
+
155
+ julia_type = if width == 32
156
+ Float32
157
+ elseif width == 64
158
+ Float64
159
+ elseif width == 1
160
+ Bool
161
+ else
162
+ nothing
163
+ end
164
+
165
+ if julia_type === nothing
166
+ @ccall printf (
167
+ " Unsupported datatype width: %lld\n " :: Cstring , width:: Int64
168
+ ):: Cvoid
169
+ return nothing
170
+ end
171
+
172
+ if ndims == 0
173
+ push! (vals, unsafe_load (Ptr {julia_type} (retval_ptr)))
174
+ else
175
+ shape = unsafe_wrap (Array, shape_ptr, ndims)
176
+ push! (vals, copy (unsafe_wrap (Array, Ptr {julia_type} (retval_ptr), Tuple (shape))))
177
+ end
178
+ end
179
+
180
+ trace. retval = length (vals) == 1 ? vals[1 ] : vals
181
+ return nothing
182
+ end
183
+
121
184
function __init__ ()
122
185
init_trace_ptr = @cfunction (initTrace, Cvoid, (Ptr{Ptr{Any}},))
123
186
@ccall MLIR. API. mlir_c. EnzymeJaXMapSymbol (
@@ -148,6 +211,27 @@ function __init__()
148
211
:enzyme_probprog_add_subtrace :: Cstring , add_subtrace_ptr:: Ptr{Cvoid}
149
212
):: Cvoid
150
213
214
+ add_weight_to_trace_ptr = @cfunction (addWeightToTrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Any}))
215
+ @ccall MLIR. API. mlir_c. EnzymeJaXMapSymbol (
216
+ :enzyme_probprog_add_weight_to_trace :: Cstring , add_weight_to_trace_ptr:: Ptr{Cvoid}
217
+ ):: Cvoid
218
+
219
+ add_retval_to_trace_ptr = @cfunction (
220
+ addRetvalToTrace,
221
+ Cvoid,
222
+ (
223
+ Ptr{Ptr{Any}},
224
+ Ptr{Ptr{Any}},
225
+ Ptr{UInt64},
226
+ Ptr{UInt64},
227
+ Ptr{Ptr{UInt64}},
228
+ Ptr{UInt64},
229
+ ),
230
+ )
231
+ @ccall MLIR. API. mlir_c. EnzymeJaXMapSymbol (
232
+ :enzyme_probprog_add_retval_to_trace :: Cstring , add_retval_to_trace_ptr:: Ptr{Cvoid}
233
+ ):: Cvoid
234
+
151
235
return nothing
152
236
end
153
237
@@ -392,8 +476,6 @@ function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
392
476
end
393
477
394
478
trace = unsafe_pointer_to_objref (Ptr {Any} (Array (trace)[1 ]))
395
- trace. retval = res isa AbstractConcreteArray ? Array (res) : res
396
- trace. weight = Array (weight)[1 ]
397
479
398
480
return trace, trace. weight
399
481
end
0 commit comments