@@ -10,7 +10,7 @@ using ..Reactant:
10
10
TracedRNumber,
11
11
ConcreteRNumber,
12
12
Ops
13
- using .. Compiler: @jit
13
+ using .. Compiler: @jit , @compile
14
14
using Enzyme
15
15
using Base: ReentrantLock
16
16
@@ -705,10 +705,15 @@ function generate(
705
705
weight = nothing
706
706
res = nothing
707
707
708
+ constraint_ptr = ConcreteRNumber (reinterpret (UInt64, pointer_from_objref (constraint)))
709
+
710
+ function wrapper_fn (constraint_ptr, args... )
711
+ return generate_internal (f, args... ; constraint_ptr, constraint)
712
+ end
713
+
708
714
try
709
- trace, weight, res = @jit optimize = :probprog generate_internal (
710
- f, args... ; constraint
711
- )
715
+ compiled_fn = @compile optimize = :probprog wrapper_fn (constraint_ptr, args... )
716
+ trace, weight, res = compiled_fn (constraint_ptr, args... )
712
717
finally
713
718
GC. enable (old_gc_state)
714
719
end
@@ -719,7 +724,10 @@ function generate(
719
724
end
720
725
721
726
function generate_internal (
722
- f:: Function , args:: Vararg{Any,Nargs} ; constraint:: Constraint = Dict {Symbol,Any} ()
727
+ f:: Function ,
728
+ args:: Vararg{Any,Nargs} ;
729
+ constraint_ptr:: TracedRNumber ,
730
+ constraint:: Constraint = Dict {Symbol,Any} (),
723
731
) where {Nargs}
724
732
argprefix:: Symbol = gensym (" generatearg" )
725
733
resprefix:: Symbol = gensym (" generateresult" )
@@ -771,12 +779,9 @@ function generate_internal(
771
779
MLIR. IR. context ():: MLIR.API.MlirContext
772
780
):: MLIR.IR.Type
773
781
774
- constraint_addr = ConcreteRNumber (reinterpret (UInt64, pointer_from_objref (constraint)))
775
- constraint_mlir_val = TracedUtils. get_mlir_data (Ops. constant (constraint_addr))
776
-
777
782
constraint_val = MLIR. IR. result (
778
783
MLIR. Dialects. builtin. unrealized_conversion_cast (
779
- [constraint_mlir_val ]; outputs= [constraint_ty]
784
+ [TracedUtils . get_mlir_data (constraint_ptr) ]; outputs= [constraint_ty]
780
785
),
781
786
1 ,
782
787
)
0 commit comments