@@ -19,9 +19,11 @@ inits_to_json <- function(inits) {
1919}
2020
2121write_inits <- function (inits , init_filepath ) {
22- dat_string <- inits_to_json(inits )
23- writeLines(dat_string , con = init_filepath )
24- dat_string
22+ lapply(seq_len(length(inits )), function (i ) {
23+ dat_string <- inits_to_json(inits [[i ]])
24+ writeLines(dat_string , con = init_filepath [[i ]])
25+ dat_string
26+ })
2527}
2628
2729prepare_and_write_json <- function (what , input_list ) {
@@ -44,10 +46,9 @@ with_env <- function(f, e=parent.frame()) {
4446 f
4547}
4648
47- prepare_function <- function (fn , inits , extra_args_list , grad = FALSE ) {
48- fn_wrapper <- function (v ) { do.call(fn , c(list (v ), extra_args_list )) }
49+ validate_function <- function (fn , inits , extra_args_list , grad = FALSE ) {
4950 fn_type <- ifelse(isTRUE(grad ), " Gradient" , " Log-Likelihood" )
50- test_fn <- try(invisible (fn_wrapper (inits )), silent = TRUE )
51+ test_fn <- try(invisible (fn (inits )), silent = TRUE )
5152 correct_length <- ifelse(isTRUE(grad ), length(inits ), 1 )
5253
5354 if (inherits(test_fn , " try-error" )) {
@@ -60,12 +61,12 @@ prepare_function <- function(fn, inits, extra_args_list, grad = FALSE) {
6061 stop(fn_type , " function should have return of length " , correct_length ,
6162 " , but return was length " , length(test_fn ), " instead!" , call. = FALSE )
6263 } else {
63- fn_wrapper
64+ invisible ( NULL )
6465 }
6566}
6667
6768prepare_inputs <- function (fn , par_inits , n_pars , extra_args_list , grad_fun , lower , upper ,
68- globals , packages , eval_standalone , output_dir , output_basename ) {
69+ globals , packages , eval_standalone , output_dir , output_basename , num_chains = 1 ) {
6970 user_inits <- TRUE
7071 if (is.null(par_inits )) {
7172 if (is.null(n_pars )) {
@@ -83,7 +84,26 @@ prepare_inputs <- function(fn, par_inits, n_pars, extra_args_list, grad_fun, low
8384 user_inits <- FALSE
8485 }
8586
86- fn1 <- prepare_function(fn , par_inits , extra_args_list , grad = FALSE )
87+ inits <- NULL
88+ if (is.list(par_inits )) {
89+ if (length(par_inits ) != num_chains ) {
90+ stop(" If par_inits is a list, it must have length equal to num_chains" ,
91+ call. = FALSE )
92+ }
93+ inits <- par_inits
94+ } else if (is.numeric(par_inits )) {
95+ inits <- lapply(seq_len(num_chains ), function (i ) { par_inits })
96+ } else if (is.function(par_inits )) {
97+ inits <- lapply(seq_len(num_chains ), function (i ) { par_inits(i ) })
98+ } else {
99+ stop(" par_inits must be NULL, a numeric vector, a list of numeric vectors, or a function" ,
100+ call. = FALSE )
101+ }
102+
103+ fn1 <- function (v ) { do.call(fn , c(list (v ), extra_args_list )) }
104+ for (chain in seq_len(num_chains )) {
105+ validate_function(fn1 , inits [[chain ]], extra_args_list , grad = FALSE )
106+ }
87107 fun_globals <- NULL
88108 fun_packages <- NULL
89109 if (isTRUE(eval_standalone )) {
@@ -101,7 +121,10 @@ prepare_inputs <- function(fn, par_inits, n_pars, extra_args_list, grad_fun, low
101121 fun_packages <- c(gp $ packages , packages )
102122 }
103123 if (! is.null(grad_fun )) {
104- gr1 <- prepare_function(grad_fun , par_inits , extra_args_list , grad = TRUE )
124+ gr1 <- function (v ) { do.call(grad_fun , c(list (v ), extra_args_list )) }
125+ for (chain in seq_len(num_chains )) {
126+ validate_function(gr1 , inits [[chain ]], extra_args_list , grad = TRUE )
127+ }
105128 if (isTRUE(eval_standalone )) {
106129 gr_gp <- future :: getGlobalsAndPackages(grad_fun , globals = globals )
107130 fun_globals <- c(fun_globals , gr_gp $ globals )
@@ -111,13 +134,13 @@ prepare_inputs <- function(fn, par_inits, n_pars, extra_args_list, grad_fun, low
111134 gr1 <- fn1
112135 }
113136
114- if ((length(par_inits ) > 1 ) && (length(lower ) == 1 )) {
115- lower <- rep(lower , length(par_inits ))
137+ if ((length(inits [[ 1 ]] ) > 1 ) && (length(lower ) == 1 )) {
138+ lower <- rep(lower , length(inits [[ 1 ]] ))
116139 }
117- if ((length(par_inits ) > 1 ) && (length(upper ) == 1 )) {
118- upper <- rep(upper , length(par_inits ))
140+ if ((length(inits [[ 1 ]] ) > 1 ) && (length(upper ) == 1 )) {
141+ upper <- rep(upper , length(inits [[ 1 ]] ))
119142 }
120- bounds_types <- sapply(seq_len(length(par_inits )), function (i ) {
143+ bounds_types <- sapply(seq_len(length(inits [[ 1 ]] )), function (i ) {
121144 if (lower [i ] != - Inf && upper [i ] != Inf ) {
122145 3
123146 } else if (lower [i ] != - Inf ) {
@@ -139,7 +162,9 @@ prepare_inputs <- function(fn, par_inits, n_pars, extra_args_list, grad_fun, low
139162
140163 init_filepath <- NULL
141164 if (user_inits ) {
142- init_filepath <- tempfile(fileext = " .json" , tmpdir = output_dir )
165+ init_filepath <- sapply(seq_len(num_chains ), function (i ) {
166+ tempfile(fileext = " .json" , tmpdir = output_dir )
167+ })
143168 }
144169
145170 structured <- list (
@@ -148,9 +173,9 @@ prepare_inputs <- function(fn, par_inits, n_pars, extra_args_list, grad_fun, low
148173 globals = fun_globals ,
149174 packages = fun_packages ,
150175 eval_standalone = eval_standalone ,
151- inits = par_inits ,
176+ inits = inits ,
152177 finite_diff = as.integer(is.null(grad_fun )),
153- Npars = length(par_inits ),
178+ Npars = length(inits [[ 1 ]] ),
154179 lower_bounds = lower ,
155180 upper_bounds = upper ,
156181 bounds_types = bounds_types ,
0 commit comments