Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: irxforge
Title: Forging data for pharmacometric analyses
Version: 0.0.0.9000
Version: 0.0.0.9001
Authors@R: c(
person("Ron", "Keizer", email = "ron@insight-rx.com", role = c("cre", "aut")),
person("Michael", "McCarthy", email = "michael.mccarthy@insight-rx.com", role = "ctb"),
Expand Down
90 changes: 74 additions & 16 deletions R/sample_covariates_mvtnorm.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
#' Sample covariates from multivariate normal distributions
#'
#' Samples from a multivariate normal distribution either derived from observed
#' data or specified directly via `means` plus a covariance matrix (`sigma`) or
#' standard deviations (`sd`).
#'
#' @param data data.frame (n x p) containing the original, observed,
#' time-invariant covariates (ID should not be included) that will be used to
#' inform the imputation.
#' inform the imputation. Can be `NULL` when `means` and either `sigma` or
#' `sd` are supplied directly. Ignored with a warning when `means` is also
#' provided.
#' @param means named numeric vector of means for each covariate. When supplied,
#' the distribution is specified directly and `data` is ignored. Must be
#' supplied together with either `sigma` or `sd` when `data` is `NULL`.
#' @param sigma named numeric matrix (p x p) giving the full covariance matrix.
#' Takes precedence over `sd` when both are provided. Column and row names
#' must match the names in `means`.
#' @param sd named numeric vector of standard deviations. Used to construct a
#' diagonal covariance matrix (`diag(sd^2)`) when `sigma` is not provided.
#' Names must match the names in `means`.
#' @param cat_covs character vector containing the names of the categorical
#' covariates in orgCovs.
#' @param n_subjects number of simulated subjects, default is the number of
#' subjects in the data.
#' @param n_subjects number of simulated subjects. Defaults to `nrow(data)`
#' when `data` is provided; required (no default) when `data` is `NULL`.
#' @param exponential sample from exponential distribution? Default `FALSE`.
#' Only applies when means/covariance are derived from `data`.
#' @param conditional description...
#' @param seed integer random seed passed to [set.seed()] for reproducibility.
#' Default `NULL` does not set a seed.
Expand All @@ -20,21 +36,63 @@
#'
#' @export
sample_covariates_mvtnorm <- function(
data,
data = NULL,
means = NULL,
sigma = NULL,
sd = NULL,
cat_covs = NULL,
n_subjects = nrow(data),
n_subjects = if (!is.null(data)) nrow(data) else stop("`n_subjects` must be specified when `data` is NULL."),
exponential = FALSE,
conditional = NULL,
seed = NULL,
...
) {
if (!is.null(seed)) set.seed(seed)

## Branch: distribution specified directly via means + sigma/sd
if (!is.null(means)) {
if (!is.null(data)) {
warning("`data` is ignored when `means` is provided.")
}
if (is.null(sigma) && is.null(sd)) {
stop("When `means` is supplied, either `sigma` or `sd` must also be provided.")
}
if (!is.null(sigma)) {
cov_mat <- sigma
} else {
if (length(sd) != length(means)) {
stop("`sd` must have the same length as `means`.")
}
cov_mat <- diag(sd^2)
if (!is.null(names(sd))) {
rownames(cov_mat) <- names(sd)
colnames(cov_mat) <- names(sd)
} else if (!is.null(names(means))) {
rownames(cov_mat) <- names(means)
colnames(cov_mat) <- names(means)
}
}
out <- mvtnorm::rmvnorm(
n_subjects,
mean = means,
sigma = cov_mat,
...
) |>
as.data.frame()
if (!is.null(names(means))) names(out) <- names(means)
return(out)
}

## Branch: derive distribution from data
if (is.null(data)) {
stop("Either `data` or `means` with `sigma`/`sd` must be provided.")
}

if(!is.null(conditional)) {
for(key in names(conditional)) {
data <- dplyr::filter(
data,
.data[[key]] >= min(conditional[[key]]) &
.data[[key]] >= min(conditional[[key]]) &
.data[[key]] <= max(conditional[[key]])
)
}
Expand All @@ -44,33 +102,33 @@ sample_covariates_mvtnorm <- function(
# FIXME: This code does nothing currently... is this function intended to
# work with categorical covariates? or only continuous? If the latter, how
# do we handle categorical?
cont_covs <- setdiff(names(data), cat_covs)
cont_covs <- setdiff(names(data), cat_covs)
miss_vars <- names(data)[colSums(is.na(data)) > 0]

## Get distribution and sample
if(exponential) {
# FIXME: This fails if there are zeroes or negative numbers. Should add some
# safety rails.
means <- apply(data, 2, function(x) mean(log(x)))
cov_mat <- stats::cov(log(data))
data_means <- apply(data, 2, function(x) mean(log(x)))
cov_mat <- stats::cov(log(data))
out <- mvtnorm::rmvnorm(
n_subjects,
mean = means,
n_subjects,
mean = data_means,
sigma = cov_mat
) |>
exp() |>
as.data.frame()
} else {
means <- apply(data, 2, mean)
data_means <- apply(data, 2, mean)
cov_mat <- stats::cov(data)
out <- mvtnorm::rmvnorm(
n_subjects,
mean = means,
n_subjects,
mean = data_means,
sigma = cov_mat
) |>
as.data.frame()
}

if (tibble::is_tibble(data)) out <- tibble::as_tibble(out)
out
out
}
35 changes: 28 additions & 7 deletions man/sample_covariates_mvtnorm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

78 changes: 78 additions & 0 deletions tests/testthat/test-sample_covariates_mvtnorm.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,81 @@ test_that("different seeds produce different output", {
out2 <- sample_covariates_mvtnorm(dat, n_subjects = 20, seed = 2)
expect_false(identical(out1, out2))
})

# --- Direct distribution specification (means + sigma / sds) ---

test_that("means + sigma samples correct number of rows and columns", {
mu <- c(x = 10, y = 20)
S <- matrix(c(4, 1, 1, 9), nrow = 2, dimnames = list(c("x","y"), c("x","y")))
out <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 100)
expect_equal(nrow(out), 100)
expect_equal(ncol(out), 2)
expect_named(out, c("x", "y"))
})

test_that("means + sigma samples near the specified mean (large n)", {
mu <- c(AGE = 40, WT = 70)
S <- diag(c(25, 100))
out <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 5000, seed = 42)
expect_equal(mean(out$AGE), 40, tolerance = 1)
expect_equal(mean(out$WT), 70, tolerance = 2)
})

test_that("means + sd constructs diagonal covariance and samples correctly", {
mu <- c(x = 5, y = 50)
sds <- c(x = 1, y = 10)
out <- sample_covariates_mvtnorm(means = mu, sd = sds, n_subjects = 5000, seed = 7)
expect_named(out, c("x", "y"))
expect_equal(mean(out$x), 5, tolerance = 0.2)
expect_equal(mean(out$y), 50, tolerance = 2)
expect_equal(sd(out$x), 1, tolerance = 0.1)
expect_equal(sd(out$y), 10, tolerance = 1)
})

test_that("means + sd: length mismatch raises error", {
expect_error(
sample_covariates_mvtnorm(means = c(a = 1, b = 2), sd = c(1, 2, 3), n_subjects = 10),
"`sd` must have the same length"
)
})

test_that("means without sigma or sd raises error", {
expect_error(
sample_covariates_mvtnorm(means = c(x = 1), n_subjects = 10),
"either `sigma` or `sd` must also be provided"
)
})

test_that("data = NULL without means raises error", {
expect_error(
sample_covariates_mvtnorm(data = NULL, n_subjects = 10),
"Either `data` or `means`"
)
})

test_that("warning is issued when both data and means are provided", {
dat <- data.frame(x = rnorm(50), y = rnorm(50))
mu <- c(x = 0, y = 0)
S <- diag(2)
expect_warning(
sample_covariates_mvtnorm(data = dat, means = mu, sigma = S, n_subjects = 10),
"`data` is ignored"
)
})

test_that("n_subjects is required when data is NULL", {
mu <- c(x = 0)
S <- matrix(1)
expect_error(
sample_covariates_mvtnorm(means = mu, sigma = S),
"n_subjects.*must be specified"
)
})

test_that("seed produces reproducible output with means + sigma", {
mu <- c(x = 0, y = 1)
S <- diag(2)
out1 <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 30, seed = 5)
out2 <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 30, seed = 5)
expect_equal(out1, out2)
})
Loading