Skip to content

Commit 34f35c4

Browse files
committed
@compile for generate op
1 parent 1585483 commit 34f35c4

File tree

2 files changed

+14
-28
lines changed

2 files changed

+14
-28
lines changed

src/ProbProg.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using ..Reactant:
1010
TracedRNumber,
1111
ConcreteRNumber,
1212
Ops
13-
using ..Compiler: @jit
13+
using ..Compiler: @jit, @compile
1414
using Enzyme
1515
using Base: ReentrantLock
1616

@@ -705,10 +705,15 @@ function generate(
705705
weight = nothing
706706
res = nothing
707707

708+
constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint)))
709+
710+
function wrapper_fn(constraint_ptr, args...)
711+
return generate_internal(f, args...; constraint_ptr, constraint)
712+
end
713+
708714
try
709-
trace, weight, res = @jit optimize = :probprog generate_internal(
710-
f, args...; constraint
711-
)
715+
compiled_fn = @compile optimize = :probprog wrapper_fn(constraint_ptr, args...)
716+
trace, weight, res = compiled_fn(constraint_ptr, args...)
712717
finally
713718
GC.enable(old_gc_state)
714719
end
@@ -719,7 +724,10 @@ function generate(
719724
end
720725

721726
function generate_internal(
722-
f::Function, args::Vararg{Any,Nargs}; constraint::Constraint=Dict{Symbol,Any}()
727+
f::Function,
728+
args::Vararg{Any,Nargs};
729+
constraint_ptr::TracedRNumber,
730+
constraint::Constraint=Dict{Symbol,Any}(),
723731
) where {Nargs}
724732
argprefix::Symbol = gensym("generatearg")
725733
resprefix::Symbol = gensym("generateresult")
@@ -771,12 +779,9 @@ function generate_internal(
771779
MLIR.IR.context()::MLIR.API.MlirContext
772780
)::MLIR.IR.Type
773781

774-
constraint_addr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint)))
775-
constraint_mlir_val = TracedUtils.get_mlir_data(Ops.constant(constraint_addr))
776-
777782
constraint_val = MLIR.IR.result(
778783
MLIR.Dialects.builtin.unrealized_conversion_cast(
779-
[constraint_mlir_val]; outputs=[constraint_ty]
784+
[TracedUtils.get_mlir_data(constraint_ptr)]; outputs=[constraint_ty]
780785
),
781786
1,
782787
)

test/probprog/generate.jl

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,6 @@ function model(seed, μ, σ, shape)
1313
end
1414

1515
@testset "Generate" begin
16-
@testset "hlo" begin
17-
shape = (10,)
18-
seed = Reactant.to_rarray(UInt64[1, 4])
19-
μ = Reactant.ConcreteRNumber(0.0)
20-
σ = Reactant.ConcreteRNumber(1.0)
21-
22-
before = @code_hlo optimize = :no_enzyme ProbProg.generate_internal(
23-
model, seed, μ, σ, shape
24-
)
25-
@test contains(repr(before), "enzyme.generate")
26-
@test contains(repr(before), "enzyme.sample")
27-
28-
after = @code_hlo optimize = :probprog ProbProg.generate_internal(
29-
model, seed, μ, σ, shape
30-
)
31-
@test !contains(repr(after), "enzyme.generate")
32-
@test !contains(repr(after), "enzyme.sample")
33-
end
34-
3516
@testset "unconstrained" begin
3617
shape = (1000,)
3718
seed = Reactant.to_rarray(UInt64[1, 4])

0 commit comments

Comments
 (0)