Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
145 commits
Select commit Hold shift + click to select a range
902ced9
generate
sbrantq May 2, 2025
e2c77e4
refactor
sbrantq May 2, 2025
e204d13
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 2, 2025
327b10a
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 6, 2025
d611ae4
add probprog pass to :all
sbrantq May 7, 2025
3672d83
improve test
sbrantq May 7, 2025
b70843e
only probprog opt mode
sbrantq May 8, 2025
597fa89
fix up test
sbrantq May 8, 2025
e6c2c0a
move
sbrantq May 12, 2025
9b9395e
simplify
sbrantq May 12, 2025
b3ba477
fix up
sbrantq May 14, 2025
47e9fe3
saving changes
sbrantq May 15, 2025
982b2bf
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 16, 2025
06b7464
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 17, 2025
bd73c62
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 19, 2025
a6fcca3
fix sample op
sbrantq May 20, 2025
e51e04b
save tests
sbrantq May 20, 2025
ce68f6a
temporarily removing probprog pass from :all as MLIR pass is not merg…
sbrantq May 20, 2025
94bbe62
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 20, 2025
d31bba6
undo enzyme binding change
sbrantq May 20, 2025
573fa02
format
sbrantq May 20, 2025
0264a3d
format
sbrantq May 20, 2025
2e18bdf
improve
sbrantq May 20, 2025
1f19979
improve
sbrantq May 20, 2025
096d790
get rid of result_and_mutated too
sbrantq May 22, 2025
bb319a3
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 31, 2025
9ac6535
working trace object pointer hacks + tests
sbrantq Jun 5, 2025
b24766f
Assuming scalar samples for now; simple Bayesian linear regression test
sbrantq Jun 5, 2025
3c52b39
exclamation mark
sbrantq Jun 5, 2025
af3d055
sample metadata
sbrantq Jun 6, 2025
6c7ffa3
fix up copy
sbrantq Jun 6, 2025
4e017d0
fix up copy
sbrantq Jun 6, 2025
e53fc7c
working vectorized blr test
sbrantq Jun 6, 2025
1dbf5c7
fix test warning
sbrantq Jun 11, 2025
dd9dcab
hacks to temporarily remove world age issue in tests
sbrantq Jun 11, 2025
a344726
partial refactoring
sbrantq Jun 12, 2025
ebeceb8
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Jun 13, 2025
ef2e770
fixed tracing infra
sbrantq Jun 14, 2025
46e0f6b
transpose fix up
sbrantq Jun 16, 2025
1c5297c
minor changes
sbrantq Jun 17, 2025
d707053
reorder
sbrantq Jun 17, 2025
91a0850
API change
sbrantq Jun 20, 2025
561b051
better print
sbrantq Jun 20, 2025
99d7608
unconstrained real generate op
sbrantq Jun 25, 2025
b13f8bf
probprog postpasses
sbrantq Jun 25, 2025
6e4dc0c
bug fix for alising outputs
sbrantq Jun 26, 2025
5b5c1d1
generate op with constraints
sbrantq Jun 26, 2025
1ad167a
untraced call
sbrantq Jun 26, 2025
8f66b5f
working metropolis hastings (with hacks)
sbrantq Jun 26, 2025
850e3c4
set julia rng
sbrantq Jun 27, 2025
e1b3bcb
remove print
sbrantq Jun 27, 2025
659b963
less iterations. hiding prints
sbrantq Jun 27, 2025
537de49
add probprog test group
sbrantq Jun 27, 2025
04d2e44
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Jun 27, 2025
8260fee
format
sbrantq Jun 27, 2025
0f94166
add probprog compile opt
sbrantq Jun 27, 2025
7f611fe
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Jul 3, 2025
a05d2c2
pass all args even when w/o rng
sbrantq Jul 3, 2025
f40960f
updated probprog frontend for refactored simulate op
sbrantq Jul 3, 2025
f6ee849
probprog attr mlir api
sbrantq Jul 3, 2025
38e33de
adding cfunction mapping for AddWeightToTrace and AddRetvalToTrace ops
sbrantq Jul 4, 2025
127126d
adding traced_output_indices attr to simulate op
sbrantq Jul 4, 2025
3d66c7a
update tests
sbrantq Jul 4, 2025
1585483
refactored generate op
sbrantq Jul 8, 2025
34f35c4
@compile for generate op
sbrantq Jul 8, 2025
f4a6415
improve api
sbrantq Jul 8, 2025
b92a733
compiled generate test
sbrantq Jul 8, 2025
160561e
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Jul 12, 2025
bbfa3f6
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Jul 14, 2025
f4c4a88
save gc change
sbrantq Jul 15, 2025
d1be27c
enforcing calling convention (rng being the 0th operand) for sample &…
sbrantq Jul 16, 2025
b666813
enforcing calling convention (rng being 0th operand) for simulate/gen…
sbrantq Jul 17, 2025
c57a1e4
clean up
sbrantq Jul 18, 2025
2b81db9
refactored mh inference steps with new calling convention enforced
sbrantq Jul 18, 2025
e647b0d
improve
sbrantq Jul 20, 2025
4fe55a6
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Jul 20, 2025
94b9e3a
reorganize
sbrantq Jul 20, 2025
b2d583a
format
sbrantq Jul 20, 2025
65d3595
fix up tests
sbrantq Jul 21, 2025
a29fbed
remove redundant cast
sbrantq Jul 21, 2025
87ced72
generate op fixup: replacing constrained_symbols with constrained_add…
sbrantq Jul 29, 2025
ebec467
minor
sbrantq Jul 29, 2025
f771bcb
update legacy inference API
sbrantq Jul 29, 2025
1908188
simplify
sbrantq Jul 29, 2025
0b71444
cleanup
sbrantq Jul 29, 2025
1a23c2e
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Jul 29, 2025
9bd1dee
fix deadlock
sbrantq Jul 31, 2025
c9ff7c0
fix test
sbrantq Jul 31, 2025
3196989
don't print
sbrantq Jul 31, 2025
4afda71
clean up postpasses
sbrantq Jul 31, 2025
d9ca225
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Aug 18, 2025
20b9b9c
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Aug 29, 2025
a2951d5
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Sep 12, 2025
f6172df
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Sep 16, 2025
d635a37
Merge branch 'main' into probprog-trace-operand
sbrantq Sep 20, 2025
f7aab5b
format
sbrantq Sep 20, 2025
75a187a
Merge branch 'main' into probprog-trace-operand
avik-pal Sep 26, 2025
0ed081f
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Sep 30, 2025
834de05
remove probprog_no_lowering
sbrantq Sep 30, 2025
d2dd57c
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Sep 30, 2025
7ad561d
undo jll change
sbrantq Sep 30, 2025
a31f439
undo jll change
sbrantq Sep 30, 2025
9656914
clean
sbrantq Sep 30, 2025
1532269
clean up and improve
sbrantq Sep 30, 2025
eda74e3
format
sbrantq Sep 30, 2025
db314a7
remove invokelatest
sbrantq Sep 30, 2025
df764c1
clean up
sbrantq Sep 30, 2025
c0cc686
ci
sbrantq Sep 30, 2025
1d3a7d8
format
sbrantq Sep 30, 2025
06b44c3
utils
sbrantq Oct 1, 2025
e50931f
fmt
sbrantq Oct 1, 2025
6a491f3
minor fix
sbrantq Oct 2, 2025
9fd8cf3
plausible trace object threading
sbrantq Oct 2, 2025
0041f6f
ffi for new ops
sbrantq Oct 7, 2025
46222d9
test
sbrantq Oct 8, 2025
03580e1
format
sbrantq Oct 8, 2025
a5bfeda
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Oct 8, 2025
beb26e9
format
sbrantq Oct 8, 2025
7ac92a7
format
sbrantq Oct 8, 2025
5470d7e
format
sbrantq Oct 8, 2025
c9cf848
Merge branch 'probprog-mh' of https://github.com/EnzymeAD/Reactant.jl…
sbrantq Oct 8, 2025
30020d1
fix
sbrantq Oct 12, 2025
1e7a23e
minor
sbrantq Oct 14, 2025
df6d3ca
FFIs for getFlattenedSamplesFromTrace and dump
sbrantq Oct 14, 2025
49bf991
jll change
sbrantq Oct 14, 2025
fc474b8
split proboprog pipeline
sbrantq Oct 14, 2025
96919da
hmc frontend
sbrantq Oct 14, 2025
f584ca5
clean up
sbrantq Oct 15, 2025
5241512
don't try to setfield immutable ReactantRNG
sbrantq Oct 15, 2025
224ca16
initial momentum arg
sbrantq Oct 17, 2025
85d13df
fix logpdf
sbrantq Oct 17, 2025
44ce684
temporary tracer fix
sbrantq Oct 17, 2025
de404f0
static hmc correctness test
sbrantq Oct 17, 2025
13bfce0
clean up
sbrantq Oct 19, 2025
24305aa
simplify interface
sbrantq Oct 20, 2025
9293bba
clean up
sbrantq Oct 20, 2025
46d839b
clean up
sbrantq Oct 20, 2025
796cd41
simplify
sbrantq Oct 20, 2025
25386c5
time
sbrantq Oct 20, 2025
f0b73ea
put generate and hmc into a single function
sbrantq Oct 20, 2025
8d6ae3e
simplify
sbrantq Oct 20, 2025
81a6f43
improve
sbrantq Oct 21, 2025
61fdff8
simplify; undo custom tracing which doesn't work
sbrantq Oct 21, 2025
e35351e
simplify and profiler call
sbrantq Oct 21, 2025
5e863fe
more iters
sbrantq Oct 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,18 @@ enzymeSymbolAttrGet(MlirContext ctx, uint64_t symbol) {
return wrap(attr);
}

extern "C" MLIR_CAPI_EXPORTED MlirAttribute
enzymeRngDistributionAttrGet(MlirContext ctx, int32_t val) {
return wrap(mlir::enzyme::RngDistributionAttr::get(
unwrap(ctx), (mlir::enzyme::RngDistribution)val));
}

extern "C" MLIR_CAPI_EXPORTED MlirAttribute
enzymeMCMCAlgorithmAttrGet(MlirContext ctx, int32_t val) {
return wrap(mlir::enzyme::MCMCAlgorithmAttr::get(
unwrap(ctx), (mlir::enzyme::MCMCAlgorithm)val));
}

// Create profiler session and start profiling
REACTANT_ABI tsl::ProfilerSession *
CreateProfilerSession(uint32_t device_tracer_level,
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/tblgen/jl-generators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ end
operandname = "operand_" + std::to_string(i);
}
if (named_operand.isOptional()) {
operandsegmentsizes += "(" + operandname + "==nothing) ? 0 : 1";
operandsegmentsizes += "(" + operandname + "==nothing) ? 0 : 1, ";
continue;
}
operandsegmentsizes += named_operand.isVariadic()
Expand Down
1 change: 1 addition & 0 deletions src/CompileOptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ function CompileOptions(;
:canonicalize,
:just_batch,
:none,
:probprog,
]
end

Expand Down
66 changes: 66 additions & 0 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1302,6 +1302,7 @@ end
# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}"
const probprog_pass::String = "probprog{postpasses=\"arith-raise{stablehlo=true}\"}"

function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true)
pm = MLIR.IR.PassManager()
Expand Down Expand Up @@ -1885,6 +1886,71 @@ function compile_mlir!(
),
"no_enzyme",
)
elseif compile_options.optimization_passes === :probprog
run_pass_pipeline!(
mod,
join(
if compile_options.raise_first
[
"mark-func-memory-effects",
opt_passes,
kern,
raise_passes,
"enzyme-batch",
opt_passes2,
probprog_pass,
"lower-probprog-to-stablehlo{backend=$backend}",
"outline-enzyme-regions",
enzyme_pass,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
(
if compile_options.legalize_chlo_to_stablehlo
["func.func(chlo-legalize-to-stablehlo)"]
else
[]
end
)...,
opt_passes2,
lower_enzymexla_linalg_pass,
"lower-probprog-trace-ops{backend=$backend}",
jit,
]
else
[
"mark-func-memory-effects",
opt_passes,
"enzyme-batch",
opt_passes2,
probprog_pass,
"lower-probprog-to-stablehlo{backend=$backend}",
"outline-enzyme-regions",
enzyme_pass,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
(
if compile_options.legalize_chlo_to_stablehlo
["func.func(chlo-legalize-to-stablehlo)"]
else
[]
end
)...,
opt_passes2,
kern,
raise_passes,
lower_enzymexla_linalg_pass,
"lower-probprog-trace-ops{backend=$backend}",
jit,
]
end,
",",
),
"probprog",
)
elseif compile_options.optimization_passes === :only_enzyme
run_pass_pipeline!(
mod,
Expand Down
1 change: 1 addition & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ include("Tracing.jl")
include("Compiler.jl")

include("Overlay.jl")
include("probprog/ProbProg.jl")

# Serialization
include("serialization/Serialization.jl")
Expand Down
2 changes: 2 additions & 0 deletions src/Types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ function ConcretePJRTArray(
end

Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data)
Base.isready(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = all(isready, x.data)
XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data)
function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber})
x.sharding isa Sharding.NoShardInfo && return XLA.device(only(x.data))
Expand Down Expand Up @@ -405,6 +406,7 @@ function ConcreteIFRTArray(
end

Base.wait(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = wait(x.data)
Base.isready(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = isready(x.data)
XLA.client(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber}) = XLA.client(x.data)
function XLA.device(x::Union{ConcreteIFRTArray,ConcreteIFRTNumber})
return XLA.device(x.data)
Expand Down
87 changes: 87 additions & 0 deletions src/probprog/Display.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104
function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple)
VERT = '\u2502'
PLUS = '\u251C'
HORZ = '\u2500'
LAST = '\u2514'

indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n'])
indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' '])
indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' '])

for i in vert_bars
indent_vert[i] = VERT
indent[i] = VERT
indent_last[i] = VERT
end

indent_vert_str = join(indent_vert)
indent_str = join(indent)
indent_last_str = join(indent_last)

sorted_choices = sort(collect(trace.choices); by=x -> x[1])
n = length(sorted_choices)

if trace.retval !== nothing
n += 1
end

if trace.weight !== nothing
n += 1
end

cur = 1

if trace.retval !== nothing
print(io, indent_vert_str)
print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n")
cur += 1
end

if trace.weight !== nothing
print(io, indent_vert_str)
print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n")
cur += 1
end

for (key, value) in sorted_choices
print(io, indent_vert_str)
print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n")
cur += 1
end

sorted_subtraces = sort(collect(trace.subtraces); by=x -> x[1])
n += length(sorted_subtraces)

for (key, subtrace) in sorted_subtraces
print(io, indent_vert_str)
print(io, (cur == n ? indent_last_str : indent_str) * "subtrace on $(repr(key))\n")
_show_pretty(
io, subtrace, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre + 1)
)
cur += 1
end
end

function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace)
println(io, "ProbProgTrace:")
if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing
println(io, " (empty)")
else
_show_pretty(io, trace, 0, ())
end
end

function Base.show(io::IO, trace::ProbProgTrace)
if get(io, :compact, false)
choices_count = length(trace.choices)
has_retval = trace.retval !== nothing
print(io, "ProbProgTrace($(choices_count) choices")
if has_retval
print(io, ", retval=$(trace.retval), weight=$(trace.weight)")
end
print(io, ")")
else
show(io, MIME"text/plain"(), trace)
end
end
Loading
Loading