@@ -9,11 +9,14 @@ function mcmc(
99 algorithm:: Symbol = :HMC,
1010 inverse_mass_matrix= nothing ,
1111 step_size= nothing ,
12- num_steps :: Int = 10 ,
12+ trajectory_length :: Float64 = 2 π ,
1313 max_tree_depth:: Int = 10 ,
1414 max_delta_energy:: Float64 = 1000.0 ,
15+ num_warmup:: Int = 0 ,
1516 num_samples:: Int = 1 ,
1617 thinning:: Int = 1 ,
18+ adapt_step_size:: Bool = true ,
19+ adapt_mass_matrix:: Bool = true ,
1720) where {Nargs}
1821 args = (rng, args... )
1922 (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function(
@@ -63,13 +66,15 @@ function mcmc(
6366
6467 if algorithm == :HMC
6568 hmc_config_attr = @ccall MLIR. API. mlir_c. enzymeHMCConfigAttrGet(
66- MLIR. IR. context():: MLIR.API.MlirContext , num_steps :: Int64
69+ MLIR. IR. context():: MLIR.API.MlirContext , trajectory_length :: Float64 , adapt_step_size :: Bool , adapt_mass_matrix :: Bool
6770 ):: MLIR.IR.Attribute
6871 elseif algorithm == :NUTS
6972 nuts_config_attr = @ccall MLIR. API. mlir_c. enzymeNUTSConfigAttrGet(
7073 MLIR. IR. context():: MLIR.API.MlirContext ,
7174 max_tree_depth:: Int64 ,
72- max_delta_energy:: Float64
75+ max_delta_energy:: Float64 ,
76+ adapt_step_size:: Bool ,
77+ adapt_mass_matrix:: Bool
7378 ):: MLIR.IR.Attribute
7479 else
7580 error(" Unknown MCMC algorithm: $algorithm . Supported algorithms are :HMC and :NUTS" )
@@ -97,6 +102,7 @@ function mcmc(
97102 selection= MLIR. IR. Attribute(selection_attr),
98103 hmc_config= hmc_config_attr,
99104 nuts_config= nuts_config_attr,
105+ num_warmup= Int64(num_warmup),
100106 num_samples= Int64(num_samples),
101107 thinning= Int64(thinning),
102108 )
0 commit comments