@@ -13,6 +13,7 @@ using ..Reactant:
13
13
using .. Compiler: @jit , @compile
14
14
using Enzyme
15
15
using Base: ReentrantLock
16
+ using Random
16
17
17
18
mutable struct ProbProgTrace
18
19
fn:: Union{Nothing,Function}
@@ -21,13 +22,18 @@ mutable struct ProbProgTrace
21
22
retval:: Any
22
23
weight:: Any
23
24
subtraces:: Dict{Symbol,Any}
25
+ rng:: Union{Nothing,AbstractRNG}
24
26
25
27
function ProbProgTrace (fn:: Function , args:: Tuple )
26
- return new (fn, args, Dict {Symbol,Any} (), nothing , nothing , Dict {Symbol,Any} ())
28
+ return new (
29
+ fn, args, Dict {Symbol,Any} (), nothing , nothing , Dict {Symbol,Any} (), nothing
30
+ )
27
31
end
28
32
29
33
function ProbProgTrace ()
30
- return new (nothing , (), Dict {Symbol,Any} (), nothing , nothing , Dict {Symbol,Any} ())
34
+ return new (
35
+ nothing , (), Dict {Symbol,Any} (), nothing , nothing , Dict {Symbol,Any} (), nothing
36
+ )
31
37
end
32
38
end
33
39
@@ -587,6 +593,10 @@ function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where
587
593
588
594
trace = unsafe_pointer_to_objref (Ptr {Any} (Array (trace)[1 ]))
589
595
596
+ trace. fn = f
597
+ trace. args = args
598
+ trace. rng = rng
599
+
590
600
return trace, trace. weight
591
601
end
592
602
@@ -702,7 +712,7 @@ function generate(
702
712
trace = nothing
703
713
704
714
constraint_ptr = ConcreteRNumber (reinterpret (UInt64, pointer_from_objref (constraint)))
705
- constrained_symbols = collect (keys (constraint))
715
+ constrained_symbols = Set (keys (constraint))
706
716
707
717
function wrapper_fn (rng, constraint_ptr, args... )
708
718
return generate_internal (rng, f, args... ; constraint_ptr, constrained_symbols)
@@ -719,6 +729,10 @@ function generate(
719
729
720
730
trace = unsafe_pointer_to_objref (Ptr {Any} (Array (trace)[1 ]))
721
731
732
+ trace. fn = f
733
+ trace. args = args
734
+ trace. rng = rng
735
+
722
736
return trace, trace. weight
723
737
end
724
738
@@ -727,7 +741,7 @@ function generate_internal(
727
741
f:: Function ,
728
742
args:: Vararg{Any,Nargs} ;
729
743
constraint_ptr:: TracedRNumber ,
730
- constrained_symbols:: Vector {Symbol} ,
744
+ constrained_symbols:: Set {Symbol} ,
731
745
) where {Nargs}
732
746
argprefix:: Symbol = gensym (" generatearg" )
733
747
resprefix:: Symbol = gensym (" generateresult" )
947
961
948
962
get_choices (trace:: ProbProgTrace ) = trace. choices
949
963
964
+ const Selection = Set{Symbol}
965
+ select (syms:: Symbol... ) = Set (syms)
966
+ choicemap () = Constraint ()
967
+ const CompiledFnCache = Dict{Tuple{Type,Set{Symbol}},Any}
968
+
969
+ function metropolis_hastings (
970
+ trace:: ProbProgTrace ,
971
+ sel:: Selection ;
972
+ compiled_cache:: Union{Nothing,CompiledFnCache} = nothing ,
973
+ )
974
+ if trace. fn === nothing || trace. rng === nothing
975
+ error (" MH requires a trace with fn and rng recorded (use generate to create trace)" )
976
+ end
977
+
978
+ constraints = Dict {Symbol,Any} ()
979
+ constrained_symbols = Set {Symbol} ()
980
+
981
+ for (sym, val) in trace. choices
982
+ if ! (sym in sel)
983
+ constraints[sym] = val
984
+ push! (constrained_symbols, sym)
985
+ end
986
+ end
987
+
988
+ cache_key = (typeof (trace. fn), constrained_symbols)
989
+
990
+ compiled_fn = nothing
991
+ if compiled_cache != = nothing
992
+ compiled_fn = get (compiled_cache, cache_key, nothing )
993
+ end
994
+
995
+ if compiled_fn === nothing
996
+ function wrapper_fn (rng, constraint_ptr, args... )
997
+ return generate_internal (
998
+ rng, trace. fn, args... ; constraint_ptr, constrained_symbols
999
+ )
1000
+ end
1001
+
1002
+ constraint_ptr = ConcreteRNumber (
1003
+ reinterpret (UInt64, pointer_from_objref (constraints))
1004
+ )
1005
+
1006
+ compiled_fn = @compile optimize = :probprog wrapper_fn (
1007
+ trace. rng, constraint_ptr, trace. args...
1008
+ )
1009
+
1010
+ if compiled_cache != = nothing
1011
+ compiled_cache[cache_key] = compiled_fn
1012
+ end
1013
+ end
1014
+
1015
+ constraint_ptr = ConcreteRNumber (reinterpret (UInt64, pointer_from_objref (constraints)))
1016
+
1017
+ old_gc_state = GC. enable (false )
1018
+ new_trace_ptr = nothing
1019
+ try
1020
+ new_trace_ptr, _, _ = compiled_fn (trace. rng, constraint_ptr, trace. args... )
1021
+ finally
1022
+ GC. enable (old_gc_state)
1023
+ end
1024
+
1025
+ new_trace = unsafe_pointer_to_objref (Ptr {Any} (Array (new_trace_ptr)[1 ]))
1026
+
1027
+ new_trace. fn = trace. fn
1028
+ new_trace. args = trace. args
1029
+ new_trace. rng = trace. rng
1030
+
1031
+ log_alpha = new_trace. weight - trace. weight
1032
+
1033
+ if log (rand ()) < log_alpha
1034
+ return (new_trace, true )
1035
+ else
1036
+ return (trace, false )
1037
+ end
1038
+ end
1039
+
950
1040
end
0 commit comments