Skip to content

Commit 096b063

Browse files
committed
interface
1 parent 9f486ea commit 096b063

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -403,17 +403,24 @@ enzymeRngDistributionAttrGet(MlirContext ctx, int32_t val) {
403403
}
404404

405405
REACTANT_ABI MLIR_CAPI_EXPORTED MlirAttribute
406-
enzymeHMCConfigAttrGet(MlirContext ctx, int64_t num_steps) {
407-
return wrap(mlir::enzyme::HMCConfigAttr::get(unwrap(ctx), num_steps));
406+
enzymeHMCConfigAttrGet(MlirContext ctx, double trajectory_length,
407+
bool adapt_step_size, bool adapt_mass_matrix) {
408+
auto *context = unwrap(ctx);
409+
mlir::FloatAttr trajectoryLengthAttr =
410+
mlir::FloatAttr::get(mlir::Float64Type::get(context), trajectory_length);
411+
return wrap(mlir::enzyme::HMCConfigAttr::get(
412+
context, trajectoryLengthAttr, adapt_step_size, adapt_mass_matrix));
408413
}
409414

410415
REACTANT_ABI MLIR_CAPI_EXPORTED MlirAttribute enzymeNUTSConfigAttrGet(
411-
MlirContext ctx, int64_t max_tree_depth, double max_delta_energy) {
416+
MlirContext ctx, int64_t max_tree_depth, double max_delta_energy,
417+
bool adapt_step_size, bool adapt_mass_matrix) {
412418
auto *context = unwrap(ctx);
413419
mlir::FloatAttr maxDeltaEnergyAttr =
414420
mlir::FloatAttr::get(mlir::Float64Type::get(context), max_delta_energy);
415-
return wrap(mlir::enzyme::NUTSConfigAttr::get(context, max_tree_depth,
416-
maxDeltaEnergyAttr));
421+
return wrap(mlir::enzyme::NUTSConfigAttr::get(
422+
context, max_tree_depth, maxDeltaEnergyAttr, adapt_step_size,
423+
adapt_mass_matrix));
417424
}
418425

419426
// Create profiler session and start profiling

src/probprog/MCMC.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@ function mcmc(
99
algorithm::Symbol=:HMC,
1010
inverse_mass_matrix=nothing,
1111
step_size=nothing,
12-
num_steps::Int=10,
12+
trajectory_length::Float64=2π,
1313
max_tree_depth::Int=10,
1414
max_delta_energy::Float64=1000.0,
15+
num_warmup::Int=0,
1516
num_samples::Int=1,
1617
thinning::Int=1,
18+
adapt_step_size::Bool=true,
19+
adapt_mass_matrix::Bool=true,
1720
) where {Nargs}
1821
args = (rng, args...)
1922
(; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function(
@@ -63,13 +66,15 @@ function mcmc(
6366

6467
if algorithm == :HMC
6568
hmc_config_attr = @ccall MLIR.API.mlir_c.enzymeHMCConfigAttrGet(
66-
MLIR.IR.context()::MLIR.API.MlirContext, num_steps::Int64
69+
MLIR.IR.context()::MLIR.API.MlirContext, trajectory_length::Float64, adapt_step_size::Bool, adapt_mass_matrix::Bool
6770
)::MLIR.IR.Attribute
6871
elseif algorithm == :NUTS
6972
nuts_config_attr = @ccall MLIR.API.mlir_c.enzymeNUTSConfigAttrGet(
7073
MLIR.IR.context()::MLIR.API.MlirContext,
7174
max_tree_depth::Int64,
72-
max_delta_energy::Float64
75+
max_delta_energy::Float64,
76+
adapt_step_size::Bool,
77+
adapt_mass_matrix::Bool
7378
)::MLIR.IR.Attribute
7479
else
7580
error("Unknown MCMC algorithm: $algorithm. Supported algorithms are :HMC and :NUTS")
@@ -97,6 +102,7 @@ function mcmc(
97102
selection=MLIR.IR.Attribute(selection_attr),
98103
hmc_config=hmc_config_attr,
99104
nuts_config=nuts_config_attr,
105+
num_warmup=Int64(num_warmup),
100106
num_samples=Int64(num_samples),
101107
thinning=Int64(thinning),
102108
)

0 commit comments

Comments
 (0)