diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 5369bed1ef..6c09a53079 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -388,6 +388,18 @@ enzymeSymbolAttrGet(MlirContext ctx, uint64_t symbol) { return wrap(attr); } +extern "C" MLIR_CAPI_EXPORTED MlirAttribute +enzymeRngDistributionAttrGet(MlirContext ctx, int32_t val) { + return wrap(mlir::enzyme::RngDistributionAttr::get( + unwrap(ctx), (mlir::enzyme::RngDistribution)val)); +} + +extern "C" MLIR_CAPI_EXPORTED MlirAttribute +enzymeMCMCAlgorithmAttrGet(MlirContext ctx, int32_t val) { + return wrap(mlir::enzyme::MCMCAlgorithmAttr::get( + unwrap(ctx), (mlir::enzyme::MCMCAlgorithm)val)); +} + // Create profiler session and start profiling REACTANT_ABI tsl::ProfilerSession * CreateProfilerSession(uint32_t device_tracer_level, diff --git a/deps/ReactantExtra/tblgen/jl-generators.cc b/deps/ReactantExtra/tblgen/jl-generators.cc index a3e871935c..0823ac9077 100644 --- a/deps/ReactantExtra/tblgen/jl-generators.cc +++ b/deps/ReactantExtra/tblgen/jl-generators.cc @@ -293,7 +293,7 @@ end operandname = "operand_" + std::to_string(i); } if (named_operand.isOptional()) { - operandsegmentsizes += "(" + operandname + "==nothing) ? 0 : 1"; + operandsegmentsizes += "(" + operandname + "==nothing) ? 0 : 1, "; continue; } operandsegmentsizes += named_operand.isVariadic() diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index e8cac78be6..925c357e1a 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -229,6 +229,7 @@ function CompileOptions(; :canonicalize, :just_batch, :none, + :probprog, ] end diff --git a/src/Compiler.jl b/src/Compiler.jl index d0356b3fb0..5bf8f55e59 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1302,6 +1302,7 @@ end # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate # However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" +const probprog_pass::String = "probprog{postpasses=\"arith-raise{stablehlo=true}\"}" function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true) pm = MLIR.IR.PassManager() @@ -1885,6 +1886,71 @@ function compile_mlir!( ), "no_enzyme", ) + elseif compile_options.optimization_passes === :probprog + run_pass_pipeline!( + mod, + join( + if compile_options.raise_first + [ + "mark-func-memory-effects", + opt_passes, + kern, + raise_passes, + "enzyme-batch", + opt_passes2, + probprog_pass, + "lower-probprog-to-stablehlo{backend=$backend}", + "outline-enzyme-regions", + enzyme_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ( + if compile_options.legalize_chlo_to_stablehlo + ["func.func(chlo-legalize-to-stablehlo)"] + else + [] + end + )..., + opt_passes2, + lower_enzymexla_linalg_pass, + "lower-probprog-trace-ops{backend=$backend}", + jit, + ] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + probprog_pass, + "lower-probprog-to-stablehlo{backend=$backend}", + "outline-enzyme-regions", + enzyme_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ( + if compile_options.legalize_chlo_to_stablehlo + ["func.func(chlo-legalize-to-stablehlo)"] + else + [] + end + )..., + opt_passes2, + kern, + raise_passes, + lower_enzymexla_linalg_pass, + "lower-probprog-trace-ops{backend=$backend}", + jit, + ] + end, + ",", + ), + "probprog", + ) elseif compile_options.optimization_passes === :only_enzyme run_pass_pipeline!( mod, diff --git a/src/Reactant.jl b/src/Reactant.jl index df78eccae2..97c733a58b 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -242,6 +242,7 @@ include("Tracing.jl") include("Compiler.jl") include("Overlay.jl") +include("probprog/ProbProg.jl") # Serialization include("serialization/Serialization.jl") diff --git a/src/Types.jl b/src/Types.jl index cc257c4ebf..df1813d1b5 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -229,6 +229,7 @@ function ConcretePJRTArray( end Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data) +Base.isready(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = all(isready, x.data) XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data) function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) x.sharding isa Sharding.NoShardInfo && return XLA.device(only(x.data)) @@ -405,6 +406,7 @@ function ConcreteIFRTArray( end Base.wait(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = wait(x.data) +Base.isready(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = isready(x.data) XLA.client(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = XLA.client(x.data) function XLA.device(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) return XLA.device(x.data) diff --git a/src/probprog/Display.jl b/src/probprog/Display.jl new file mode 100644 index 0000000000..a81992eb71 --- /dev/null +++ b/src/probprog/Display.jl @@ -0,0 +1,87 @@ +# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104 +function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) + VERT = '\u2502' + PLUS = '\u251C' + HORZ = '\u2500' + LAST = '\u2514' + + indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) + indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) + + for i in vert_bars + indent_vert[i] = VERT + indent[i] = VERT + indent_last[i] = VERT + end + + indent_vert_str = join(indent_vert) + indent_str = join(indent) + indent_last_str = join(indent_last) + + sorted_choices = sort(collect(trace.choices); by=x -> x[1]) + n = length(sorted_choices) + + if trace.retval !== nothing + n += 1 + end + + if trace.weight !== nothing + n += 1 + end + + cur = 1 + + if trace.retval !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n") + cur += 1 + end + + if trace.weight !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n") + cur += 1 + end + + for (key, value) in sorted_choices + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") + cur += 1 + end + + sorted_subtraces = sort(collect(trace.subtraces); by=x -> x[1]) + n += length(sorted_subtraces) + + for (key, subtrace) in sorted_subtraces + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "subtrace on $(repr(key))\n") + _show_pretty( + io, subtrace, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre + 1) + ) + cur += 1 + end +end + +function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace) + println(io, "ProbProgTrace:") + if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing + println(io, " (empty)") + else + _show_pretty(io, trace, 0, ()) + end +end + +function Base.show(io::IO, trace::ProbProgTrace) + if get(io, :compact, false) + choices_count = length(trace.choices) + has_retval = trace.retval !== nothing + print(io, "ProbProgTrace($(choices_count) choices") + if has_retval + print(io, ", retval=$(trace.retval), weight=$(trace.weight)") + end + print(io, ")") + else + show(io, MIME"text/plain"(), trace) + end +end diff --git a/src/probprog/FFI.jl b/src/probprog/FFI.jl new file mode 100644 index 0000000000..2f40390727 --- /dev/null +++ b/src/probprog/FFI.jl @@ -0,0 +1,768 @@ +using ..Reactant: MLIR, Profiler + +function initTrace(trace_ptr_ptr::Ptr{Ptr{Any}}) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.initTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + tr = ProbProgTrace() + _keepalive!(tr) + + unsafe_store!(trace_ptr_ptr, pointer_from_objref(tr)) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function addSampleToTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_outputs_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.addSampleToTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_outputs = unsafe_load(num_outputs_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_outputs) + width_array = unsafe_wrap(Array, width_array, num_outputs) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_outputs) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_outputs) + + vals = Any[] + for i in 1:num_outputs + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + return nothing + end + + if ndims == 0 + push!(vals, unsafe_load(Ptr{julia_type}(sample_ptr))) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)))) + end + end + + trace.choices[symbol] = tuple(vals...) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function addSubtrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + subtrace_ptr_ptr::Ptr{Ptr{Any}}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.addSubtrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + subtrace = unsafe_pointer_to_objref(unsafe_load(subtrace_ptr_ptr))::ProbProgTrace + + trace.subtraces[symbol] = subtrace + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function addWeightToTrace(trace_ptr_ptr::Ptr{Ptr{Any}}, weight_ptr::Ptr{Any}) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.addWeightToTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + trace.weight = unsafe_load(Ptr{Float64}(weight_ptr)) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function addRetvalToTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + retval_ptr_array::Ptr{Ptr{Any}}, + num_results_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.addRetvalToTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + num_results = unsafe_load(num_results_ptr) + + if num_results == 0 + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + ndims_array = unsafe_wrap(Array, ndims_array, num_results) + width_array = unsafe_wrap(Array, width_array, num_results) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_results) + retval_ptr_array = unsafe_wrap(Array, retval_ptr_array, num_results) + + vals = Any[] + for i in 1:num_results + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + retval_ptr = retval_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + if ndims == 0 + push!(vals, unsafe_load(Ptr{julia_type}(retval_ptr))) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(retval_ptr), Tuple(shape)))) + end + end + + trace.retval = tuple(vals...) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getSampleFromConstraint( + constraint_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_samples_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getSampleFromConstraint"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + constraint = unsafe_pointer_to_objref(unsafe_load(constraint_ptr_ptr))::Constraint + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_samples = unsafe_load(num_samples_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_samples) + width_array = unsafe_wrap(Array, width_array, num_samples) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_samples) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_samples) + + tostore = get(constraint, Address(symbol), nothing) + + if tostore === nothing + @ccall printf( + "No constraint found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + for i in 1:num_samples + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %zd\n"::Cstring, width::Csize_t + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + if julia_type != eltype(tostore[i]) + @ccall printf( + "Type mismatch in constrained sample: %s != %s\n"::Cstring, + string(julia_type)::Cstring, + string(eltype(tostore[i]))::Cstring, + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + if ndims == 0 + unsafe_store!(Ptr{julia_type}(sample_ptr), tostore[i]) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + dest = unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)) + + dest_size = size(dest) + src_size = size(tostore[i]) + + if dest_size != src_size + @ccall printf( + "Shape mismatch in constrained sample: expected %zd dims, got %zd\n"::Cstring, + length(dest_size)::Csize_t, + length(src_size)::Csize_t, + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + copyto!(dest, tostore[i]) + end + end + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getSubconstraint( + constraint_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + subconstraint_ptr_ptr::Ptr{Ptr{Any}}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getSubconstraint"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + constraint = unsafe_pointer_to_objref(unsafe_load(constraint_ptr_ptr))::Constraint + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + + subconstraint = Constraint() + + for (key, value) in constraint + if key.path[1] == symbol + @assert isa(key, Address) "Expected Address type for constraint key" + @assert length(key.path) > 1 "Expected composite address with length > 1" + tail_address = Address(key.path[2:end]) + subconstraint[tail_address] = value + end + end + + if isempty(subconstraint) + @ccall printf( + "No subconstraint found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + _keepalive!(subconstraint) + unsafe_store!(subconstraint_ptr_ptr, pointer_from_objref(subconstraint)) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getSampleFromTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_samples_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getSampleFromTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_samples = unsafe_load(num_samples_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_samples) + width_array = unsafe_wrap(Array, width_array, num_samples) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_samples) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_samples) + + tostore = get(trace.choices, symbol, nothing) + + if tostore === nothing + @ccall printf( + "No sample found in trace for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + for i in 1:num_samples + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %zd\n"::Cstring, width::Csize_t + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + if julia_type != eltype(tostore[i]) + @ccall printf( + "Type mismatch in trace sample: %s != %s\n"::Cstring, + string(julia_type)::Cstring, + string(eltype(tostore[i]))::Cstring, + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + if ndims == 0 + unsafe_store!(Ptr{julia_type}(sample_ptr), tostore[i]) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + dest = unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)) + + dest_size = size(dest) + src_size = size(tostore[i]) + + if dest_size != src_size + @ccall printf( + "Shape mismatch in trace sample: expected %zd dims, got %zd\n"::Cstring, + length(dest_size)::Csize_t, + length(src_size)::Csize_t, + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + copyto!(dest, tostore[i]) + end + end + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getSubtrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + subtrace_ptr_ptr::Ptr{Ptr{Any}}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getSubtrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + + subtrace = get(trace.subtraces, symbol, nothing) + + if subtrace === nothing + @ccall printf( + "No subtrace found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + _keepalive!(subtrace) + unsafe_store!(subtrace_ptr_ptr, pointer_from_objref(subtrace)) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getWeightFromTrace(trace_ptr_ptr::Ptr{Ptr{Any}}, weight_ptr::Ptr{Any}) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getWeightFromTrace"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("Trace dereference failure\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + unsafe_store!(Ptr{Float64}(weight_ptr), trace.weight) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function getFlattenedSamplesFromTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + num_addresses_ptr::Ptr{UInt64}, + total_symbols_ptr::Ptr{UInt64}, + address_lengths_ptr::Ptr{UInt64}, + flattened_symbols_ptr::Ptr{UInt64}, + position_ptr::Ptr{Any}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.getFlattenedSamplesFromTrace"::Cstring, + Profiler.TRACE_ME_LEVEL_CRITICAL::Cint, + )::Int64 + + trace = nothing + try + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + catch + @ccall printf("No trace found\n"::Cstring)::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + num_addresses = unsafe_load(num_addresses_ptr) + total_symbols = unsafe_load(total_symbols_ptr) + + address_lengths = unsafe_wrap(Array, address_lengths_ptr, num_addresses) + flattened_symbols = unsafe_wrap(Array, flattened_symbols_ptr, total_symbols) + + addresses = Vector{Vector{Symbol}}() + symbol_idx = 1 + for i in 1:num_addresses + addr_len = address_lengths[i] + address = Symbol[] + for j in 1:addr_len + symbol_ptr_value = flattened_symbols[symbol_idx] + symbol = unsafe_pointer_to_objref(Ptr{Any}(symbol_ptr_value))::Symbol + push!(address, symbol) + symbol_idx += 1 + end + push!(addresses, address) + end + + flattened_values = Float64[] + + for address in addresses + current_trace = trace + + for (idx, symbol) in enumerate(address) + if idx < length(address) + if !haskey(current_trace.subtraces, symbol) + @ccall printf( + "No subtrace found for symbol in address path: %s\n"::Cstring, + string(symbol)::Cstring, + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + current_trace = current_trace.subtraces[symbol] + else + if !haskey(current_trace.choices, symbol) + @ccall printf( + "No sample found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + sample_tuple = current_trace.choices[symbol] + + for sample_val in sample_tuple + if isa(sample_val, AbstractArray) + for val in sample_val + push!(flattened_values, Float64(val)) + end + else + push!(flattened_values, Float64(sample_val)) + end + end + end + end + end + + position_array = unsafe_wrap( + Array, Ptr{Float64}(position_ptr), length(flattened_values) + ) + copyto!(position_array, flattened_values) + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function dump( + value_ptr::Ptr{Any}, + label_ptr::Ptr{UInt8}, + ndims_ptr::Ptr{UInt64}, + shape_ptr::Ptr{UInt64}, + width_ptr::Ptr{UInt64}, +) + activity_id = @ccall MLIR.API.mlir_c.ProfilerActivityStart( + "ProbProg.dump"::Cstring, Profiler.TRACE_ME_LEVEL_CRITICAL::Cint + )::Int64 + + label = unsafe_string(label_ptr) + ndims = unsafe_load(ndims_ptr) + width = unsafe_load(width_ptr) + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + @ccall printf( + "DUMP ERROR: Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing + end + + println("═══ DUMP: $label ═══") + + if ndims == 0 + value = unsafe_load(Ptr{julia_type}(value_ptr)) + println(" Scalar ($julia_type): $value") + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + value_array = unsafe_wrap(Array, Ptr{julia_type}(value_ptr), Tuple(shape)) + + println(" Shape: $(Tuple(shape))") + println(" Type: Array{$julia_type}") + println(" Values:") + + total_elements = prod(shape) + if total_elements <= 20 + println(" ", value_array) + else + println(" [$(total_elements) elements]") + println(" min: $(minimum(value_array))") + println(" max: $(maximum(value_array))") + println(" mean: $(sum(value_array) / total_elements)") + println(" First 10: $(value_array[1:min(10, total_elements)])") + end + end + + println("═══════════════════════════════════") + + @ccall MLIR.API.mlir_c.ProfilerActivityEnd(activity_id::Int64)::Cvoid + return nothing +end + +function __init__() + init_trace_ptr = @cfunction(initTrace, Cvoid, (Ptr{Ptr{Any}},)) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_init_trace::Cstring, init_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_sample_to_trace_ptr = @cfunction( + addSampleToTrace, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_sample_to_trace::Cstring, add_sample_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_subtrace_ptr = @cfunction( + addSubtrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_subtrace::Cstring, add_subtrace_ptr::Ptr{Cvoid} + )::Cvoid + + add_weight_to_trace_ptr = @cfunction(addWeightToTrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Any})) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_weight_to_trace::Cstring, add_weight_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_retval_to_trace_ptr = @cfunction( + addRetvalToTrace, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ), + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_retval_to_trace::Cstring, add_retval_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + get_sample_from_constraint_ptr = @cfunction( + getSampleFromConstraint, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_sample_from_constraint::Cstring, + get_sample_from_constraint_ptr::Ptr{Cvoid}, + )::Cvoid + + get_subconstraint_ptr = @cfunction( + getSubconstraint, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_subconstraint::Cstring, get_subconstraint_ptr::Ptr{Cvoid} + )::Cvoid + + get_sample_from_trace_ptr = @cfunction( + getSampleFromTrace, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_sample_from_trace::Cstring, + get_sample_from_trace_ptr::Ptr{Cvoid}, + )::Cvoid + + get_subtrace_ptr = @cfunction( + getSubtrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_subtrace::Cstring, get_subtrace_ptr::Ptr{Cvoid} + )::Cvoid + + get_weight_from_trace_ptr = @cfunction( + getWeightFromTrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Any}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_weight_from_trace::Cstring, + get_weight_from_trace_ptr::Ptr{Cvoid}, + )::Cvoid + + get_flattened_samples_from_trace_ptr = @cfunction( + getFlattenedSamplesFromTrace, + Cvoid, + (Ptr{Ptr{Any}}, Ptr{UInt64}, Ptr{UInt64}, Ptr{UInt64}, Ptr{UInt64}, Ptr{Any}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_flattened_samples_from_trace::Cstring, + get_flattened_samples_from_trace_ptr::Ptr{Cvoid}, + )::Cvoid + + dump_ptr = @cfunction( + dump, Cvoid, (Ptr{Any}, Ptr{UInt8}, Ptr{UInt64}, Ptr{UInt64}, Ptr{UInt64}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_dump::Cstring, dump_ptr::Ptr{Cvoid} + )::Cvoid + + return nothing +end diff --git a/src/probprog/HMC.jl b/src/probprog/HMC.jl new file mode 100644 index 0000000000..134044c3b8 --- /dev/null +++ b/src/probprog/HMC.jl @@ -0,0 +1,126 @@ +using ..Reactant: ConcreteRNumber, TracedRArray + +function hmc( + rng::AbstractRNG, + original_trace::Union{ProbProgTrace,TracedRArray{UInt64,0}}, + f::Function, + args::Vararg{Any,Nargs}; + selection::Selection, + mass=nothing, + step_size=nothing, + num_steps=nothing, + initial_momentum=nothing, +) where {Nargs} + args = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "hmc") + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + + trace_val = if original_trace isa TracedRArray{UInt64,0} + MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [original_trace.mlir_data]; outputs=[trace_ty] + ), + 1, + ) + else + # First iteration: promote a ProbProgTrace to tensor + promoted = to_trace_tensor(original_trace) + MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [TracedUtils.get_mlir_data(promoted)]; outputs=[trace_ty] + ), + 1, + ) + end + + selection_attr = MLIR.IR.Attribute[] + for address in selection + address_attr = MLIR.IR.Attribute[] + for sym in address.path + sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) + push!( + address_attr, + @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, sym_addr::UInt64 + )::MLIR.IR.Attribute + ) + end + push!(selection_attr, MLIR.IR.Attribute(address_attr)) + end + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + accepted_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Bool)) + + alg_attr = @ccall MLIR.API.mlir_c.enzymeMCMCAlgorithmAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, + 0::Int32, # 0 = HMC + )::MLIR.IR.Attribute + + mass_val = nothing + if !isnothing(mass) + mass_val = TracedUtils.get_mlir_data(mass) + end + + step_size_val = nothing + if !isnothing(step_size) + step_size_val = TracedUtils.get_mlir_data(step_size) + end + + num_steps_val = nothing + if !isnothing(num_steps) + num_steps_val = TracedUtils.get_mlir_data(num_steps) + end + + initial_momentum_val = nothing + if !isnothing(initial_momentum) + initial_momentum_val = TracedUtils.get_mlir_data(initial_momentum) + end + + hmc_op = MLIR.Dialects.enzyme.mcmc( + inputs, + trace_val, + mass_val; + step_size=step_size_val, + num_steps=num_steps_val, + initial_momentum=initial_momentum_val, + new_trace=trace_ty, + accepted=accepted_ty, + output_rng_state=out_tys[1], # by convention + alg=alg_attr, + fn=fn_attr, + selection=MLIR.IR.Attribute(selection_attr), + ) + + # (new_trace, accepted, output_rng_state) + process_probprog_outputs( + hmc_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2, true + ) + + new_trace_val = MLIR.IR.result(hmc_op, 1) + new_trace_ptr = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [new_trace_val]; outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))] + ), + 1, + ) + + new_trace = TracedRArray{UInt64,0}((), new_trace_ptr, ()) + accepted = TracedRArray{Bool,0}((), MLIR.IR.result(hmc_op, 2), ()) + + return new_trace, accepted, result +end diff --git a/src/probprog/MH.jl b/src/probprog/MH.jl new file mode 100644 index 0000000000..93d5bbca79 --- /dev/null +++ b/src/probprog/MH.jl @@ -0,0 +1,95 @@ +using ..Reactant: ConcreteRNumber, TracedRArray + +function mh( + rng::AbstractRNG, + original_trace::Union{ProbProgTrace,TracedRArray{UInt64,0}}, + f::Function, + args::Vararg{Any,Nargs}; + selection::Selection, +) where {Nargs} + args = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "mh") + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + + if original_trace isa TracedRArray{UInt64,0} + # Use MLIR data from previous iteration + trace_val = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [original_trace.mlir_data]; outputs=[trace_ty] + ), + 1, + ) + else + # First iteration: create constant from pointer + promoted = to_trace_tensor(original_trace) + trace_val = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [TracedUtils.get_mlir_data(promoted)]; outputs=[trace_ty] + ), + 1, + ) + end + + selection_attr = MLIR.IR.Attribute[] + for address in selection + address_attr = MLIR.IR.Attribute[] + for sym in address.path + sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) + push!( + address_attr, + @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, sym_addr::UInt64 + )::MLIR.IR.Attribute + ) + end + push!(selection_attr, MLIR.IR.Attribute(address_attr)) + end + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + accepted_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Bool)) + + mh_op = MLIR.Dialects.enzyme.mh( + inputs, + trace_val; + new_trace=trace_ty, + accepted=accepted_ty, + output_rng_state=out_tys[1], # by convention + fn=fn_attr, + selection=MLIR.IR.Attribute(selection_attr), + ) + + # Return (new_trace, accepted, output_rng_state) + process_probprog_outputs( + mh_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2, true + ) + + new_trace_val = MLIR.IR.result(mh_op, 1) + new_trace_ptr = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [new_trace_val]; outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))] + ), + 1, + ) + + new_trace = TracedRArray{UInt64,0}((), new_trace_ptr, ()) + accepted = TracedRArray{Bool,0}((), MLIR.IR.result(mh_op, 2), ()) + + return new_trace, accepted, result +end + +const metropolis_hastings = mh diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl new file mode 100644 index 0000000000..b3e1f82e9f --- /dev/null +++ b/src/probprog/Modeling.jl @@ -0,0 +1,257 @@ +using ..Reactant: MLIR, TracedUtils, AbstractRNG, TracedRArray, ConcreteRNumber +using ..Compiler: @jit, @compile + +include("Utils.jl") + +function sample( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + symbol::Symbol=gensym("sample"), + logpdf::Union{Nothing,Function}=nothing, +) where {Nargs} + args_with_rng = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function( + f, args_with_rng, "sample" + ) + + (; result, linear_args, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + sym = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + + symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) + symbol_attr = @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, symbol_addr::UInt64 + )::MLIR.IR.Attribute + + # Construct logpdf attribute if `logpdf` function is provided. + logpdf_attr = nothing + if logpdf isa Function + samples = f(args_with_rng...) + + # Assume that logpdf parameters follow `(sample, args...)` convention. + logpdf_args = (samples, args...) + + logpdf_mlir = TracedUtils.make_mlir_fn( + logpdf, + logpdf_args, + (), + string(logpdf), + false; + do_transpose=false, + args_in_result=:result, + ) + + logpdf_sym = TracedUtils.get_attribute_by_name(logpdf_mlir.f, "sym_name") + logpdf_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(logpdf_sym)) + end + + sample_op = MLIR.Dialects.enzyme.sample( + inputs; + outputs=out_tys, + fn=fn_attr, + logpdf=logpdf_attr, + symbol=symbol_attr, + name=Base.String(symbol), + ) + + process_probprog_outputs( + sample_op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix + ) + + return result +end + +function untraced_call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + args_with_rng = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function( + f, args_with_rng, "call" + ) + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + call_op = MLIR.Dialects.enzyme.untracedCall(inputs; outputs=out_tys, fn=fn_attr) + + process_probprog_outputs( + call_op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix + ) + + return result +end + +# Gen-like helper function. +function simulate_(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + trace = nothing + + compiled_fn = @compile optimize = :probprog simulate(rng, f, args...) + + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer begin + t, _, _ = compiled_fn(rng, f, args...) + trace = from_trace_tensor(t) + end + + return trace, trace.weight +end + +function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + args = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "simulate") + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) + + simulate_op = MLIR.Dialects.enzyme.simulate( + inputs; trace=trace_ty, weight=weight_ty, outputs=out_tys, fn=fn_attr + ) + + process_probprog_outputs( + simulate_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2 + ) + + trace = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [MLIR.IR.result(simulate_op, 1)]; + outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))], + ), + 1, + ) + + trace = TracedRArray{UInt64,0}((), trace, ()) + weight = TracedRArray{Float64,0}((), MLIR.IR.result(simulate_op, 2), ()) + + return trace, weight, result +end + +# Gen-like helper function. +function generate_( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + constraint::Constraint=Constraint(), +) where {Nargs} + trace = nothing + + constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint))) + + constrained_addresses = extract_addresses(constraint) + + function wrapper_fn(rng, constraint_ptr, args...) + return generate(rng, f, args...; constraint_ptr, constrained_addresses) + end + + compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr, args...) + + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint begin + t, _, _ = compiled_fn(rng, constraint_ptr, args...) + trace = from_trace_tensor(t) + end + + return trace, trace.weight +end + +function generate( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + constraint_ptr::TracedRNumber, + constrained_addresses::Set{Address}, +) where {Nargs} + args = (rng, args...) + mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "generate") + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + constraint_ty = @ccall MLIR.API.mlir_c.enzymeConstraintTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + + constraint_val = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [TracedUtils.get_mlir_data(constraint_ptr)]; outputs=[constraint_ty] + ), + 1, + ) + + constrained_addresses_attr = MLIR.IR.Attribute[] + for address in constrained_addresses + address_attr = MLIR.IR.Attribute[] + for sym in address.path + sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) + push!( + address_attr, + @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, sym_addr::UInt64 + )::MLIR.IR.Attribute + ) + end + push!(constrained_addresses_attr, MLIR.IR.Attribute(address_attr)) + end + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) + + generate_op = MLIR.Dialects.enzyme.generate( + inputs, + constraint_val; + trace=trace_ty, + weight=weight_ty, + outputs=out_tys, + fn=fn_attr, + constrained_addresses=MLIR.IR.Attribute(constrained_addresses_attr), + ) + + process_probprog_outputs( + generate_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2 + ) + + trace = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [MLIR.IR.result(generate_op, 1)]; + outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))], + ), + 1, + ) + + trace = TracedRArray{UInt64,0}((), trace, ()) + weight = TracedRArray{Float64,0}((), MLIR.IR.result(generate_op, 2), ()) + + return trace, weight, result +end diff --git a/src/probprog/ProbProg.jl b/src/probprog/ProbProg.jl new file mode 100644 index 0000000000..23de59676f --- /dev/null +++ b/src/probprog/ProbProg.jl @@ -0,0 +1,28 @@ +module ProbProg + +using ..Reactant: + MLIR, TracedUtils, AbstractRNG, TracedRArray, TracedRNumber, ConcreteRNumber +using ..Compiler: @jit, @compile + +include("Types.jl") +include("FFI.jl") +include("Modeling.jl") +include("Display.jl") +include("MH.jl") +include("HMC.jl") + +# Types. +export ProbProgTrace, Constraint, Selection, Address + +# Utility functions. +export get_choices, select +export to_trace_tensor, from_trace_tensor +export to_constraint_tensor, from_constraint_tensor + +# Core MLIR ops. +export sample, untraced_call, simulate, generate, mh, hmc + +# Gen-like helper functions. +export simulate_, generate_ + +end diff --git a/src/probprog/Types.jl b/src/probprog/Types.jl new file mode 100644 index 0000000000..98f189d9a0 --- /dev/null +++ b/src/probprog/Types.jl @@ -0,0 +1,77 @@ +using Base: ReentrantLock + +mutable struct ProbProgTrace + choices::Dict{Symbol,Any} + retval::Any + weight::Any + subtraces::Dict{Symbol,Any} + + function ProbProgTrace() + return new(Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}()) + end +end + +struct Address + path::Vector{Symbol} + + Address(path::Vector{Symbol}) = new(path) +end + +Address(sym::Symbol) = Address([sym]) +Address(syms::Symbol...) = Address([syms...]) + +Base.:(==)(a::Address, b::Address) = a.path == b.path +Base.hash(a::Address, h::UInt) = hash(a.path, h) + +mutable struct Constraint <: AbstractDict{Address,Any} + dict::Dict{Address,Any} + + function Constraint(pairs::Pair...) + dict = Dict{Address,Any}() + for pair in pairs + symbols = Symbol[] + current = pair + while isa(current, Pair) && isa(current.first, Symbol) + push!(symbols, current.first) + current = current.second + end + dict[Address(symbols...)] = current + end + return new(dict) + end + + Constraint() = new(Dict{Address,Any}()) + Constraint(d::Dict{Address,Any}) = new(d) +end + +Base.getindex(c::Constraint, k::Address) = c.dict[k] +Base.setindex!(c::Constraint, v, k::Address) = (c.dict[k] = v) +Base.delete!(c::Constraint, k::Address) = delete!(c.dict, k) +Base.keys(c::Constraint) = keys(c.dict) +Base.values(c::Constraint) = values(c.dict) +Base.iterate(c::Constraint) = iterate(c.dict) +Base.iterate(c::Constraint, state) = iterate(c.dict, state) +Base.length(c::Constraint) = length(c.dict) +Base.isempty(c::Constraint) = isempty(c.dict) +Base.haskey(c::Constraint, k::Address) = haskey(c.dict, k) +Base.get(c::Constraint, k::Address, default) = get(c.dict, k, default) + +extract_addresses(constraint::Constraint) = Set(keys(constraint)) + +const Selection = Set{Address} + +const _probprog_ref_lock = ReentrantLock() +const _probprog_refs = IdDict() + +function _keepalive!(tr::Any) + lock(_probprog_ref_lock) + try + _probprog_refs[tr] = tr + finally + unlock(_probprog_ref_lock) + end + return tr +end + +get_choices(trace::ProbProgTrace) = trace.choices +select(addrs::Address...) = Set{Address}([addrs...]) diff --git a/src/probprog/Utils.jl b/src/probprog/Utils.jl new file mode 100644 index 0000000000..7e8c197d5a --- /dev/null +++ b/src/probprog/Utils.jl @@ -0,0 +1,154 @@ +using ..Reactant: MLIR, TracedUtils, Ops, TracedRArray +import ..Reactant: promote_to + +""" + process_probprog_function(f, args_with_rng, op_name) + +This function handles the probprog argument convention where: +- **Index 1**: RNG state +- **Index 2**: Function `f` (when wrapped) +- **Index 3+**: Remaining arguments + +This wrapper ensures the RNG state is threaded through as the first result, +followed by the actual function results. +""" +function process_probprog_function(f, args_with_rng, op_name) + argprefix = gensym(op_name * "arg") + resprefix = gensym(op_name * "result") + resargprefix = gensym(op_name * "resarg") + + wrapper_fn = (all_args...) -> begin + res = f(all_args...) + (all_args[1], (res isa Tuple ? res : (res,))...) + end + + mlir_fn_res = TracedUtils.make_mlir_fn( + wrapper_fn, + args_with_rng, + (), + string(f), + false; + do_transpose=false, + args_in_result=:result, + argprefix, + resprefix, + resargprefix, + ) + + return mlir_fn_res, argprefix, resprefix, resargprefix +end + +""" + process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) + +This function handles the probprog argument convention where: +- **Index 1**: RNG state +- **Index 2**: Function `f` (when `fnwrap` is true) +- **Index 3+**: Other arguments +""" +function process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix) + inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 2 && fnwrap + TracedUtils.push_val!(inputs, f, path[3:end]) + else + if fnwrap && idx > 1 + idx -= 1 + end + TracedUtils.push_val!(inputs, args_with_rng[idx], path[3:end]) + end + end + return inputs +end + +""" + process_probprog_outputs(op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix, start_idx=0, rng_only=false) + +This function handles the probprog argument convention where: +- **Index 1**: RNG state +- **Index 2**: Function `f` (when `fnwrap` is true) +- **Index 3+**: Other arguments + +When setting results, the function checks: +1. If result path matches `resprefix`, store in `result` +2. If result path matches `argprefix`, store in `args_with_rng` (adjust indices for wrapped function) + +`start_idx` varies depending on the ProbProg operation: +- `sample` and `untraced_call` return only function outputs: + Use `start_idx=0`: `linear_results[i]` corresponds to `op.result[i]` +- `simulate` and `generate` return trace, weight, then outputs: + Use `start_idx=2`: `linear_results[i]` corresponds to `op.result[i+2]` +- `mh` and `regenerate` return trace, accepted/weight, rng_state (no model outputs): + Use `start_idx=2, rng_only=true`: only process first result (rng_state) + +`rng_only`: When true, only process the first result (RNG state), skipping model outputs +""" +function process_probprog_outputs( + op, + linear_results, + result, + f, + args_with_rng, + fnwrap, + resprefix, + argprefix, + start_idx=0, + rng_only=false, +) + num_to_process = rng_only ? 1 : length(linear_results) + + for i in 1:num_to_process + res = linear_results[i] + resv = MLIR.IR.result(op, i + start_idx) + + if TracedUtils.has_idx(res, resprefix) + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], resv) + end + + if TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if fnwrap && idx == 2 + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap && idx > 2 + idx -= 1 + end + TracedUtils.set!(args_with_rng[idx], path[3:end], resv) + end + end + + if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix) + TracedUtils.set!(res, (), resv) + end + end +end + +to_trace_tensor(t::ProbProgTrace) = promote_to(TracedRArray{UInt64,0}, t) + +function from_trace_tensor(trace_tensor) + while !isready(trace_tensor) + yield() + end + return unsafe_pointer_to_objref(Ptr{Any}(Array(trace_tensor)[1]))::ProbProgTrace +end + +function promote_to(::Type{TracedRArray{UInt64,0}}, t::ProbProgTrace) + ptr = reinterpret(UInt64, pointer_from_objref(t)) + return Ops.fill(ptr, Int64[]) +end + +to_constraint_tensor(c::Constraint) = promote_to(TracedRArray{UInt64,0}, c) + +function from_constraint_tensor(constraint_tensor) + while !isready(constraint_tensor) + yield() + end + return unsafe_pointer_to_objref(Ptr{Any}(Array(constraint_tensor)[1]))::Constraint +end + +function promote_to(::Type{TracedRArray{UInt64,0}}, c::Constraint) + ptr = reinterpret(UInt64, pointer_from_objref(c)) + return Ops.fill(ptr, Int64[]) +end diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl new file mode 100644 index 0000000000..f5fa4fea38 --- /dev/null +++ b/test/probprog/generate.jl @@ -0,0 +1,142 @@ +using Reactant, Test, Random, Statistics +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -length(x) * log(σ) - length(x) / 2 * log(2π) - + sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function model(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, normal, s, σ, shape; symbol=:t, logpdf=normal_logpdf) + return t +end + +function two_normals(rng, μ, σ, shape) + _, x = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:x, logpdf=normal_logpdf) + _, y = ProbProg.sample(rng, normal, x, σ, shape; symbol=:y, logpdf=normal_logpdf) + return y +end + +function nested_model(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, two_normals, s, σ, shape; symbol=:t) + _, u = ProbProg.sample(rng, two_normals, t, σ, shape; symbol=:u) + return u +end + +@testset "Generate" begin + @testset "unconstrained" begin + shape = (1000,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + trace, weight = ProbProg.generate_(rng, model, μ, σ, shape) + @test mean(trace.retval[1]) ≈ 0.0 atol = 0.05 rtol = 0.05 + end + + @testset "constrained" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint = ProbProg.Constraint(:s => (fill(0.1, shape),)) + + trace, weight = ProbProg.generate_(rng, model, μ, σ, shape; constraint) + + @test trace.choices[:s][1] == constraint[ProbProg.Address(:s)][1] + + expected_weight = + normal_logpdf(constraint[ProbProg.Address(:s)][1], 0.0, 1.0, shape) + + normal_logpdf( + trace.choices[:t][1], constraint[ProbProg.Address(:s)][1], 1.0, shape + ) + @test weight ≈ expected_weight atol = 1e-6 + end + + @testset "composite addresses" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint = ProbProg.Constraint( + :s => (fill(0.1, shape),), + :t => :x => (fill(0.2, shape),), + :u => :y => (fill(0.3, shape),), + ) + + trace, weight = ProbProg.generate_(rng, nested_model, μ, σ, shape; constraint) + + @test trace.choices[:s][1] == fill(0.1, shape) + @test trace.subtraces[:t].choices[:x][1] == fill(0.2, shape) + @test trace.subtraces[:u].choices[:y][1] == fill(0.3, shape) + + s_weight = normal_logpdf(fill(0.1, shape), 0.0, 1.0, shape) + tx_weight = normal_logpdf(fill(0.2, shape), fill(0.1, shape), 1.0, shape) + ty_weight = normal_logpdf( + trace.subtraces[:t].choices[:y][1], fill(0.2, shape), 1.0, shape + ) + ux_weight = normal_logpdf( + trace.subtraces[:u].choices[:x][1], + trace.subtraces[:t].choices[:y][1], + 1.0, + shape, + ) + uy_weight = normal_logpdf( + fill(0.3, shape), trace.subtraces[:u].choices[:x][1], 1.0, shape + ) + + expected_weight = s_weight + tx_weight + ty_weight + ux_weight + uy_weight + @test weight ≈ expected_weight atol = 1e-6 + end + + @testset "compiled" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint1 = ProbProg.Constraint(:s => (fill(0.1, shape),)) + + constrained_addresses = ProbProg.extract_addresses(constraint1) + + constraint_ptr1 = Reactant.ConcreteRNumber( + reinterpret(UInt64, pointer_from_objref(constraint1)) + ) + + wrapper_fn(rng, constraint_ptr, μ, σ) = ProbProg.generate( + rng, model, μ, σ, shape; constraint_ptr, constrained_addresses + ) + + compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr1, μ, σ) + + trace1 = nothing + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint1 begin + trace1, _ = compiled_fn(rng, constraint_ptr1, μ, σ) + trace1 = ProbProg.from_trace_tensor(trace1) + end + + constraint2 = ProbProg.Constraint(:s => (fill(0.2, shape),)) + constraint_ptr2 = Reactant.ConcreteRNumber( + reinterpret(UInt64, pointer_from_objref(constraint2)) + ) + + trace2 = nothing + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint2 begin + trace2, _ = compiled_fn(rng, constraint_ptr2, μ, σ) + trace2 = ProbProg.from_trace_tensor(trace2) + end + + @test trace1.choices[:s][1] != trace2.choices[:s][1] + end +end diff --git a/test/probprog/hmc.jl b/test/probprog/hmc.jl new file mode 100644 index 0000000000..03c3394a40 --- /dev/null +++ b/test/probprog/hmc.jl @@ -0,0 +1,138 @@ +using Reactant, Test, Random +using Statistics +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -length(x) * log(σ) - length(x) / 2 * log(2π) - + sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function model(rng, xs) + _, param_a = ProbProg.sample( + rng, normal, 0.0, 5.0, (1,); symbol=:param_a, logpdf=normal_logpdf + ) + _, param_b = ProbProg.sample( + rng, normal, 0.0, 5.0, (1,); symbol=:param_b, logpdf=normal_logpdf + ) + + _, ys_a = ProbProg.sample( + rng, normal, param_a .+ xs[1:5], 0.5, (5,); symbol=:ys_a, logpdf=normal_logpdf + ) + + _, ys_b = ProbProg.sample( + rng, normal, param_b .+ xs[6:10], 0.5, (5,); symbol=:ys_b, logpdf=normal_logpdf + ) + + return vcat(ys_a, ys_b) +end + +function hmc_program( + rng, + model, + xs, + step_size, + num_steps, + mass, + initial_momentum, + constraint_ptr, + constrained_addresses, +) + t, _, _ = ProbProg.generate( + rng, + model, + xs; + constraint_ptr=constraint_ptr, + constrained_addresses=constrained_addresses, + ) + + t, accepted, _ = ProbProg.hmc( + rng, + t, + model, + xs; + selection=ProbProg.select(ProbProg.Address(:param_a), ProbProg.Address(:param_b)), + mass=mass, + step_size=step_size, + num_steps=num_steps, + initial_momentum=initial_momentum, + ) + + return t, accepted +end + +@testset "hmc" begin + seed = Reactant.to_rarray(UInt64[1, 5]) + rng = ReactantRNG(seed) + + xs = [-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5, 4.5] + ys_a = [-2.3, -1.6, -0.4, 0.6, 1.4] + ys_b = [-2.6, -1.4, -0.6, 0.4, 1.6] + obs = ProbProg.Constraint( + :param_a => ([0.0],), :param_b => ([0.0],), :ys_a => (ys_a,), :ys_b => (ys_b,) + ) + constrained_addresses = ProbProg.extract_addresses(obs) + constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(obs))) + + step_size = ConcreteRNumber(0.001) + num_steps_compile = ConcreteRNumber(1000) + num_steps_run = ConcreteRNumber(40000000) + mass = nothing + initial_momentum = ConcreteRArray([0.0, 0.0]) + + code = @code_hlo optimize = :probprog hmc_program( + rng, + model, + xs, + step_size, + num_steps_compile, + mass, + initial_momentum, + constraint_ptr, + constrained_addresses, + ) + @test contains(repr(code), "enzyme_probprog_get_flattened_samples_from_trace") + @test contains(repr(code), "enzyme_probprog_get_weight_from_trace") + @test !contains(repr(code), "enzyme.mh") + @test !contains(repr(code), "enzyme.mcmc") + + compile_time_s = @elapsed begin + compiled_fn = @compile optimize = :probprog hmc_program( + rng, + model, + xs, + step_size, + num_steps_compile, + mass, + initial_momentum, + constraint_ptr, + constrained_addresses, + ) + end + println("HMC compile time: $(round(compile_time_s * 1000, digits=2)) ms") + + seed_buffer = only(rng.seed.data).buffer + trace = nothing + GC.@preserve seed_buffer obs begin + run_time_s = @elapsed begin + trace_ptr, _ = compiled_fn( + rng, + model, + xs, + step_size, + num_steps_run, + mass, + initial_momentum, + constraint_ptr, + constrained_addresses, + ) + trace = ProbProg.from_trace_tensor(trace_ptr) + end + println("HMC run time: $(round(run_time_s * 1000, digits=2)) ms") + end + + # NumPyro results + @test only(trace.choices[:param_a])[1] ≈ 0.01327671 rtol = 1e-6 + @test only(trace.choices[:param_b])[1] ≈ -0.01965474 rtol = 1e-6 +end diff --git a/test/probprog/mh.jl b/test/probprog/mh.jl new file mode 100644 index 0000000000..b05021827d --- /dev/null +++ b/test/probprog/mh.jl @@ -0,0 +1,114 @@ +using Reactant, Test, Random +using Reactant: ProbProg, ReactantRNG + +# Reference: https://www.gen.dev/docs/stable/getting_started/linear_regression/ + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -length(x) * log(σ) - length(x) / 2 * log(2π) - + sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function model(rng, xs) + _, slope = ProbProg.sample( + rng, normal, 0.0, 2.0, (1,); symbol=:slope, logpdf=normal_logpdf + ) + _, intercept = ProbProg.sample( + rng, normal, 0.0, 10.0, (1,); symbol=:intercept, logpdf=normal_logpdf + ) + + _, ys = ProbProg.sample( + rng, + normal, + slope .* xs .+ intercept, + 1.0, + (length(xs),); + symbol=:ys, + logpdf=normal_logpdf, + ) + + return ys +end + +function mh_program(rng, model, xs, num_iters, constraint_ptr, constrained_addresses) + init_trace, _, _ = ProbProg.generate( + rng, + model, + xs; + constraint_ptr=constraint_ptr, + constrained_addresses=constrained_addresses, + ) + + trace_ptr = init_trace + @trace for _ in 1:num_iters + trace_ptr, _ = ProbProg.mh( + rng, trace_ptr, model, xs; selection=ProbProg.select(ProbProg.Address(:slope)) + ) + trace_ptr, _ = ProbProg.mh( + rng, + trace_ptr, + model, + xs; + selection=ProbProg.select(ProbProg.Address(:intercept)), + ) + end + + return trace_ptr +end + +@testset "linear_regression" begin + @testset "simulate" begin + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + + xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + xs_r = Reactant.to_rarray(xs) + + trace, _ = ProbProg.simulate_(rng, model, xs_r) + + @test haskey(trace.choices, :slope) + @test haskey(trace.choices, :intercept) + @test haskey(trace.choices, :ys) + end + + @testset "inference" begin + seed = Reactant.to_rarray(UInt64[1, 5]) + rng = ReactantRNG(seed) + + xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90] + obs = ProbProg.Constraint(:ys => (ys,)) + num_iters = ConcreteRNumber(10000) + constrained_addresses = ProbProg.extract_addresses(obs) + constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(obs))) + + code = @code_hlo optimize = :probprog mh_program( + rng, model, xs, 10000, constraint_ptr, constrained_addresses + ) + @test contains(repr(code), "enzyme_probprog_get_sample_from_trace") + @test contains(repr(code), "enzyme_probprog_get_weight_from_trace") + @test !contains(repr(code), "enzyme.mh") + + compiled_fn = @compile optimize = :probprog mh_program( + rng, model, xs, num_iters, constraint_ptr, constrained_addresses + ) + + trace = nothing + seed_buffer = only(rng.seed.data).buffer + num_iters = ConcreteRNumber(1000) + GC.@preserve seed_buffer obs begin + trace_ptr = compiled_fn( + rng, model, xs, num_iters, constraint_ptr, constrained_addresses + ) + trace = ProbProg.from_trace_tensor(trace_ptr) + end + + slope = only(trace.choices[:slope])[1] + intercept = only(trace.choices[:intercept])[1] + @show slope, intercept + + @test slope ≈ -2.0 rtol = 0.1 + @test intercept ≈ 10.0 rtol = 0.1 + end +end diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl new file mode 100644 index 0000000000..b7889c46dd --- /dev/null +++ b/test/probprog/sample.jl @@ -0,0 +1,88 @@ +using Reactant, Test, Random +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function one_sample(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape) + return s +end + +function two_samples(rng, μ, σ, shape) + _ = ProbProg.sample(rng, normal, μ, σ, shape) + _, t = ProbProg.sample(rng, normal, μ, σ, shape) + return t +end + +function compose(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape) + _, t = ProbProg.sample(rng, normal, s, σ, shape) + return t +end + +@testset "test" begin + @testset "normal_hlo" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + code = @code_hlo optimize = false ProbProg.sample(rng, normal, μ, σ, shape) + @test contains(repr(code), "enzyme.sample") + end + + @testset "two_samples_hlo" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + code = @code_hlo optimize = false ProbProg.sample(rng, two_samples, μ, σ, shape) + @test contains(repr(code), "enzyme.sample") + end + + @testset "compose" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + before = @code_hlo optimize = false ProbProg.untraced_call( + rng, compose, μ, σ, shape + ) + @test contains(repr(before), "enzyme.sample") + + after = @code_hlo optimize = :probprog ProbProg.untraced_call( + rng, compose, μ, σ, shape + ) + @test !contains(repr(after), "enzyme.sample") + end + + @testset "rng_state" begin + shape = (10,) + + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + rng1 = ReactantRNG(copy(seed)) + + _, X = @jit optimize = :probprog ProbProg.untraced_call( + rng1, one_sample, μ, σ, shape + ) + @test !all(rng1.seed .== seed) + + rng2 = ReactantRNG(copy(seed)) + _, Y = @jit optimize = :probprog ProbProg.untraced_call( + rng2, two_samples, μ, σ, shape + ) + + @test !all(rng2.seed .== seed) + @test !all(rng2.seed .== rng1.seed) + + @test !all(X .≈ Y) + end +end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl new file mode 100644 index 0000000000..3be45bc256 --- /dev/null +++ b/test/probprog/simulate.jl @@ -0,0 +1,115 @@ +using Reactant, Test, Random +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -length(x) * log(σ) - length(x) / 2 * log(2π) - + sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function product_two_normals(rng, μ, σ, shape) + _, a = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:a, logpdf=normal_logpdf) + _, b = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:b, logpdf=normal_logpdf) + return a .* b +end + +function model(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + _, t = ProbProg.sample(rng, normal, s, σ, shape; symbol=:t, logpdf=normal_logpdf) + return t +end + +function model2(rng, μ, σ, shape) + _, s = ProbProg.sample(rng, product_two_normals, μ, σ, shape; symbol=:s) + _, t = ProbProg.sample(rng, product_two_normals, s, σ, shape; symbol=:t) + return t +end + +@testset "Simulate" begin + @testset "hlo" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + before = @code_hlo optimize = false ProbProg.simulate(rng, model, μ, σ, shape) + @test contains(repr(before), "enzyme.simulate") + + after = @code_hlo optimize = :probprog ProbProg.simulate(rng, model, μ, σ, shape) + @test !contains(repr(after), "enzyme.simulate") + @test !contains(repr(after), "enzyme.addSampleToTrace") + @test !contains(repr(after), "enzyme.addWeightToTrace") + @test !contains(repr(after), "enzyme.addRetvalToTrace") + end + + @testset "normal_simulate" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace, weight = ProbProg.simulate_(rng, model, μ, σ, shape) + + @test size(trace.retval[1]) == shape + @test haskey(trace.choices, :s) + @test haskey(trace.choices, :t) + @test size(trace.choices[:s][1]) == shape + @test size(trace.choices[:t][1]) == shape + @test trace.weight isa Float64 + end + + @testset "simple_fake" begin + op(_, x, y) = x * y' + logpdf(res, _, _) = sum(res) + function fake_model(rng, x, y) + _, res = ProbProg.sample(rng, op, x, y; symbol=:matmul, logpdf=logpdf) + return res + end + + x = reshape(collect(Float64, 1:12), (4, 3)) + y = reshape(collect(Float64, 1:12), (4, 3)) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + + trace, weight = ProbProg.simulate_(rng, fake_model, x_ra, y_ra) + + @test Array(trace.retval[1]) == op(rng, x, y) + @test haskey(trace.choices, :matmul) + @test trace.choices[:matmul][1] == op(rng, x, y) + @test trace.weight == logpdf(op(rng, x, y), x, y) + end + + @testset "submodel_fake" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace, weight = ProbProg.simulate_(rng, model2, μ, σ, shape) + + @test size(trace.retval[1]) == shape + + @test length(trace.choices) == 2 + @test haskey(trace.choices, :s) + @test haskey(trace.choices, :t) + + @test length(trace.subtraces) == 2 + @test haskey(trace.subtraces[:s].choices, :a) + @test haskey(trace.subtraces[:s].choices, :b) + @test haskey(trace.subtraces[:t].choices, :a) + @test haskey(trace.subtraces[:t].choices, :b) + + @test size(trace.choices[:s][1]) == shape + @test size(trace.choices[:t][1]) == shape + + @test trace.weight isa Float64 + + @test trace.weight ≈ trace.subtraces[:s].weight + trace.subtraces[:t].weight + end +end diff --git a/test/runtests.jl b/test/runtests.jl index f812deee5c..0ed50fa776 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -69,4 +69,11 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Lux Integration" include("nn/lux.jl") end end + + if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "probprog" + @safetestset "ProbProg Sample" include("probprog/sample.jl") + @safetestset "ProbProg Simulate" include("probprog/simulate.jl") + @safetestset "ProbProg Generate" include("probprog/generate.jl") + @safetestset "ProbProg HMC" include("probprog/hmc.jl") + end end