Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
babc7d0
static HMC
sbrantq Oct 26, 2025
a3f69d2
test
sbrantq Oct 26, 2025
a0a47ff
refactored with callcache
sbrantq Oct 30, 2025
292635e
automated tracing of ProbProgTrace and Constraint structs
sbrantq Oct 31, 2025
4de87c9
clean up
sbrantq Oct 31, 2025
865f76f
format
sbrantq Oct 31, 2025
5c5ce42
clean up
sbrantq Nov 2, 2025
84eadae
inv mass matrices instead
sbrantq Nov 7, 2025
f741d07
CI
sbrantq Nov 9, 2025
0aa6b23
rename
sbrantq Nov 11, 2025
3141753
rename test
sbrantq Nov 11, 2025
be04b84
enzyme.randomSplit op test
sbrantq Dec 8, 2025
2ce63e7
enzyme.random stat property check
sbrantq Dec 8, 2025
0368a56
pointwise compare against `jax.random.*`
sbrantq Dec 9, 2025
9da1a4f
save
sbrantq Dec 9, 2025
af6e113
complex seed test
sbrantq Dec 10, 2025
06a912b
fix
sbrantq Dec 29, 2025
820b406
debug
sbrantq Dec 29, 2025
76a5b13
ordered set
sbrantq Jan 4, 2026
d9da436
layout fix
sbrantq Jan 4, 2026
8b35b9c
extend mcmc
sbrantq Jan 4, 2026
9f042d1
interface
sbrantq Jan 7, 2026
d0987e6
fix ordering
sbrantq Jan 13, 2026
f0cd0d3
distributions boilerplate
sbrantq Jan 13, 2026
d390e67
context -> current_context
sbrantq Jan 31, 2026
0033208
dist transform
sbrantq Jan 31, 2026
24ebaec
constraint transform handling
sbrantq Jan 31, 2026
a47bf32
fix up tests
sbrantq Feb 1, 2026
3d4259d
display stat
sbrantq Feb 1, 2026
4602094
clean up
sbrantq Feb 1, 2026
428c65e
format
sbrantq Feb 1, 2026
2771ba1
pass debug dump flag
sbrantq Feb 1, 2026
b270e5b
clean up
sbrantq Feb 1, 2026
d9d6510
deps
sbrantq Feb 1, 2026
f4025ab
qa fix
sbrantq Feb 1, 2026
13aa869
fix version
sbrantq Feb 1, 2026
cf7afa3
make mcmcdiagnostics ext
sbrantq Feb 1, 2026
e0e52af
run probprog tests
sbrantq Feb 1, 2026
bc3f44e
remove
sbrantq Feb 4, 2026
95168e5
clean up
sbrantq Feb 4, 2026
7e58c99
clean up
sbrantq Feb 4, 2026
a9ae0bb
clean up
sbrantq Feb 4, 2026
9b663a1
simulate
sbrantq Feb 4, 2026
9840dc1
generate
sbrantq Feb 4, 2026
3b0e2c5
mh
sbrantq Feb 4, 2026
398db9e
tests fix up
sbrantq Feb 5, 2026
09ef76e
hmc/nuts migrate
sbrantq Feb 5, 2026
4b148e0
debug dump fix up
sbrantq Feb 5, 2026
19db286
nuts pointwise checks against numpyro
sbrantq Feb 5, 2026
ec2d826
fmt
sbrantq Feb 5, 2026
c8718e5
hmc pointwise
sbrantq Feb 5, 2026
bc631bb
[DROPME] enzymexla hash
sbrantq Feb 6, 2026
df6a63d
fix
sbrantq Feb 6, 2026
84a8b0a
fix
sbrantq Feb 6, 2026
d66bf9f
format
sbrantq Feb 6, 2026
8e3279b
get rid of user end scopedvalue
sbrantq Feb 6, 2026
4ccf8ef
sample test fix
sbrantq Feb 6, 2026
9c44d8a
qa
sbrantq Feb 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ jobs:
- core
- nn
- integration
- probprog
runtime:
- "pjrt"
- "ifrt"
Expand Down Expand Up @@ -95,6 +96,7 @@ jobs:
- core
- nn
- integration
- probprog
runtime:
- "pjrt"
- "ifrt"
Expand Down
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
p7zip_jll = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"

[weakdeps]
Expand All @@ -45,6 +46,7 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Float8s = "81dfefd7-55b0-40c6-a251-db853704e186"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
Expand All @@ -54,7 +56,6 @@ PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -70,6 +71,7 @@ ReactantFFTWExt = ["FFTW", "AbstractFFTs", "LinearAlgebra"]
ReactantFillArraysExt = "FillArrays"
ReactantFloat8sExt = "Float8s"
ReactantKernelAbstractionsExt = "KernelAbstractions"
ReactantMCMCDiagnosticToolsExt = "MCMCDiagnosticTools"
ReactantMPIExt = "MPI"
ReactantNNlibExt = ["NNlib", "Statistics"]
ReactantNPZExt = "NPZ"
Expand Down Expand Up @@ -108,6 +110,7 @@ LLVM = "9.4"
LLVMOpenMP_jll = "18.1.7"
Libdl = "1.10"
LinearAlgebra = "1.10"
MCMCDiagnosticTools = "0.3.11"
MPI = "0.20"
NNlib = "0.9.26"
NPZ = "0.4"
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"

NSYNC_SHA256 = ""

ENZYMEXLA_COMMIT = "07ba19aca13d9456e86ec3e2cc9e72719f0e982d"
ENZYMEXLA_COMMIT = "17fc8bb253c37b663ab6af27796387870230613b"

ENZYMEXLA_SHA256 = ""

Expand Down
26 changes: 26 additions & 0 deletions ext/ReactantMCMCDiagnosticToolsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module ReactantMCMCDiagnosticToolsExt

using Reactant.ProbProg: ProbProg
using MCMCDiagnosticTools: ess, rhat

function ProbProg._compute_ess(samples::AbstractVector)
x = collect(Float64, samples)
n = length(x)
if n < 4
return Float64(n)
end
x_matrix = reshape(x, n, 1)
return ess(x_matrix)
end

function ProbProg._compute_rhat(samples::AbstractVector)
x = collect(Float64, samples)
n = length(x)
if n < 4
return NaN
end
x_matrix = reshape(x, n, 1)
return rhat(x_matrix)
end

end
1 change: 1 addition & 0 deletions src/CompileOptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ function CompileOptions(;
:canonicalize,
:just_batch,
:none,
:probprog,
]
end

Expand Down
77 changes: 77 additions & 0 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import ..ReactantCore: correct_maybe_bcast_call
const DEBUG_PRINT_CODEGEN = Ref(false)
const DEBUG_DISABLE_RESHARDING = Ref(false)
const DEBUG_ALIASED_BUFFER_ASSIGNMENT_ERROR = Ref(false)
const DEBUG_PROBPROG_DUMP_VALUE = Ref(false)
const DEBUG_PROBPROG_DISABLE_OPT = Ref(true)

const DEBUG_BUFFER_POINTERS_STORE_DICT = Base.IdDict()

Expand Down Expand Up @@ -1439,6 +1441,16 @@ end
# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize,arith-raise{stablehlo=true}\"}"

function probprog_pass(;
debug_dump::Bool=DEBUG_PROBPROG_DUMP_VALUE[],
disable_optimizations::Bool=DEBUG_PROBPROG_DISABLE_OPT[],
)
if !disable_optimizations
# TODO(#2063): Add probprog optimization passes
end
return "probprog{debug-dump=$debug_dump postpasses=\"arith-raise{stablehlo=true}\"}"
end

function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true)
pm = MLIR.IR.PassManager()
MLIR.IR.enable_verifier!(pm, enable_verifier)
Expand Down Expand Up @@ -2063,6 +2075,71 @@ function compile_mlir!(
),
"no_enzyme",
)
elseif compile_options.optimization_passes === :probprog
run_pass_pipeline!(
mod,
join(
if compile_options.raise_first
[
"mark-func-memory-effects",
opt_passes,
kern,
raise_passes,
"enzyme-batch",
opt_passes2,
probprog_pass(),
"lower-probprog-to-stablehlo{backend=$backend}",
"outline-enzyme-regions",
enzyme_pass,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
(
if compile_options.legalize_chlo_to_stablehlo
["func.func(chlo-legalize-to-stablehlo)"]
else
[]
end
)...,
opt_passes2,
lower_enzymexla_linalg_pass,
"lower-probprog-trace-ops{backend=$backend}",
jit,
]
else
[
"mark-func-memory-effects",
opt_passes,
"enzyme-batch",
opt_passes2,
probprog_pass(),
"lower-probprog-to-stablehlo{backend=$backend}",
"outline-enzyme-regions",
enzyme_pass,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
(
if compile_options.legalize_chlo_to_stablehlo
["func.func(chlo-legalize-to-stablehlo)"]
else
[]
end
)...,
opt_passes2,
kern,
raise_passes,
lower_enzymexla_linalg_pass,
"lower-probprog-trace-ops{backend=$backend}",
jit,
]
end,
",",
),
"probprog",
)
elseif compile_options.optimization_passes === :only_enzyme
run_pass_pipeline!(
mod,
Expand Down
3 changes: 3 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ include("Overlay.jl")
# Serialization
include("serialization/Serialization.jl")

# ProbProg
include("probprog/ProbProg.jl")

using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, traced_getfield, compile
export ConcreteRArray,
ConcreteRNumber,
Expand Down
153 changes: 153 additions & 0 deletions src/probprog/Display.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104

function _format_array(arr::AbstractArray; n_show::Int=3, indent::Int=0)
nd = ndims(arr)
if nd == 0
return string(arr[])
elseif nd == 1
len = length(arr)
if len <= 2 * n_show
return "[" * join(arr, " ") * "]"
end
first_part = join(arr[1:n_show], " ")
last_part = join(arr[(end - n_show + 1):end], " ")
return "[$first_part ... $last_part]"
else
n_slices = size(arr, 1)
indent_str = " "^(indent + 1)

if n_slices <= 2 * n_show
slice_strs = [
_format_array(selectdim(arr, 1, i); n_show=n_show, indent=indent + 1) for
i in 1:n_slices
]
return "[" * join(slice_strs, "\n" * indent_str) * "]"
else
first_slices = [
_format_array(selectdim(arr, 1, i); n_show=n_show, indent=indent + 1) for
i in 1:n_show
]
last_slices = [
_format_array(selectdim(arr, 1, i); n_show=n_show, indent=indent + 1) for
i in (n_slices - n_show + 1):n_slices
]
return "[" *
join(first_slices, "\n" * indent_str) *
"\n" *
indent_str *
"..." *
"\n" *
indent_str *
join(last_slices, "\n" * indent_str) *
"]"
end
end
end

function _format_digest(value; n_show::Int=3)
if isa(value, Tuple)
if length(value) == 1
return _format_digest(value[1]; n_show=n_show)
else
formatted = [_format_digest(v; n_show=n_show) for v in value]
return "(" * join(formatted, ", ") * ")"
end
elseif isa(value, AbstractArray)
return _format_array(value; n_show=n_show, indent=0)
else
return string(value)
end
end

function _show_pretty(io::IO, trace::Trace, pre::Int, vert_bars::Tuple)
VERT = '\u2502'
PLUS = '\u251C'
HORZ = '\u2500'
LAST = '\u2514'

indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n'])
indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' '])
indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' '])

for i in vert_bars
indent_vert[i] = VERT
indent[i] = VERT
indent_last[i] = VERT
end

indent_vert_str = join(indent_vert)
indent_str = join(indent)
indent_last_str = join(indent_last)

sorted_choices = sort(collect(trace.choices); by=x -> x[1])
n = length(sorted_choices)

if trace.retval !== nothing
n += 1
end

if trace.weight !== nothing
n += 1
end

cur = 1

if trace.retval !== nothing
print(io, indent_vert_str)
retval_str = _format_digest(trace.retval)
print(io, (cur == n ? indent_last_str : indent_str) * "retval : $retval_str\n")
cur += 1
end

if trace.weight !== nothing
print(io, indent_vert_str)
print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n")
cur += 1
end

for (key, value) in sorted_choices
print(io, indent_vert_str)
value_str = _format_digest(value)
if contains(value_str, '\n')
indent_continuation = " "^(length(indent_str) + length(repr(key)) + 3)
value_str = replace(value_str, "\n" => "\n" * indent_continuation)
end
print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value_str\n")
cur += 1
end

sorted_subtraces = sort(collect(trace.subtraces); by=x -> x[1])
n += length(sorted_subtraces)

for (key, subtrace) in sorted_subtraces
print(io, indent_vert_str)
print(io, (cur == n ? indent_last_str : indent_str) * "subtrace on $(repr(key))\n")
_show_pretty(
io, subtrace, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre + 1)
)
cur += 1
end
end

function Base.show(io::IO, ::MIME"text/plain", trace::Trace)
println(io, "Trace:")
if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing
println(io, " (empty)")
else
_show_pretty(io, trace, 0, ())
end
end

function Base.show(io::IO, trace::Trace)
if get(io, :compact, false)
choices_count = length(trace.choices)
has_retval = trace.retval !== nothing
print(io, "Trace($(choices_count) choices")
if has_retval
print(io, ", retval=$(trace.retval), weight=$(trace.weight)")
end
print(io, ")")
else
show(io, MIME"text/plain"(), trace)
end
end
Loading
Loading