Skip to content

Commit 0942101

Browse files
committed
Support list and function inputs for inits
1 parent 9eef7d0 commit 0942101

File tree

14 files changed

+119
-43
lines changed

14 files changed

+119
-43
lines changed

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: StanEstimators
22
Title: Estimate Parameters for Arbitrary R Functions using 'Stan'
3-
Version: 0.2.0
3+
Version: 0.2.0.9000
44
Authors@R: c(
55
person(given = c("Andrew", "R."), family = "Johnson", role = c("aut", "cre"),
66
email = "andrew.johnson@arjohnsonau.com",
@@ -20,7 +20,7 @@ Description: Allows for the estimation of parameters for 'R' functions using the
2020
License: MIT + file LICENSE
2121
Encoding: UTF-8
2222
Roxygen: list(markdown = TRUE)
23-
RoxygenNote: 7.3.1
23+
RoxygenNote: 7.3.3
2424
NeedsCompilation: yes
2525
UseLTO: true
2626
SystemRequirements: GNU make

R/diagnose.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
#' Check gradient estimation using Stan's 'Diagnose' method
44
#'
55
#' @param fn Function to estimate parameters for
6-
#' @param par_inits Initial values for parameters
6+
#' @param par_inits Initial values for parameters. This can either be a numeric vector
7+
#' of initial values (which will be used for all chains), a list of numeric vectors (of length
8+
#' equal to the number of chains), a function taking a single argument (the chain ID) and
9+
#' returning a numeric vector of initial values, or NULL (in which case Stan will
10+
#' generate initial values automatically).
711
#' (must be specified if `n_pars` is NULL)
812
#' @param n_pars Number of parameters to estimate
913
#' (must be specified if `par_inits` is NULL)
@@ -51,7 +55,7 @@ stan_diagnose <- function(fn, par_inits = NULL, n_pars = NULL, additional_args =
5155
args <- build_stan_call(method = "diagnose",
5256
method_args = "",
5357
data_file = inputs$data_filepath,
54-
init = inputs$init_filepath,
58+
init = inputs$init_filepath[1],
5559
seed = seed,
5660
output_args = output)
5761
call_stan(args, inputs, quiet)

R/laplace.R

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ setMethod("summary", "StanLaplace", function(object, ...) {
1818
#' Estimate parameters using Stan's laplace algorithm
1919
#'
2020
#' @param fn Function to estimate parameters for
21-
#' @param par_inits Initial values for parameters
21+
#' @param par_inits Initial values for parameters. This can either be a numeric vector
22+
#' of initial values (which will be used for all chains), a list of numeric vectors (of length
23+
#' equal to the number of chains), a function taking a single argument (the chain ID) and
24+
#' returning a numeric vector of initial values, or NULL (in which case Stan will
25+
#' generate initial values automatically).
2226
#' (must be specified if `n_pars` is NULL)
2327
#' @param n_pars Number of parameters to estimate
2428
#' (must be specified if `par_inits` is NULL)
@@ -94,11 +98,11 @@ stan_laplace <- function(fn, par_inits = NULL, n_pars = NULL, additional_args =
9498
mode_vals <- opt@estimates[, -1]
9599
}
96100
mode_vals <- as.numeric(mode_vals)
97-
if (length(mode_vals) != length(par_inits)) {
101+
if (length(mode_vals) != length(inputs$inits[[1]])) {
98102
stop("The number of mode values does not match the number of parameter ",
99103
"inits!", .call = FALSE)
100104
}
101-
write_inits(mode_vals, mode_file)
105+
write_inits(list(mode_vals), list(mode_file))
102106
method_args <- list(
103107
mode = mode_file,
104108
jacobian = format_bool(jacobian),
@@ -115,7 +119,7 @@ stan_laplace <- function(fn, par_inits = NULL, n_pars = NULL, additional_args =
115119
args <- build_stan_call(method = "laplace",
116120
method_args = method_args,
117121
data_file = inputs$data_filepath,
118-
init = inputs$init_filepath,
122+
init = inputs$init_filepath[1],
119123
seed = seed,
120124
output_args = output)
121125

R/optimize.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ setMethod("summary", "StanOptimize", function(object, ...) {
1818
#' Estimate parameters using Stan's optimization algorithms
1919
#'
2020
#' @param fn Function to estimate parameters for
21-
#' @param par_inits Initial values for parameters
21+
#' @param par_inits Initial values for parameters. This can either be a numeric vector
22+
#' of initial values (which will be used for all chains), a list of numeric vectors (of length
23+
#' equal to the number of chains), a function taking a single argument (the chain ID) and
24+
#' returning a numeric vector of initial values, or NULL (in which case Stan will
25+
#' generate initial values automatically).
2226
#' (must be specified if `n_pars` is NULL)
2327
#' @param n_pars Number of parameters to estimate
2428
#' (must be specified if `par_inits` is NULL)
@@ -108,7 +112,7 @@ stan_optimize <- function(fn, par_inits = NULL, n_pars = NULL, additional_args =
108112
args <- build_stan_call(method = "optimize",
109113
method_args = method_args,
110114
data_file = inputs$data_filepath,
111-
init = inputs$init_filepath,
115+
init = inputs$init_filepath[1],
112116
seed = seed,
113117
output_args = output)
114118
call_stan(args, inputs, quiet)

R/pathfinder.R

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ setMethod("summary", "StanPathfinder", function(object, ...) {
1818
#' Estimate parameters using Stan's pathfinder algorithm
1919
#'
2020
#' @param fn Function to estimate parameters for
21-
#' @param par_inits Initial values for parameters
21+
#' @param par_inits Initial values for parameters. This can either be a numeric vector
22+
#' of initial values (which will be used for all chains), a list of numeric vectors (of length
23+
#' equal to the number of chains), a function taking a single argument (the chain ID) and
24+
#' returning a numeric vector of initial values, or NULL (in which case Stan will
25+
#' generate initial values automatically).
2226
#' (must be specified if `n_pars` is NULL)
2327
#' @param n_pars Number of parameters to estimate
2428
#' (must be specified if `par_inits` is NULL)
@@ -80,7 +84,8 @@ stan_pathfinder <- function(fn, par_inits = NULL, n_pars = NULL, additional_args
8084
max_lbfgs_iters = NULL, num_draws = NULL,
8185
num_elbo_draws = NULL) {
8286
inputs <- prepare_inputs(fn, par_inits, n_pars, additional_args, grad_fun, lower, upper,
83-
globals, packages, eval_standalone, output_dir, output_basename)
87+
globals, packages, eval_standalone, output_dir, output_basename,
88+
ifelse(is.null(num_paths), 4, num_paths)) # Default to 4 pathfinders
8489
method_args <- list(
8590
init_alpha = init_alpha,
8691
tol_obj = tol_obj,
@@ -108,7 +113,7 @@ stan_pathfinder <- function(fn, par_inits = NULL, n_pars = NULL, additional_args
108113
args <- build_stan_call(method = "pathfinder",
109114
method_args = method_args,
110115
data_file = inputs$data_filepath,
111-
init = inputs$init_filepath,
116+
init = paste0(inputs$init_filepath, collapse = ","), # Pass all inits for multiple pathfinders
112117
seed = seed,
113118
output_args = output)
114119

R/sample.R

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ setMethod("summary", "StanMCMC", function(object, ...) {
1818
#' Estimate parameters using Stan's sampling algorithms
1919
#'
2020
#' @param fn Function to estimate parameters for
21-
#' @param par_inits Initial values for parameters
21+
#' @param par_inits Initial values for parameters. This can either be a numeric vector
22+
#' of initial values (which will be used for all chains), a list of numeric vectors (of length
23+
#' equal to the number of chains), a function taking a single argument (the chain ID) and
24+
#' returning a numeric vector of initial values, or NULL (in which case Stan will
25+
#' generate initial values automatically).
2226
#' (must be specified if `n_pars` is NULL)
2327
#' @param n_pars Number of parameters to estimate
2428
#' (must be specified if `par_inits` is NULL)
@@ -121,7 +125,9 @@ stan_sample <- function(fn, par_inits = NULL, n_pars = NULL, additional_args = l
121125
call. = FALSE)
122126
}
123127
inputs <- prepare_inputs(fn, par_inits, n_pars, additional_args, grad_fun, lower, upper,
124-
globals, packages, eval_standalone, output_dir, output_basename)
128+
globals, packages, eval_standalone, output_dir, output_basename,
129+
num_chains)
130+
125131
method_args <- list(
126132
algorithm = algorithm,
127133
algorithm_args = list(
@@ -161,7 +167,7 @@ stan_sample <- function(fn, par_inits = NULL, n_pars = NULL, additional_args = l
161167
args <- build_stan_call(method = "sample",
162168
method_args = method_args,
163169
data_file = inputs$data_filepath,
164-
init = inputs$init_filepath,
170+
init = inputs$init_filepath[chain],
165171
seed = seed,
166172
output_args = output,
167173
id = chain)
@@ -225,7 +231,7 @@ stan_sample <- function(fn, par_inits = NULL, n_pars = NULL, additional_args = l
225231
args <- build_stan_call(method = "sample",
226232
method_args = method_args,
227233
data_file = inputs$data_filepath,
228-
init = inputs$init_filepath,
234+
init = paste0(inputs$init_filepath, collapse = ","),
229235
seed = seed,
230236
output_args = output)
231237
call_stan_impl(args, inputs)

R/util.R

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ inits_to_json <- function(inits) {
1919
}
2020

2121
write_inits <- function(inits, init_filepath) {
22-
dat_string <- inits_to_json(inits)
23-
writeLines(dat_string, con = init_filepath)
24-
dat_string
22+
lapply(seq_len(length(inits)), function(i) {
23+
dat_string <- inits_to_json(inits[[i]])
24+
writeLines(dat_string, con = init_filepath[[i]])
25+
dat_string
26+
})
2527
}
2628

2729
prepare_and_write_json <- function(what, input_list) {
@@ -44,10 +46,9 @@ with_env <- function(f, e=parent.frame()) {
4446
f
4547
}
4648

47-
prepare_function <- function(fn, inits, extra_args_list, grad = FALSE) {
48-
fn_wrapper <- function(v) { do.call(fn, c(list(v), extra_args_list)) }
49+
validate_function <- function(fn, inits, extra_args_list, grad = FALSE) {
4950
fn_type <- ifelse(isTRUE(grad), "Gradient", "Log-Likelihood")
50-
test_fn <- try(invisible(fn_wrapper(inits)), silent = TRUE)
51+
test_fn <- try(invisible(fn(inits)), silent = TRUE)
5152
correct_length <- ifelse(isTRUE(grad), length(inits), 1)
5253

5354
if (inherits(test_fn, "try-error")) {
@@ -60,12 +61,12 @@ prepare_function <- function(fn, inits, extra_args_list, grad = FALSE) {
6061
stop(fn_type, " function should have return of length ", correct_length,
6162
", but return was length ", length(test_fn), "instead!", call. = FALSE)
6263
} else {
63-
fn_wrapper
64+
invisible(NULL)
6465
}
6566
}
6667

6768
prepare_inputs <- function(fn, par_inits, n_pars, extra_args_list, grad_fun, lower, upper,
68-
globals, packages, eval_standalone, output_dir, output_basename) {
69+
globals, packages, eval_standalone, output_dir, output_basename, num_chains = 1) {
6970
user_inits <- TRUE
7071
if (is.null(par_inits)) {
7172
if (is.null(n_pars)) {
@@ -83,7 +84,26 @@ prepare_inputs <- function(fn, par_inits, n_pars, extra_args_list, grad_fun, low
8384
user_inits <- FALSE
8485
}
8586

86-
fn1 <- prepare_function(fn, par_inits, extra_args_list, grad = FALSE)
87+
inits <- NULL
88+
if (is.list(par_inits)) {
89+
if (length(par_inits) != num_chains) {
90+
stop("If par_inits is a list, it must have length equal to num_chains",
91+
call. = FALSE)
92+
}
93+
inits <- par_inits
94+
} else if (is.numeric(par_inits)) {
95+
inits <- lapply(seq_len(num_chains), function(i) { par_inits })
96+
} else if (is.function(par_inits)) {
97+
inits <- lapply(seq_len(num_chains), function(i) { par_inits(i) })
98+
} else {
99+
stop("par_inits must be NULL, a numeric vector, a list of numeric vectors, or a function",
100+
call. = FALSE)
101+
}
102+
103+
fn1 <- function(v) { do.call(fn, c(list(v), extra_args_list)) }
104+
for (chain in seq_len(num_chains)) {
105+
validate_function(fn1, inits[[chain]], extra_args_list, grad = FALSE)
106+
}
87107
fun_globals <- NULL
88108
fun_packages <- NULL
89109
if (isTRUE(eval_standalone)) {
@@ -101,7 +121,10 @@ prepare_inputs <- function(fn, par_inits, n_pars, extra_args_list, grad_fun, low
101121
fun_packages <- c(gp$packages, packages)
102122
}
103123
if (!is.null(grad_fun)) {
104-
gr1 <- prepare_function(grad_fun, par_inits, extra_args_list, grad = TRUE)
124+
gr1 <- function(v) { do.call(grad_fun, c(list(v), extra_args_list)) }
125+
for (chain in seq_len(num_chains)) {
126+
validate_function(gr1, inits[[chain]], extra_args_list, grad = TRUE)
127+
}
105128
if (isTRUE(eval_standalone)) {
106129
gr_gp <- future::getGlobalsAndPackages(grad_fun, globals = globals)
107130
fun_globals <- c(fun_globals, gr_gp$globals)
@@ -111,13 +134,13 @@ prepare_inputs <- function(fn, par_inits, n_pars, extra_args_list, grad_fun, low
111134
gr1 <- fn1
112135
}
113136

114-
if ((length(par_inits) > 1) && (length(lower) == 1)) {
115-
lower <- rep(lower, length(par_inits))
137+
if ((length(inits[[1]]) > 1) && (length(lower) == 1)) {
138+
lower <- rep(lower, length(inits[[1]]))
116139
}
117-
if ((length(par_inits) > 1) && (length(upper) == 1)) {
118-
upper <- rep(upper, length(par_inits))
140+
if ((length(inits[[1]]) > 1) && (length(upper) == 1)) {
141+
upper <- rep(upper, length(inits[[1]]))
119142
}
120-
bounds_types <- sapply(seq_len(length(par_inits)), function(i) {
143+
bounds_types <- sapply(seq_len(length(inits[[1]])), function(i) {
121144
if (lower[i] != -Inf && upper[i] != Inf) {
122145
3
123146
} else if (lower[i] != -Inf) {
@@ -139,7 +162,9 @@ prepare_inputs <- function(fn, par_inits, n_pars, extra_args_list, grad_fun, low
139162

140163
init_filepath <- NULL
141164
if (user_inits) {
142-
init_filepath <- tempfile(fileext = ".json", tmpdir = output_dir)
165+
init_filepath <- sapply(seq_len(num_chains), function(i) {
166+
tempfile(fileext = ".json", tmpdir = output_dir)
167+
})
143168
}
144169

145170
structured <- list(
@@ -148,9 +173,9 @@ prepare_inputs <- function(fn, par_inits, n_pars, extra_args_list, grad_fun, low
148173
globals = fun_globals,
149174
packages = fun_packages,
150175
eval_standalone = eval_standalone,
151-
inits = par_inits,
176+
inits = inits,
152177
finite_diff = as.integer(is.null(grad_fun)),
153-
Npars = length(par_inits),
178+
Npars = length(inits[[1]]),
154179
lower_bounds = lower,
155180
upper_bounds = upper,
156181
bounds_types = bounds_types,

R/variational.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ setMethod("summary", "StanVariational", function(object, ...) {
1818
#' Estimate parameters using Stan's variational inference algorithms
1919
#'
2020
#' @param fn Function to estimate parameters for
21-
#' @param par_inits Initial values for parameters
21+
#' @param par_inits Initial values for parameters. This can either be a numeric vector
22+
#' of initial values (which will be used for all chains), a list of numeric vectors (of length
23+
#' equal to the number of chains), a function taking a single argument (the chain ID) and
24+
#' returning a numeric vector of initial values, or NULL (in which case Stan will
25+
#' generate initial values automatically).
2226
#' (must be specified if `n_pars` is NULL)
2327
#' @param n_pars Number of parameters to estimate
2428
#' (must be specified if `par_inits` is NULL)
@@ -104,7 +108,7 @@ stan_variational <- function(fn, par_inits = NULL, n_pars = NULL, additional_arg
104108
args <- build_stan_call(method = "variational",
105109
method_args = method_args,
106110
data_file = inputs$data_filepath,
107-
init = inputs$init_filepath,
111+
init = inputs$init_filepath[1],
108112
seed = seed,
109113
output_args = output)
110114

man/stan_diagnose.Rd

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/stan_laplace.Rd

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)