diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 7413ab084b..44499a79f5 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -367,6 +367,21 @@ enzymeActivityAttrGet(MlirContext ctx, int32_t val) { (mlir::enzyme::Activity)val)); } +extern "C" MLIR_CAPI_EXPORTED MlirType enzymeTraceTypeGet(MlirContext ctx) { + return wrap(mlir::enzyme::TraceType::get(unwrap(ctx))); +} + +extern "C" MLIR_CAPI_EXPORTED MlirType +enzymeConstraintTypeGet(MlirContext ctx) { + return wrap(mlir::enzyme::ConstraintType::get(unwrap(ctx))); +} + +extern "C" MLIR_CAPI_EXPORTED MlirAttribute +enzymeSymbolAttrGet(MlirContext ctx, uint64_t symbol) { + mlir::Attribute attr = mlir::enzyme::SymbolAttr::get(unwrap(ctx), symbol); + return wrap(attr); +} + // Create profiler session and start profiling extern "C" tsl::ProfilerSession * CreateProfilerSession(uint32_t device_tracer_level, diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index 30dfda915f..9b01785c11 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -221,6 +221,8 @@ function CompileOptions(; :canonicalize, :just_batch, :none, + :probprog, + :probprog_no_lowering, ] end diff --git a/src/Compiler.jl b/src/Compiler.jl index 7bd66fff17..e1c6313f83 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1231,6 +1231,7 @@ end # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate # However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" +const probprog_pass::String = "probprog{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize\"}" function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true) pm = MLIR.IR.PassManager() @@ -1641,6 +1642,7 @@ function compile_mlir!( blas_int_width = sizeof(BLAS.BlasInt) * 8 lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \ blas_int_width=$blas_int_width}" + lower_enzyme_probprog_pass = "lower-enzyme-probprog{backend=$backend}" legalize_chlo_to_stablehlo = if legalize_stablehlo_to_mhlo || compile_options.legalize_chlo_to_stablehlo @@ -1807,6 +1809,122 @@ function compile_mlir!( ), "no_enzyme", ) + elseif compile_options.optimization_passes === :probprog_no_lowering + run_pass_pipeline!( + mod, + join( + if compile_options.raise_first + [ + "mark-func-memory-effects", + opt_passes, + kern, + raise_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + probprog_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ( + if compile_options.legalize_chlo_to_stablehlo + ["func.func(chlo-legalize-to-stablehlo)"] + else + [] + end + )..., + opt_passes2, + ] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + probprog_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ( + if compile_options.legalize_chlo_to_stablehlo + ["func.func(chlo-legalize-to-stablehlo)"] + else + [] + end + )..., + opt_passes2, + kern, + raise_passes, + ] + end, + ",", + ), + "probprog_no_lowering", + ) + elseif compile_options.optimization_passes === :probprog + run_pass_pipeline!( + mod, + join( + if compile_options.raise_first + [ + "mark-func-memory-effects", + opt_passes, + kern, + raise_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + probprog_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ( + if compile_options.legalize_chlo_to_stablehlo + ["func.func(chlo-legalize-to-stablehlo)"] + else + [] + end + )..., + opt_passes2, + lower_enzymexla_linalg_pass, + lower_enzyme_probprog_pass, + jit, + ] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + probprog_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ( + if compile_options.legalize_chlo_to_stablehlo + ["func.func(chlo-legalize-to-stablehlo)"] + else + [] + end + )..., + opt_passes2, + kern, + raise_passes, + lower_enzymexla_linalg_pass, + lower_enzyme_probprog_pass, + jit, + ] + end, + ",", + ), + "probprog", + ) elseif compile_options.optimization_passes === :only_enzyme run_pass_pipeline!( mod, diff --git a/src/Reactant.jl b/src/Reactant.jl index c3af41cc88..68c9bdc9f9 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -191,6 +191,7 @@ include("Tracing.jl") include("Compiler.jl") include("Overlay.jl") +include("probprog/ProbProg.jl") # Serialization include("serialization/Serialization.jl") diff --git a/src/Types.jl b/src/Types.jl index 48221a431d..4cb153b723 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -215,6 +215,7 @@ function ConcretePJRTArray( end Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data) +Base.isready(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = all(isready, x.data) XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data) function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) x.sharding isa Sharding.NoShardInfo && return XLA.device(only(x.data)) @@ -412,6 +413,7 @@ function ConcreteIFRTArray( end Base.wait(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = wait(x.data) +Base.isready(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = isready(x.data) XLA.client(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = XLA.client(x.data) function XLA.device(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) return XLA.device(x.data) diff --git a/src/probprog/Display.jl b/src/probprog/Display.jl new file mode 100644 index 0000000000..a81992eb71 --- /dev/null +++ b/src/probprog/Display.jl @@ -0,0 +1,87 @@ +# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104 +function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) + VERT = '\u2502' + PLUS = '\u251C' + HORZ = '\u2500' + LAST = '\u2514' + + indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) + indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) + + for i in vert_bars + indent_vert[i] = VERT + indent[i] = VERT + indent_last[i] = VERT + end + + indent_vert_str = join(indent_vert) + indent_str = join(indent) + indent_last_str = join(indent_last) + + sorted_choices = sort(collect(trace.choices); by=x -> x[1]) + n = length(sorted_choices) + + if trace.retval !== nothing + n += 1 + end + + if trace.weight !== nothing + n += 1 + end + + cur = 1 + + if trace.retval !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n") + cur += 1 + end + + if trace.weight !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n") + cur += 1 + end + + for (key, value) in sorted_choices + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") + cur += 1 + end + + sorted_subtraces = sort(collect(trace.subtraces); by=x -> x[1]) + n += length(sorted_subtraces) + + for (key, subtrace) in sorted_subtraces + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "subtrace on $(repr(key))\n") + _show_pretty( + io, subtrace, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre + 1) + ) + cur += 1 + end +end + +function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace) + println(io, "ProbProgTrace:") + if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing + println(io, " (empty)") + else + _show_pretty(io, trace, 0, ()) + end +end + +function Base.show(io::IO, trace::ProbProgTrace) + if get(io, :compact, false) + choices_count = length(trace.choices) + has_retval = trace.retval !== nothing + print(io, "ProbProgTrace($(choices_count) choices") + if has_retval + print(io, ", retval=$(trace.retval), weight=$(trace.weight)") + end + print(io, ")") + else + show(io, MIME"text/plain"(), trace) + end +end diff --git a/src/probprog/FFI.jl b/src/probprog/FFI.jl new file mode 100644 index 0000000000..70fb6c0618 --- /dev/null +++ b/src/probprog/FFI.jl @@ -0,0 +1,346 @@ +using ..Reactant: MLIR + +function initTrace(trace_ptr_ptr::Ptr{Ptr{Any}}) + tr = ProbProgTrace() + _keepalive!(tr) + + unsafe_store!(trace_ptr_ptr, pointer_from_objref(tr)) + return nothing +end + +function addSampleToTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_outputs_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_outputs = unsafe_load(num_outputs_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_outputs) + width_array = unsafe_wrap(Array, width_array, num_outputs) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_outputs) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_outputs) + + vals = Any[] + for i in 1:num_outputs + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + return nothing + end + + if ndims == 0 + push!(vals, unsafe_load(Ptr{julia_type}(sample_ptr))) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)))) + end + end + + trace.choices[symbol] = tuple(vals...) + + return nothing +end + +function addSubtrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + subtrace_ptr_ptr::Ptr{Ptr{Any}}, +) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + subtrace = unsafe_pointer_to_objref(unsafe_load(subtrace_ptr_ptr))::ProbProgTrace + + trace.subtraces[symbol] = subtrace + + return nothing +end + +function addWeightToTrace(trace_ptr_ptr::Ptr{Ptr{Any}}, weight_ptr::Ptr{Any}) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + trace.weight = unsafe_load(Ptr{Float64}(weight_ptr)) + return nothing +end + +function addRetvalToTrace( + trace_ptr_ptr::Ptr{Ptr{Any}}, + retval_ptr_array::Ptr{Ptr{Any}}, + num_results_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr))::ProbProgTrace + + num_results = unsafe_load(num_results_ptr) + + if num_results == 0 + return nothing + end + + ndims_array = unsafe_wrap(Array, ndims_array, num_results) + width_array = unsafe_wrap(Array, width_array, num_results) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_results) + retval_ptr_array = unsafe_wrap(Array, retval_ptr_array, num_results) + + vals = Any[] + for i in 1:num_results + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + retval_ptr = retval_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %lld\n"::Cstring, width::Int64 + )::Cvoid + return nothing + end + + if ndims == 0 + push!(vals, unsafe_load(Ptr{julia_type}(retval_ptr))) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(retval_ptr), Tuple(shape)))) + end + end + + trace.retval = tuple(vals...) + + return nothing +end + +function getSampleFromConstraint( + constraint_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr_array::Ptr{Ptr{Any}}, + num_samples_ptr::Ptr{UInt64}, + ndims_array::Ptr{UInt64}, + shape_ptr_array::Ptr{Ptr{UInt64}}, + width_array::Ptr{UInt64}, +) + constraint = unsafe_pointer_to_objref(unsafe_load(constraint_ptr_ptr))::Constraint + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + num_samples = unsafe_load(num_samples_ptr) + ndims_array = unsafe_wrap(Array, ndims_array, num_samples) + width_array = unsafe_wrap(Array, width_array, num_samples) + shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_samples) + sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_samples) + + tostore = get(constraint, Address(symbol), nothing) + + if tostore === nothing + @ccall printf( + "No constraint found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + return nothing + end + + for i in 1:num_samples + ndims = ndims_array[i] + width = width_array[i] + shape_ptr = shape_ptr_array[i] + sample_ptr = sample_ptr_array[i] + + julia_type = if width == 32 + Float32 + elseif width == 64 + Float64 + elseif width == 1 + Bool + else + nothing + end + + if julia_type === nothing + @ccall printf( + "Unsupported datatype width: %zd\n"::Cstring, width::Csize_t + )::Cvoid + return nothing + end + + if julia_type != eltype(tostore[i]) + @ccall printf( + "Type mismatch in constrained sample: %s != %s\n"::Cstring, + string(julia_type)::Cstring, + string(eltype(tostore[i]))::Cstring, + )::Cvoid + return nothing + end + + if ndims == 0 + unsafe_store!(Ptr{julia_type}(sample_ptr), tostore[i]) + else + shape = unsafe_wrap(Array, shape_ptr, ndims) + dest = unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)) + + if size(dest) != size(tostore[i]) + if length(size(dest)) != length(size(tostore[i])) + @ccall printf( + "Shape size mismatch in constrained sample: %zd != %zd\n"::Cstring, + length(size(dest))::Csize_t, + length(size(tostore[i]))::Csize_t, + )::Cvoid + return nothing + end + for i in 1:length(size(dest)) + d = size(dest)[i] + t = size(tostore[i])[i] + if d != t + @ccall printf( + "Shape mismatch in `%zd`th dimension of constrained sample: %zd != %zd\n"::Cstring, + i::Csize_t, + size(dest)[i]::Csize_t, + size(tostore[i])[i]::Csize_t, + )::Cvoid + return nothing + end + end + end + + dest .= tostore[i] + end + end + + return nothing +end + +function getSubconstraint( + constraint_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + subconstraint_ptr_ptr::Ptr{Ptr{Any}}, +) + constraint = unsafe_pointer_to_objref(unsafe_load(constraint_ptr_ptr))::Constraint + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr))::Symbol + + subconstraint = Constraint() + + for (key, value) in constraint + if key.path[1] == symbol + @assert isa(key, Address) "Expected Address type for constraint key" + @assert length(key.path) > 1 "Expected composite address with length > 1" + tail_address = Address(key.path[2:end]) + subconstraint[tail_address] = value + end + end + + if isempty(subconstraint) + @ccall printf( + "No subconstraint found for symbol: %s\n"::Cstring, string(symbol)::Cstring + )::Cvoid + return nothing + end + + _keepalive!(subconstraint) + unsafe_store!(subconstraint_ptr_ptr, pointer_from_objref(subconstraint)) + return nothing +end + +function __init__() + init_trace_ptr = @cfunction(initTrace, Cvoid, (Ptr{Ptr{Any}},)) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_init_trace::Cstring, init_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_sample_to_trace_ptr = @cfunction( + addSampleToTrace, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_sample_to_trace::Cstring, add_sample_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_subtrace_ptr = @cfunction( + addSubtrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_subtrace::Cstring, add_subtrace_ptr::Ptr{Cvoid} + )::Cvoid + + add_weight_to_trace_ptr = @cfunction(addWeightToTrace, Cvoid, (Ptr{Ptr{Any}}, Ptr{Any})) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_weight_to_trace::Cstring, add_weight_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + add_retval_to_trace_ptr = @cfunction( + addRetvalToTrace, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ), + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_retval_to_trace::Cstring, add_retval_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + get_sample_from_constraint_ptr = @cfunction( + getSampleFromConstraint, + Cvoid, + ( + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{Ptr{Any}}, + Ptr{UInt64}, + Ptr{UInt64}, + Ptr{Ptr{UInt64}}, + Ptr{UInt64}, + ) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_sample_from_constraint::Cstring, + get_sample_from_constraint_ptr::Ptr{Cvoid}, + )::Cvoid + + get_subconstraint_ptr = @cfunction( + getSubconstraint, Cvoid, (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Ptr{Any}}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_get_subconstraint::Cstring, get_subconstraint_ptr::Ptr{Cvoid} + )::Cvoid + + return nothing +end diff --git a/src/probprog/Inference.jl b/src/probprog/Inference.jl new file mode 100644 index 0000000000..6f9c85a7d9 --- /dev/null +++ b/src/probprog/Inference.jl @@ -0,0 +1,73 @@ +using ..Reactant: ConcreteRNumber +using ..Compiler: @compile + +function metropolis_hastings( + trace::ProbProgTrace, + sel::Selection; + compiled_cache::Union{Nothing,CompiledFnCache}=nothing, +) + if trace.fn === nothing || trace.rng === nothing + error("MH requires a trace with fn and rng recorded (use generate to create trace)") + end + + constraint_pairs = Pair{Symbol,Any}[] + for (sym, val) in trace.choices + if !(sym in sel) + push!(constraint_pairs, sym => val) + end + end + constraint = Constraint(constraint_pairs...) + + constrained_addresses = extract_addresses(constraint) + + cache_key = (typeof(trace.fn), constrained_addresses) + + compiled_fn = nothing + if compiled_cache !== nothing + compiled_fn = get(compiled_cache, cache_key, nothing) + end + + if compiled_fn === nothing + function wrapper_fn(rng, constraint_ptr, args...) + return generate_internal( + rng, trace.fn, args...; constraint_ptr, constrained_addresses + ) + end + + constraint_ptr = ConcreteRNumber( + reinterpret(UInt64, pointer_from_objref(constraint)) + ) + + compiled_fn = @compile optimize = :probprog wrapper_fn( + trace.rng, constraint_ptr, trace.args... + ) + + if compiled_cache !== nothing + compiled_cache[cache_key] = compiled_fn + end + end + + constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint))) + + old_gc_state = GC.enable(false) + new_trace_ptr = nothing + try + new_trace_ptr, _, _ = compiled_fn(trace.rng, constraint_ptr, trace.args...) + finally + GC.enable(old_gc_state) + end + + new_trace = unsafe_pointer_to_objref(Ptr{Any}(Array(new_trace_ptr)[1])) + + new_trace.fn = trace.fn + new_trace.args = trace.args + new_trace.rng = trace.rng + + log_alpha = new_trace.weight - trace.weight + + if log(rand()) < log_alpha + return (new_trace, true) + else + return (trace, false) + end +end diff --git a/src/probprog/Modeling.jl b/src/probprog/Modeling.jl new file mode 100644 index 0000000000..632ae7d6fb --- /dev/null +++ b/src/probprog/Modeling.jl @@ -0,0 +1,380 @@ +using ..Reactant: + MLIR, TracedUtils, AbstractRNG, AbstractConcreteArray, TracedRArray, ConcreteRNumber +using ..Compiler: @jit, @compile + +function process_mlir_function(f::Function, args::Tuple, op_name::String) + argprefix = gensym(op_name * "arg") + resprefix = gensym(op_name * "result") + resargprefix = gensym(op_name * "resarg") + + wrapper_fn = (all_args...) -> begin + res = f(all_args...) + (all_args[1], (res isa Tuple ? res : (res,))...) + end + + mlir_fn_res = invokelatest( + TracedUtils.make_mlir_fn, + wrapper_fn, + args, + (), + string(f), + false; + do_transpose=false, + args_in_result=:result, + argprefix, + resprefix, + resargprefix, + ) + + return mlir_fn_res, argprefix, resprefix, resargprefix +end + +function process_mlir_inputs(linear_args, f, args, fnwrap, argprefix) + inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 2 && fnwrap + TracedUtils.push_val!(inputs, f, path[3:end]) + else + if fnwrap && idx > 1 + idx -= 1 + end + TracedUtils.push_val!(inputs, args[idx], path[3:end]) + end + end + return inputs +end + +function process_mlir_outputs( + op, linear_results, result, f, args, fnwrap, resprefix, argprefix, start_idx=0 +) + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(op, i + start_idx) + + if TracedUtils.has_idx(res, resprefix) + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], resv) + end + + if TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if fnwrap && idx == 2 + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap && idx > 2 + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + end + + if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix) + TracedUtils.set!(res, (), resv) + end + end +end + +function sample( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + symbol::Symbol=gensym("sample"), + logpdf::Union{Nothing,Function}=nothing, +) where {Nargs} + res = sample_internal(rng, f, args...; symbol, logpdf) + + res = res[2:end] + + return length(res) == 1 ? res[1] : res +end + +function sample_internal( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + symbol::Symbol=gensym("sample"), + logpdf::Union{Nothing,Function}=nothing, +) where {Nargs} + args = (rng, args...) + mlir_fn_res, argprefix, resprefix, resargprefix = process_mlir_function( + f, args, "sample" + ) + + (; result, linear_args, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_mlir_inputs(linear_args, f, args, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + sym = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + + symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) + symbol_attr = @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, symbol_addr::UInt64 + )::MLIR.IR.Attribute + + # Construct MLIR attribute if Julia logpdf function is provided. + logpdf_attr = nothing + if logpdf !== nothing + # Just to get static information about the sample. TODO: kwargs? + example_sample = f(args...) + + # Remove AbstractRNG from `f`'s argument list if present, assuming that + # logpdf parameters follows `(sample, args...)` convention. + logpdf_args = nothing + if !isempty(args) && args[1] isa AbstractRNG + logpdf_args = (example_sample, Base.tail(args)...) # TODO: kwargs? + else + logpdf_args = (example_sample, args...) + end + + logpdf_mlir = invokelatest( + TracedUtils.make_mlir_fn, + logpdf, + logpdf_args, + (), + string(logpdf), + false; + do_transpose=false, + args_in_result=:result, + ) + + logpdf_sym = TracedUtils.get_attribute_by_name(logpdf_mlir.f, "sym_name") + logpdf_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(logpdf_sym)) + end + + sample_op = MLIR.Dialects.enzyme.sample( + inputs; + outputs=out_tys, + fn=fn_attr, + logpdf=logpdf_attr, + symbol=symbol_attr, + name=Base.String(symbol), + ) + + process_mlir_outputs( + sample_op, linear_results, result, f, args, fnwrap, resprefix, argprefix + ) + + return result +end + +function call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + res = @jit optimize = :probprog call_internal(rng, f, args...) + + res = map(res[2:end]) do r + r isa AbstractConcreteArray ? Array(r) : r + end + + return length(res) == 1 ? res[1] : res +end + +function call_internal(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + args = (rng, args...) + mlir_fn_res, argprefix, resprefix, resargprefix = process_mlir_function(f, args, "call") + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_mlir_inputs(linear_args, f, args, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + call_op = MLIR.Dialects.enzyme.untracedCall(inputs; outputs=out_tys, fn=fn_attr) + + process_mlir_outputs( + call_op, linear_results, result, f, args, fnwrap, resprefix, argprefix + ) + + return result +end + +function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs} + trace = nothing + + compiled_fn = @compile optimize = :probprog simulate_internal(rng, f, args...) + + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer begin + trace, _, _ = compiled_fn(rng, f, args...) + + while !isready(trace) + yield() + end + end + + trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1])) + + trace.rng = rng + trace.fn = f + trace.args = args + + return trace, trace.weight +end + +function simulate_internal( + rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs} +) where {Nargs} + args = (rng, args...) + mlir_fn_res, argprefix, resprefix, resargprefix = process_mlir_function( + f, args, "simulate" + ) + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_mlir_inputs(linear_args, f, args, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) + + simulate_op = MLIR.Dialects.enzyme.simulate( + inputs; trace=trace_ty, weight=weight_ty, outputs=out_tys, fn=fn_attr + ) + + process_mlir_outputs( + simulate_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2 + ) + + trace = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [MLIR.IR.result(simulate_op, 1)]; + outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))], + ), + 1, + ) + + trace = TracedRArray{UInt64,0}((), trace, ()) + weight = TracedRArray{Float64,0}((), MLIR.IR.result(simulate_op, 2), ()) + + return trace, weight, result +end + +function generate( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + constraint::Constraint=Constraint(), +) where {Nargs} + trace = nothing + + constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint))) + + constrained_addresses = extract_addresses(constraint) + + function wrapper_fn(rng, constraint_ptr, args...) + return generate_internal(rng, f, args...; constraint_ptr, constrained_addresses) + end + + compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr, args...) + + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint begin + trace, _, _ = compiled_fn(rng, constraint_ptr, args...) + + while !isready(trace) + yield() + end + end + + trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1])) + + trace.rng = rng + trace.fn = f + trace.args = args + + return trace, trace.weight +end + +function generate_internal( + rng::AbstractRNG, + f::Function, + args::Vararg{Any,Nargs}; + constraint_ptr::TracedRNumber, + constrained_addresses::Set{Address}, +) where {Nargs} + args = (rng, args...) + mlir_fn_res, argprefix, resprefix, resargprefix = process_mlir_function( + f, args, "generate" + ) + + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + inputs = process_mlir_inputs(linear_args, f, args, fnwrap, argprefix) + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + constraint_ty = @ccall MLIR.API.mlir_c.enzymeConstraintTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + + constraint_val = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [TracedUtils.get_mlir_data(constraint_ptr)]; outputs=[constraint_ty] + ), + 1, + ) + + constrained_addresses_attr = MLIR.IR.Attribute[] + for address in constrained_addresses + address_attr = MLIR.IR.Attribute[] + for sym in address.path + sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) + push!( + address_attr, + @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, sym_addr::UInt64 + )::MLIR.IR.Attribute + ) + end + push!(constrained_addresses_attr, MLIR.IR.Attribute(address_attr)) + end + + trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet( + MLIR.IR.context()::MLIR.API.MlirContext + )::MLIR.IR.Type + weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)) + + generate_op = MLIR.Dialects.enzyme.generate( + inputs, + constraint_val; + trace=trace_ty, + weight=weight_ty, + outputs=out_tys, + fn=fn_attr, + constrained_addresses=MLIR.IR.Attribute(constrained_addresses_attr), + ) + + process_mlir_outputs( + generate_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2 + ) + + trace = MLIR.IR.result( + MLIR.Dialects.builtin.unrealized_conversion_cast( + [MLIR.IR.result(generate_op, 1)]; + outputs=[MLIR.IR.TensorType(Int64[], MLIR.IR.Type(UInt64))], + ), + 1, + ) + + trace = TracedRArray{UInt64,0}((), trace, ()) + weight = TracedRArray{Float64,0}((), MLIR.IR.result(generate_op, 2), ()) + + return trace, weight, result +end diff --git a/src/probprog/ProbProg.jl b/src/probprog/ProbProg.jl new file mode 100644 index 0000000000..b795677a52 --- /dev/null +++ b/src/probprog/ProbProg.jl @@ -0,0 +1,29 @@ +module ProbProg + +using ..Reactant: + MLIR, + TracedUtils, + AbstractConcreteArray, + AbstractConcreteNumber, + AbstractRNG, + TracedRArray, + TracedRNumber, + ConcreteRNumber, + Ops +using ..Compiler: @jit, @compile +using Enzyme + +include("Types.jl") +include("FFI.jl") +include("Modeling.jl") +include("Inference.jl") +include("Display.jl") + +export ProbProgTrace, Constraint, Selection, CompiledFnCache, Address +export get_choices, select, choicemap, with_compiled_cache + +export sample, call, simulate, generate + +export metropolis_hastings + +end diff --git a/src/probprog/Types.jl b/src/probprog/Types.jl new file mode 100644 index 0000000000..51df28e176 --- /dev/null +++ b/src/probprog/Types.jl @@ -0,0 +1,94 @@ +using Base: ReentrantLock + +mutable struct ProbProgTrace + choices::Dict{Symbol,Any} + retval::Any + weight::Any + subtraces::Dict{Symbol,Any} + rng::Union{Nothing,AbstractRNG} + fn::Union{Nothing,Function} + args::Union{Nothing,Tuple} + + function ProbProgTrace() + return new( + Dict{Symbol,Any}(), + nothing, + nothing, + Dict{Symbol,Any}(), + nothing, + nothing, + nothing, + ) + end +end + +struct Address + path::Vector{Symbol} + + Address(path::Vector{Symbol}) = new(path) +end + +Address(sym::Symbol) = Address([sym]) +Address(syms::Symbol...) = Address([syms...]) + +Base.:(==)(a::Address, b::Address) = a.path == b.path +Base.hash(a::Address, h::UInt) = hash(a.path, h) + +mutable struct Constraint <: AbstractDict{Address,Any} + dict::Dict{Address,Any} + + function Constraint(pairs::Pair...) + dict = Dict{Address,Any}() + for pair in pairs + symbols = Symbol[] + current = pair + while isa(current, Pair) && isa(current.first, Symbol) + push!(symbols, current.first) + current = current.second + end + dict[Address(symbols...)] = current + end + return new(dict) + end + + Constraint() = new(Dict{Address,Any}()) + Constraint(d::Dict{Address,Any}) = new(d) +end + +Base.getindex(c::Constraint, k::Address) = c.dict[k] +Base.setindex!(c::Constraint, v, k::Address) = (c.dict[k] = v) +Base.delete!(c::Constraint, k::Address) = delete!(c.dict, k) +Base.keys(c::Constraint) = keys(c.dict) +Base.values(c::Constraint) = values(c.dict) +Base.iterate(c::Constraint) = iterate(c.dict) +Base.iterate(c::Constraint, state) = iterate(c.dict, state) +Base.length(c::Constraint) = length(c.dict) +Base.isempty(c::Constraint) = isempty(c.dict) +Base.haskey(c::Constraint, k::Address) = haskey(c.dict, k) +Base.get(c::Constraint, k::Address, default) = get(c.dict, k, default) + +extract_addresses(constraint::Constraint) = Set(keys(constraint)) + +const Selection = Set{Symbol} +const CompiledFnCache = Dict{Tuple{Type,Set{Address}},Any} + +const _probprog_ref_lock = ReentrantLock() +const _probprog_refs = IdDict() + +function _keepalive!(tr::Any) + lock(_probprog_ref_lock) + try + _probprog_refs[tr] = tr + finally + unlock(_probprog_ref_lock) + end + return tr +end + +get_choices(trace::ProbProgTrace) = trace.choices +select(syms::Symbol...) = Set(syms) + +function with_compiled_cache(f) + cache = CompiledFnCache() + return f(cache) +end diff --git a/test/probprog/blr.jl b/test/probprog/blr.jl new file mode 100644 index 0000000000..1dafcce76c --- /dev/null +++ b/test/probprog/blr.jl @@ -0,0 +1,44 @@ +using Reactant, Test, Random +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +bernoulli_logit(rng, logit, shape) = rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit))) +bernoulli_logit_logpdf(x, logit, _) = sum(x .* logit .- log1p.(exp.(logit))) + +# https://github.com/facebookresearch/pplbench/blob/main/pplbench/models/logistic_regression.py +function blr(rng, N, K) + # α ~ Normal(0, 10, size = 1) + α = ProbProg.sample(rng, normal, 0, 10, (1,); symbol=:α, logpdf=normal_logpdf) + + # β ~ Normal(0, 2.5, size = K) + β = ProbProg.sample(rng, normal, 0, 2.5, (K,); symbol=:β, logpdf=normal_logpdf) + + # X ~ Normal(0, 10, size = (N, K)) + X = ProbProg.sample(rng, normal, 0, 10, (N, K); symbol=:X, logpdf=normal_logpdf) + + # μ = α .+ X * β + μ = α .+ X * β + + Y = ProbProg.sample( + rng, bernoulli_logit, μ, (N,); symbol=:Y, logpdf=bernoulli_logit_logpdf + ) + + return Y +end + +@testset "BLR" begin + N = 5 # number of observations + K = 3 # number of features + seed = Reactant.to_rarray(UInt64[1, 4]) + + rng = ReactantRNG(seed) + + trace, _ = ProbProg.simulate(rng, blr, N, K) + + @test size(trace.retval[1]) == (N,) +end diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl new file mode 100644 index 0000000000..2bae4df530 --- /dev/null +++ b/test/probprog/generate.jl @@ -0,0 +1,149 @@ +using Reactant, Test, Random, Statistics +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function model(rng, μ, σ, shape) + s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + t = ProbProg.sample(rng, normal, s, σ, shape; symbol=:t, logpdf=normal_logpdf) + return t +end + +function two_normals(rng, μ, σ, shape) + x = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:x, logpdf=normal_logpdf) + y = ProbProg.sample(rng, normal, x, σ, shape; symbol=:y, logpdf=normal_logpdf) + return y +end + +function nested_model(rng, μ, σ, shape) + s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + t = ProbProg.sample(rng, two_normals, s, σ, shape; symbol=:t) + u = ProbProg.sample(rng, two_normals, t, σ, shape; symbol=:u) + return u +end + +@testset "Generate" begin + @testset "unconstrained" begin + shape = (1000,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + trace, weight = ProbProg.generate(rng, model, μ, σ, shape) + @test mean(trace.retval[1]) ≈ 0.0 atol = 0.05 rtol = 0.05 + end + + @testset "constrained" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint = ProbProg.Constraint(:s => (fill(0.1, shape),)) + + trace, weight = ProbProg.generate(rng, model, μ, σ, shape; constraint) + + @test trace.choices[:s][1] == constraint[ProbProg.Address(:s)][1] + + expected_weight = + normal_logpdf(constraint[ProbProg.Address(:s)][1], 0.0, 1.0, shape) + + normal_logpdf( + trace.choices[:t][1], constraint[ProbProg.Address(:s)][1], 1.0, shape + ) + @test weight ≈ expected_weight atol = 1e-6 + end + + @testset "composite addresses" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint = ProbProg.Constraint( + :s => (fill(0.1, shape),), + :t => :x => (fill(0.2, shape),), + :u => :y => (fill(0.3, shape),), + ) + + trace, weight = ProbProg.generate(rng, nested_model, μ, σ, shape; constraint) + + @test trace.choices[:s][1] == fill(0.1, shape) + @test trace.subtraces[:t].choices[:x][1] == fill(0.2, shape) + @test trace.subtraces[:u].choices[:y][1] == fill(0.3, shape) + + s_weight = normal_logpdf(fill(0.1, shape), 0.0, 1.0, shape) + tx_weight = normal_logpdf(fill(0.2, shape), fill(0.1, shape), 1.0, shape) + ty_weight = normal_logpdf( + trace.subtraces[:t].choices[:y][1], fill(0.2, shape), 1.0, shape + ) + ux_weight = normal_logpdf( + trace.subtraces[:u].choices[:x][1], + trace.subtraces[:t].choices[:y][1], + 1.0, + shape, + ) + uy_weight = normal_logpdf( + fill(0.3, shape), trace.subtraces[:u].choices[:x][1], 1.0, shape + ) + + expected_weight = s_weight + tx_weight + ty_weight + ux_weight + uy_weight + @test weight ≈ expected_weight atol = 1e-6 + end + + @testset "compiled" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + constraint1 = ProbProg.Constraint(:s => (fill(0.1, shape),)) + + constrained_addresses = ProbProg.extract_addresses(constraint1) + + constraint_ptr1 = Reactant.ConcreteRNumber( + reinterpret(UInt64, pointer_from_objref(constraint1)) + ) + + wrapper_fn(rng, constraint_ptr, μ, σ) = ProbProg.generate_internal( + rng, model, μ, σ, shape; constraint_ptr, constrained_addresses + ) + + compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr1, μ, σ) + + trace1 = nothing + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint1 begin + trace1, _ = compiled_fn(rng, constraint_ptr1, μ, σ) + + while !isready(trace1) + yield() + end + end + trace1 = unsafe_pointer_to_objref(Ptr{Any}(Array(trace1)[1])) + + constraint2 = ProbProg.Constraint(:s => (fill(0.2, shape),)) + constraint_ptr2 = Reactant.ConcreteRNumber( + reinterpret(UInt64, pointer_from_objref(constraint2)) + ) + + trace2 = nothing + seed_buffer = only(rng.seed.data).buffer + GC.@preserve seed_buffer constraint2 begin + trace2, _ = compiled_fn(rng, constraint_ptr2, μ, σ) + + while !isready(trace2) + yield() + end + end + trace2 = unsafe_pointer_to_objref(Ptr{Any}(Array(trace2)[1])) + + @test trace1.choices[:s][1] != trace2.choices[:s][1] + end +end diff --git a/test/probprog/linear_regression.jl b/test/probprog/linear_regression.jl new file mode 100644 index 0000000000..1a4ff0b181 --- /dev/null +++ b/test/probprog/linear_regression.jl @@ -0,0 +1,87 @@ +using Reactant, Test, Random +using Reactant: ProbProg, ReactantRNG + +# Reference: https://www.gen.dev/docs/stable/getting_started/linear_regression/ + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function my_model(rng, xs) + slope = ProbProg.sample( + rng, normal, 0.0, 2.0, (1,); symbol=:slope, logpdf=normal_logpdf + ) + intercept = ProbProg.sample( + rng, normal, 0.0, 10.0, (1,); symbol=:intercept, logpdf=normal_logpdf + ) + + ys = ProbProg.sample( + rng, + normal, + slope .* xs .+ intercept, + 1.0, + (length(xs),); + symbol=:ys, + logpdf=normal_logpdf, + ) + + return ys +end + +function my_inference_program(xs, ys, num_iters) + xs_r = Reactant.to_rarray(xs) + + observations = ProbProg.Constraint(:ys => (ys,)) + + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + + trace, _ = ProbProg.generate(rng, my_model, xs_r; constraint=observations) + + trace = ProbProg.with_compiled_cache() do cache + local t = trace + for _ in 1:num_iters + t, _ = ProbProg.metropolis_hastings( + t, ProbProg.select(:slope); compiled_cache=cache + ) + t, _ = ProbProg.metropolis_hastings( + t, ProbProg.select(:intercept); compiled_cache=cache + ) + end + return t + end + + choices = ProbProg.get_choices(trace) + return (Array(choices[:slope][1])[1], Array(choices[:intercept][1])[1]) +end + +@testset "linear_regression" begin + @testset "simulate" begin + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + + xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + xs_r = Reactant.to_rarray(xs) + + trace, _ = ProbProg.simulate(rng, my_model, xs_r) + + @test haskey(trace.choices, :slope) + @test haskey(trace.choices, :intercept) + @test haskey(trace.choices, :ys) + end + + @testset "inference" begin + Random.seed!(1) # For Julia side RNG + xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90] + + slope, intercept = my_inference_program(xs, ys, 10000) + + @show slope, intercept + + @test slope ≈ -2.0 rtol = 0.05 + @test intercept ≈ 10.0 rtol = 0.05 + end +end diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl new file mode 100644 index 0000000000..28a4ab6ee9 --- /dev/null +++ b/test/probprog/sample.jl @@ -0,0 +1,80 @@ +using Reactant, Test, Random +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function one_sample(rng, μ, σ, shape) + s = ProbProg.sample(rng, normal, μ, σ, shape) + return s +end + +function two_samples(rng, μ, σ, shape) + _ = ProbProg.sample(rng, normal, μ, σ, shape) + t = ProbProg.sample(rng, normal, μ, σ, shape) + return t +end + +function compose(rng, μ, σ, shape) + s = ProbProg.sample(rng, normal, μ, σ, shape) + t = ProbProg.sample(rng, normal, s, σ, shape) + return t +end + +@testset "test" begin + @testset "normal_hlo" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + code = @code_hlo optimize = false ProbProg.sample(rng, normal, μ, σ, shape) + @test contains(repr(code), "enzyme.sample") + end + + @testset "two_samples_hlo" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + code = @code_hlo optimize = false ProbProg.sample(rng, two_samples, μ, σ, shape) + @test contains(repr(code), "enzyme.sample") + end + + @testset "compose" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + before = @code_hlo optimize = false ProbProg.call(rng, compose, μ, σ, shape) + @test contains(repr(before), "enzyme.sample") + + after = @code_hlo optimize = :probprog ProbProg.call(rng, compose, μ, σ, shape) + @test !contains(repr(after), "enzyme.sample") + end + + @testset "rng_state" begin + shape = (10,) + + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + rng1 = ReactantRNG(copy(seed)) + + X = ProbProg.call(rng1, one_sample, μ, σ, shape) + @test !all(rng1.seed .== seed) + + rng2 = ReactantRNG(copy(seed)) + Y = ProbProg.call(rng2, two_samples, μ, σ, shape) + + @test !all(rng2.seed .== seed) + @test !all(rng2.seed .== rng1.seed) + + @test !all(X .≈ Y) + end +end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl new file mode 100644 index 0000000000..0a5a870b14 --- /dev/null +++ b/test/probprog/simulate.jl @@ -0,0 +1,125 @@ +using Reactant, Test, Random +using Reactant: ProbProg, ReactantRNG + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function normal_logpdf(x, μ, σ, _) + return -sum(log.(σ)) - length(x) / 2 * log(2π) - sum((x .- μ) .^ 2 ./ (2 .* (σ .^ 2))) +end + +function product_two_normals(rng, μ, σ, shape) + a = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:a, logpdf=normal_logpdf) + b = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:b, logpdf=normal_logpdf) + return a .* b +end + +function model(rng, μ, σ, shape) + s = ProbProg.sample(rng, normal, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + t = ProbProg.sample(rng, normal, s, σ, shape; symbol=:t, logpdf=normal_logpdf) + return t +end + +function model2(rng, μ, σ, shape) + s = ProbProg.sample(rng, product_two_normals, μ, σ, shape; symbol=:s) + t = ProbProg.sample(rng, product_two_normals, s, σ, shape; symbol=:t) + return t +end + +@testset "Simulate" begin + @testset "hlo" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + before = @code_hlo optimize = false ProbProg.simulate_internal( + rng, model, μ, σ, shape + ) + @test contains(repr(before), "enzyme.simulate") + + unlowered = @code_hlo optimize = :probprog_no_lowering ProbProg.simulate_internal( + rng, model, μ, σ, shape + ) + @test !contains(repr(unlowered), "enzyme.simulate") + @test contains(repr(unlowered), "enzyme.addSampleToTrace") + @test contains(repr(unlowered), "enzyme.addWeightToTrace") + @test contains(repr(unlowered), "enzyme.addRetvalToTrace") + + after = @code_hlo optimize = :probprog ProbProg.simulate_internal( + rng, model, μ, σ, shape + ) + @test !contains(repr(after), "enzyme.simulate") + @test !contains(repr(after), "enzyme.addSampleToTrace") + @test !contains(repr(after), "enzyme.addWeightToTrace") + @test !contains(repr(after), "enzyme.addRetvalToTrace") + end + + @testset "normal_simulate" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace, weight = ProbProg.simulate(rng, model, μ, σ, shape) + + @test size(trace.retval[1]) == shape + @test haskey(trace.choices, :s) + @test haskey(trace.choices, :t) + @test size(trace.choices[:s][1]) == shape + @test size(trace.choices[:t][1]) == shape + @test trace.weight isa Float64 + end + + @testset "simple_fake" begin + op(_, x, y) = x * y' + logpdf(res, _, _) = sum(res) + function fake_model(rng, x, y) + return ProbProg.sample(rng, op, x, y; symbol=:matmul, logpdf=logpdf) + end + + x = reshape(collect(Float64, 1:12), (4, 3)) + y = reshape(collect(Float64, 1:12), (4, 3)) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + + trace, weight = ProbProg.simulate(rng, fake_model, x_ra, y_ra) + + @test Array(trace.retval[1]) == op(rng, x, y) + @test haskey(trace.choices, :matmul) + @test trace.choices[:matmul][1] == op(rng, x, y) + @test trace.weight == logpdf(op(rng, x, y), x, y) + end + + @testset "submodel_fake" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + rng = ReactantRNG(seed) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace, weight = ProbProg.simulate(rng, model2, μ, σ, shape) + + @test size(trace.retval[1]) == shape + + @test length(trace.choices) == 2 + @test haskey(trace.choices, :s) + @test haskey(trace.choices, :t) + + @test length(trace.subtraces) == 2 + @test haskey(trace.subtraces[:s].choices, :a) + @test haskey(trace.subtraces[:s].choices, :b) + @test haskey(trace.subtraces[:t].choices, :a) + @test haskey(trace.subtraces[:t].choices, :b) + + @test size(trace.choices[:s][1]) == shape + @test size(trace.choices[:t][1]) == shape + + @test trace.weight isa Float64 + + @test trace.weight ≈ trace.subtraces[:s].weight + trace.subtraces[:t].weight + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 411cf443ea..e7998129f7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,4 +60,12 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Lux Integration" include("nn/lux.jl") end end + + if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "probprog" + @safetestset "ProbProg Sample" include("probprog/sample.jl") + @safetestset "ProbProg BLR" include("probprog/blr.jl") + @safetestset "ProbProg Simulate" include("probprog/simulate.jl") + @safetestset "ProbProg Generate" include("probprog/generate.jl") + @safetestset "ProbProg Linear Regression" include("probprog/linear_regression.jl") + end end