Skip to content
Open
32 changes: 18 additions & 14 deletions R/draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,7 @@ draws.condmean <- function(data, data_ice = NULL, vars, method, ncores = 1, quie
stop("Invalid Method type")
}
sample_opts <- append(sample_opts, extra_opts)
x <- do.call(get_draws_mle, sample_opts)
return(x)
return(do.call(get_draws_mle, sample_opts))
}


Expand Down Expand Up @@ -313,12 +312,22 @@ get_draws_mle <- function(
max_sample_attempts
)
)

# If the first sample is on the original dataset we want to retain the model fit so
# we can provide it back to end users. To acomplish this we generate an environment
# for the model fit to be saved into by the model fitting function
model_env <- if (first_sample_orig) {
new.env(parent = emptyenv())
} else {
NULL
}

time_taken <- system.time({
initial_sample <- get_mmrm_sample(
ids = longdata$ids,
longdata = longdata,
method = method
method = method,
model_env = model_env
)
})

Expand All @@ -335,16 +344,9 @@ get_draws_mle <- function(
)
}


cl <- get_cluster(ncores)
mmrm_sample <- encap_get_mmrm_sample(cl, longdata, method)


# browser()
# get_mmrm_sample
# mmrm_sample(ids)
# clusterEvalQ(cl, fit_mmrm)

samples <- list()
n_failed_samples <- 0
logger <- progressLogger$new(n_target_samples, quiet = quiet)
Expand Down Expand Up @@ -394,6 +396,7 @@ get_draws_mle <- function(
samples = sample_list(samples),
data = longdata,
formula = longdata$formula,
fit = model_env$model,
n_failures = n_failed_samples
)
return(ret)
Expand All @@ -411,10 +414,10 @@ get_draws_mle <- function(
#' @param longdata R6 `longdata` object containing all relevant input data information.
#' @param method A `method` object as generated by either
#' [method_approxbayes()] or [method_condmean()].
#' @param ... Additional arguments passed onto [fit_mmrm()]
#'
#' @inherit sample_single return
get_mmrm_sample <- function(ids, longdata, method) {

get_mmrm_sample <- function(ids, longdata, method, ...) {
vars <- longdata$vars
dat <- longdata$get_data(ids, nmar.rm = TRUE, na.rm = TRUE)
model_df <- as_model_df(
Expand All @@ -430,7 +433,8 @@ get_mmrm_sample <- function(ids, longdata, method) {
group = dat[[vars$group]],
cov_struct = method$covariance,
REML = method$REML,
same_cov = method$same_cov
same_cov = method$same_cov,
...
)

if (sample$failed) {
Expand Down Expand Up @@ -681,7 +685,7 @@ validate.draws <- function(x, ...) {
has_class(x$samples, "sample_list"),
validate(x$samples),
is.null(x$n_failures) | is.numeric(x$n_failures),
is.null(x$fit) | has_class(x$fit, "stanfit"),
is.null(x$fit) | has_class(x$fit, "stanfit") | has_class(x$fit, 'mmrm'),
has_class(x$formula, "formula")
)
}
Expand Down
31 changes: 20 additions & 11 deletions R/mmrm.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,19 +167,23 @@ extract_params <- function(fit) {
#' @param cov_struct a character value. Specifies which covariance structure to use. Must be one of
#' `"us"`, `"toep"`, `"cs"` or `"ar1"`
#' @param REML logical. Specifies whether restricted maximum likelihood should be used
#' @param same_cov logical. Used to specify if a shared or individual covariance matrix should be
#' used per `group`
#' @param same_cov logical. Used to specify if a shared or individual covariance matrix should be used
#' per `group`
#' @param model_env an environment or NULL. If an environment is provided then the model fit object
#' will be recorded into it.
#' @name fit_mmrm
#'
#'
fit_mmrm <- function(designmat,
outcome,
subjid,
visit,
group,
cov_struct = c("us", "toep", "cs", "ar1"),
REML = TRUE,
same_cov = TRUE) {
fit_mmrm <- function(
designmat,
outcome,
subjid,
visit,
group,
cov_struct = c("us", "toep", "cs", "ar1"),
REML = TRUE,
same_cov = TRUE,
model_env = NULL
) {
dat_mmrm <- as_mmrm_df(
designmat = designmat,
outcome = outcome,
Expand All @@ -203,6 +207,11 @@ fit_mmrm <- function(designmat,
if (fit$failed) {
return(fit)
}

# If the user provided an environment to record the model in then provide them the model
if (is.environment(model_env)) {
model_env$model <- fit
}

# extract regression coefficients and covariance matrices
params <- extract_params(fit)
Expand Down
10 changes: 7 additions & 3 deletions man/fit_mmrm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/get_mmrm_sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-fullusage.R
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ test_that("Basic Usage - Bayesian", {



test_that("Basic Usage - Condmean", {
test_that("Basic Usage - Condmean", {

skip_if_not(is_full_test())

Expand Down