Skip to content

Commit 8833ea0

Browse files
committed
probprog squash
1 parent 757e887 commit 8833ea0

25 files changed

+3189
-1
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ jobs:
5353
- core
5454
- nn
5555
- integration
56+
- probprog
5657
runtime:
5758
- "pjrt"
5859
- "ifrt"
@@ -95,6 +96,7 @@ jobs:
9596
- core
9697
- nn
9798
- integration
99+
- probprog
98100
runtime:
99101
- "pjrt"
100102
- "ifrt"

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
3434
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
3535
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3636
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
37+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3738
p7zip_jll = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
3839

3940
[weakdeps]
@@ -45,6 +46,7 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
4546
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
4647
Float8s = "81dfefd7-55b0-40c6-a251-db853704e186"
4748
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
49+
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
4850
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
4951
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
5052
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
@@ -54,7 +56,6 @@ PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
5456
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
5557
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
5658
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
57-
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
5859
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
5960
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6061

@@ -70,6 +71,7 @@ ReactantFFTWExt = ["FFTW", "AbstractFFTs", "LinearAlgebra"]
7071
ReactantFillArraysExt = "FillArrays"
7172
ReactantFloat8sExt = "Float8s"
7273
ReactantKernelAbstractionsExt = "KernelAbstractions"
74+
ReactantMCMCDiagnosticToolsExt = "MCMCDiagnosticTools"
7375
ReactantMPIExt = "MPI"
7476
ReactantNNlibExt = ["NNlib", "Statistics"]
7577
ReactantNPZExt = "NPZ"
@@ -108,6 +110,7 @@ LLVM = "9.4"
108110
LLVMOpenMP_jll = "18.1.7"
109111
Libdl = "1.10"
110112
LinearAlgebra = "1.10"
113+
MCMCDiagnosticTools = "0.3.11"
111114
MPI = "0.20"
112115
NNlib = "0.9.26"
113116
NPZ = "0.4"
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
module ReactantMCMCDiagnosticToolsExt
2+
3+
using Reactant.ProbProg: ProbProg
4+
using MCMCDiagnosticTools: ess, rhat
5+
6+
function ProbProg._compute_ess(samples::AbstractVector)
7+
x = collect(Float64, samples)
8+
n = length(x)
9+
if n < 4
10+
return Float64(n)
11+
end
12+
x_matrix = reshape(x, n, 1)
13+
return ess(x_matrix)
14+
end
15+
16+
function ProbProg._compute_rhat(samples::AbstractVector)
17+
x = collect(Float64, samples)
18+
n = length(x)
19+
if n < 4
20+
return NaN
21+
end
22+
x_matrix = reshape(x, n, 1)
23+
return rhat(x_matrix)
24+
end
25+
26+
end

src/CompileOptions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ function CompileOptions(;
275275
:canonicalize,
276276
:just_batch,
277277
:none,
278+
:probprog,
278279
]
279280
end
280281

src/Compiler.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ import ..ReactantCore: correct_maybe_bcast_call
3131
const DEBUG_PRINT_CODEGEN = Ref(false)
3232
const DEBUG_DISABLE_RESHARDING = Ref(false)
3333
const DEBUG_ALIASED_BUFFER_ASSIGNMENT_ERROR = Ref(false)
34+
const DEBUG_PROBPROG_DUMP_VALUE = Ref(false)
35+
const DEBUG_PROBPROG_DISABLE_OPT = Ref(true)
3436

3537
const DEBUG_BUFFER_POINTERS_STORE_DICT = Base.IdDict()
3638

@@ -1444,6 +1446,16 @@ end
14441446
# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
14451447
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize,arith-raise{stablehlo=true}\"}"
14461448

1449+
function probprog_pass(;
1450+
debug_dump::Bool=DEBUG_PROBPROG_DUMP_VALUE[],
1451+
disable_optimizations::Bool=DEBUG_PROBPROG_DISABLE_OPT[],
1452+
)
1453+
if !disable_optimizations
1454+
# TODO(#2063): Add probprog optimization passes
1455+
end
1456+
return "probprog{debug-dump=$debug_dump postpasses=\"arith-raise{stablehlo=true}\"}"
1457+
end
1458+
14471459
function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true)
14481460
pm = MLIR.IR.PassManager()
14491461
MLIR.IR.enable_verifier!(pm, enable_verifier)
@@ -2068,6 +2080,71 @@ function compile_mlir!(
20682080
),
20692081
"no_enzyme",
20702082
)
2083+
elseif compile_options.optimization_passes === :probprog
2084+
run_pass_pipeline!(
2085+
mod,
2086+
join(
2087+
if compile_options.raise_first
2088+
[
2089+
"mark-func-memory-effects",
2090+
opt_passes,
2091+
kern,
2092+
raise_passes,
2093+
"enzyme-batch",
2094+
opt_passes2,
2095+
probprog_pass(),
2096+
"lower-probprog-to-stablehlo{backend=$backend}",
2097+
"outline-enzyme-regions",
2098+
enzyme_pass,
2099+
opt_passes2,
2100+
"canonicalize",
2101+
"remove-unnecessary-enzyme-ops",
2102+
"enzyme-simplify-math",
2103+
(
2104+
if compile_options.legalize_chlo_to_stablehlo
2105+
["func.func(chlo-legalize-to-stablehlo)"]
2106+
else
2107+
[]
2108+
end
2109+
)...,
2110+
opt_passes2,
2111+
lower_enzymexla_linalg_pass,
2112+
"lower-probprog-trace-ops{backend=$backend}",
2113+
jit,
2114+
]
2115+
else
2116+
[
2117+
"mark-func-memory-effects",
2118+
opt_passes,
2119+
"enzyme-batch",
2120+
opt_passes2,
2121+
probprog_pass(),
2122+
"lower-probprog-to-stablehlo{backend=$backend}",
2123+
"outline-enzyme-regions",
2124+
enzyme_pass,
2125+
opt_passes2,
2126+
"canonicalize",
2127+
"remove-unnecessary-enzyme-ops",
2128+
"enzyme-simplify-math",
2129+
(
2130+
if compile_options.legalize_chlo_to_stablehlo
2131+
["func.func(chlo-legalize-to-stablehlo)"]
2132+
else
2133+
[]
2134+
end
2135+
)...,
2136+
opt_passes2,
2137+
kern,
2138+
raise_passes,
2139+
lower_enzymexla_linalg_pass,
2140+
"lower-probprog-trace-ops{backend=$backend}",
2141+
jit,
2142+
]
2143+
end,
2144+
",",
2145+
),
2146+
"probprog",
2147+
)
20712148
elseif compile_options.optimization_passes === :only_enzyme
20722149
run_pass_pipeline!(
20732150
mod,

src/Reactant.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,9 @@ include("Overlay.jl")
268268
# Serialization
269269
include("serialization/Serialization.jl")
270270

271+
# ProbProg
272+
include("probprog/ProbProg.jl")
273+
271274
using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, traced_getfield, compile
272275
export ConcreteRArray,
273276
ConcreteRNumber,

src/probprog/Display.jl

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104
2+
3+
function _format_array(arr::AbstractArray; n_show::Int=3, indent::Int=0)
4+
nd = ndims(arr)
5+
if nd == 0
6+
return string(arr[])
7+
elseif nd == 1
8+
len = length(arr)
9+
if len <= 2 * n_show
10+
return "[" * join(arr, " ") * "]"
11+
end
12+
first_part = join(arr[1:n_show], " ")
13+
last_part = join(arr[(end - n_show + 1):end], " ")
14+
return "[$first_part ... $last_part]"
15+
else
16+
n_slices = size(arr, 1)
17+
indent_str = " "^(indent + 1)
18+
19+
if n_slices <= 2 * n_show
20+
slice_strs = [
21+
_format_array(selectdim(arr, 1, i); n_show=n_show, indent=indent + 1) for
22+
i in 1:n_slices
23+
]
24+
return "[" * join(slice_strs, "\n" * indent_str) * "]"
25+
else
26+
first_slices = [
27+
_format_array(selectdim(arr, 1, i); n_show=n_show, indent=indent + 1) for
28+
i in 1:n_show
29+
]
30+
last_slices = [
31+
_format_array(selectdim(arr, 1, i); n_show=n_show, indent=indent + 1) for
32+
i in (n_slices - n_show + 1):n_slices
33+
]
34+
return "[" *
35+
join(first_slices, "\n" * indent_str) *
36+
"\n" *
37+
indent_str *
38+
"..." *
39+
"\n" *
40+
indent_str *
41+
join(last_slices, "\n" * indent_str) *
42+
"]"
43+
end
44+
end
45+
end
46+
47+
function _format_digest(value; n_show::Int=3)
48+
if isa(value, Tuple)
49+
if length(value) == 1
50+
return _format_digest(value[1]; n_show=n_show)
51+
else
52+
formatted = [_format_digest(v; n_show=n_show) for v in value]
53+
return "(" * join(formatted, ", ") * ")"
54+
end
55+
elseif isa(value, AbstractArray)
56+
return _format_array(value; n_show=n_show, indent=0)
57+
else
58+
return string(value)
59+
end
60+
end
61+
62+
function _show_pretty(io::IO, trace::Trace, pre::Int, vert_bars::Tuple)
63+
VERT = '\u2502'
64+
PLUS = '\u251C'
65+
HORZ = '\u2500'
66+
LAST = '\u2514'
67+
68+
indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n'])
69+
indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' '])
70+
indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' '])
71+
72+
for i in vert_bars
73+
indent_vert[i] = VERT
74+
indent[i] = VERT
75+
indent_last[i] = VERT
76+
end
77+
78+
indent_vert_str = join(indent_vert)
79+
indent_str = join(indent)
80+
indent_last_str = join(indent_last)
81+
82+
sorted_choices = sort(collect(trace.choices); by=x -> x[1])
83+
n = length(sorted_choices)
84+
85+
if trace.retval !== nothing
86+
n += 1
87+
end
88+
89+
if trace.weight !== nothing
90+
n += 1
91+
end
92+
93+
cur = 1
94+
95+
if trace.retval !== nothing
96+
print(io, indent_vert_str)
97+
retval_str = _format_digest(trace.retval)
98+
print(io, (cur == n ? indent_last_str : indent_str) * "retval : $retval_str\n")
99+
cur += 1
100+
end
101+
102+
if trace.weight !== nothing
103+
print(io, indent_vert_str)
104+
print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n")
105+
cur += 1
106+
end
107+
108+
for (key, value) in sorted_choices
109+
print(io, indent_vert_str)
110+
value_str = _format_digest(value)
111+
if contains(value_str, '\n')
112+
indent_continuation = " "^(length(indent_str) + length(repr(key)) + 3)
113+
value_str = replace(value_str, "\n" => "\n" * indent_continuation)
114+
end
115+
print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value_str\n")
116+
cur += 1
117+
end
118+
119+
sorted_subtraces = sort(collect(trace.subtraces); by=x -> x[1])
120+
n += length(sorted_subtraces)
121+
122+
for (key, subtrace) in sorted_subtraces
123+
print(io, indent_vert_str)
124+
print(io, (cur == n ? indent_last_str : indent_str) * "subtrace on $(repr(key))\n")
125+
_show_pretty(
126+
io, subtrace, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre + 1)
127+
)
128+
cur += 1
129+
end
130+
end
131+
132+
function Base.show(io::IO, ::MIME"text/plain", trace::Trace)
133+
println(io, "Trace:")
134+
if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing
135+
println(io, " (empty)")
136+
else
137+
_show_pretty(io, trace, 0, ())
138+
end
139+
end
140+
141+
function Base.show(io::IO, trace::Trace)
142+
if get(io, :compact, false)
143+
choices_count = length(trace.choices)
144+
has_retval = trace.retval !== nothing
145+
print(io, "Trace($(choices_count) choices")
146+
if has_retval
147+
print(io, ", retval=$(trace.retval), weight=$(trace.weight)")
148+
end
149+
print(io, ")")
150+
else
151+
show(io, MIME"text/plain"(), trace)
152+
end
153+
end

0 commit comments

Comments
 (0)