Skip to content

Commit 38e33de

Browse files
committed
adding cfunction mapping for AddWeightToTrace and AddRetvalToTrace ops
1 parent f6ee849 commit 38e33de

File tree

1 file changed

+84
-2
lines changed

1 file changed

+84
-2
lines changed

src/ProbProg.jl

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,69 @@ function addSubtrace(
118118
return nothing
119119
end
120120

121+
function addWeightToTrace(trace_ptr_ptr::Ptr{Ptr{Any}}, weight_ptr::Ptr{Any})
122+
trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace
123+
trace.weight = unsafe_load(Ptr{Float64}(weight_ptr))
124+
return nothing
125+
end
126+
127+
function addRetvalToTrace(
128+
trace_ptr_ptr::Ptr{Ptr{Any}},
129+
retval_ptr_array::Ptr{Ptr{Any}},
130+
num_results_ptr::Ptr{UInt64},
131+
ndims_array::Ptr{UInt64},
132+
shape_ptr_array::Ptr{Ptr{UInt64}},
133+
width_array::Ptr{UInt64},
134+
)
135+
trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace
136+
137+
num_results = unsafe_load(num_results_ptr)
138+
139+
if num_results == 0
140+
return nothing
141+
end
142+
143+
ndims_array = unsafe_wrap(Array, ndims_array, num_results)
144+
width_array = unsafe_wrap(Array, width_array, num_results)
145+
shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_results)
146+
retval_ptr_array = unsafe_wrap(Array, retval_ptr_array, num_results)
147+
148+
vals = Any[]
149+
for i in 1:num_results
150+
ndims = ndims_array[i]
151+
width = width_array[i]
152+
shape_ptr = shape_ptr_array[i]
153+
retval_ptr = retval_ptr_array[i]
154+
155+
julia_type = if width == 32
156+
Float32
157+
elseif width == 64
158+
Float64
159+
elseif width == 1
160+
Bool
161+
else
162+
nothing
163+
end
164+
165+
if julia_type === nothing
166+
@ccall printf(
167+
"Unsupported datatype width: %lld\n"::Cstring, width::Int64
168+
)::Cvoid
169+
return nothing
170+
end
171+
172+
if ndims == 0
173+
push!(vals, unsafe_load(Ptr{julia_type}(retval_ptr)))
174+
else
175+
shape = unsafe_wrap(Array, shape_ptr, ndims)
176+
push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(retval_ptr), Tuple(shape))))
177+
end
178+
end
179+
180+
trace.retval = length(vals) == 1 ? vals[1] : vals
181+
return nothing
182+
end
183+
121184
function __init__()
122185
init_trace_ptr = @cfunction(initTrace, Cvoid, (Ptr{Ptr{Any}},))
123186
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(
@@ -148,6 +211,27 @@ function __init__()
148211
:enzyme_probprog_add_subtrace::Cstring, add_subtrace_ptr::Ptr{Cvoid}
149212
)::Cvoid
150213

214+
add_weight_to_trace_ptr = @cfunction(addWeightToTrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Any}))
215+
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(
216+
:enzyme_probprog_add_weight_to_trace::Cstring, add_weight_to_trace_ptr::Ptr{Cvoid}
217+
)::Cvoid
218+
219+
add_retval_to_trace_ptr = @cfunction(
220+
addRetvalToTrace,
221+
Cvoid,
222+
(
223+
Ptr{Ptr{Any}},
224+
Ptr{Ptr{Any}},
225+
Ptr{UInt64},
226+
Ptr{UInt64},
227+
Ptr{Ptr{UInt64}},
228+
Ptr{UInt64},
229+
),
230+
)
231+
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(
232+
:enzyme_probprog_add_retval_to_trace::Cstring, add_retval_to_trace_ptr::Ptr{Cvoid}
233+
)::Cvoid
234+
151235
return nothing
152236
end
153237

@@ -392,8 +476,6 @@ function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
392476
end
393477

394478
trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1]))
395-
trace.retval = res isa AbstractConcreteArray ? Array(res) : res
396-
trace.weight = Array(weight)[1]
397479

398480
return trace, trace.weight
399481
end

0 commit comments

Comments
 (0)