Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Imports:
tidyr,
tidyselect
Suggests:
nhanesA,
testthat (>= 3.2.0),
withr
Remotes:
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

export(add_nominal_timepoints)
export(create_demographics_table)
export(download_nhanes_cache)
export(get_nominal_timepoints)
export(get_route_from_data_column)
export(reformat_data)
Expand All @@ -13,6 +14,7 @@ export(sample_covariates)
export(sample_covariates_bootstrap)
export(sample_covariates_mice)
export(sample_covariates_mvtnorm)
export(sample_covariates_nhanes)
importFrom(dplyr,"%>%")
importFrom(irxutils,"%<=%")
importFrom(irxutils,"%>=%")
Expand Down
105 changes: 105 additions & 0 deletions R/download_nhanes_cache.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#' Download and cache NHANES data locally
#'
#' Downloads all tables within the specified NHANES data groups for one or
#' more survey years, merges them into a single data frame per year, and
#' saves it as an RDS file. The resulting cache is used automatically by
#' [sample_covariates_nhanes()].
#'
#' @param groups character vector of NHANES data groups to download. Valid
#' values: `"DEMO"` (Demographics), `"LAB"` (Laboratory), `"EXAM"`
#' (Examination), `"Q"` (Questionnaire), `"DIET"` (Dietary).
#' Defaults to `c("DEMO", "LAB", "EXAM")`.
#' @param years character vector of NHANES survey cycles, e.g.
#' `c("2015-2016", "2017-2018")`. Defaults to `"2017-2018"`.
#' @param path directory where merged RDS files will be saved. Created
#' automatically if it does not exist. Defaults to the package-level cache
#' directory returned by `nhanes_default_cache_dir()`.
#' @param overwrite logical. If `FALSE` (default), a year that already has a
#' merged RDS in `path` is skipped. Set to `TRUE` to re-download and
#' overwrite.
#' @param ... additional arguments (currently unused)
#'
#' @details
#' Each survey year is saved as a single file `nhanes_<year>.rds`
#' (e.g. `nhanes_2017-2018.rds`) containing all variables from all downloaded
#' tables, merged on the SEQN respondent sequence number. Tables with
#' multiple rows per subject (e.g. dietary recall) are skipped automatically.
#'
#' Requires the `nhanesA` package.
#'
#' @returns the `path` directory, invisibly.
#'
#' @export
download_nhanes_cache <- function(
groups = c("DEMO", "LAB", "EXAM"),
years = "2017-2018",
path = nhanes_default_cache_dir(),
overwrite = FALSE,
...
) {
if (!requireNamespace("nhanesA", quietly = TRUE)) {
stop(
"Package 'nhanesA' is required to download NHANES data. ",
"Install it with: install.packages('nhanesA')",
call. = FALSE
)
}

if (!dir.exists(path)) {
dir.create(path, recursive = TRUE)
}

for (year in years) {
nhanes_year_suffix(year) # validates year; errors on unsupported values
out_file <- file.path(path, paste0("nhanes_", year, ".rds"))

if (file.exists(out_file) && !overwrite) {
message("Skipping ", year, " (cache exists; use overwrite = TRUE to re-download)")
next
}

year_end <- as.integer(sub(".*-", "", year))
table_list <- list()

for (group in groups) {
message("Fetching table list for group ", group, ", year ", year, " ...")
tbl_info <- nhanesA::nhanesTables(group, year_end)
# nhanesTables() returns a data.frame; table names are in Data.File.Name
tbl_names <- tbl_info[["Data.File.Name"]]

for (tbl_name in tbl_names) {
message(" Downloading ", tbl_name, " ...")
tbl_data <- tryCatch(
nhanesA::nhanes(tbl_name),
error = function(e) {
message(" Skipping ", tbl_name, ": ", conditionMessage(e))
NULL
}
)
if (is.null(tbl_data) || !"SEQN" %in% names(tbl_data)) next
if (anyDuplicated(tbl_data[["SEQN"]]) > 0) {
message(" Skipping ", tbl_name, " (multiple rows per subject)")
next
}
table_list[[tbl_name]] <- tbl_data
}
}

if (length(table_list) == 0) {
warning("No tables downloaded for year ", year, "; skipping.")
next
}

message("Merging ", length(table_list), " tables for ", year, " ...")
merged <- Reduce(
function(a, b) dplyr::full_join(a, b, by = "SEQN"),
table_list
)

saveRDS(merged, out_file)
message("Saved merged NHANES data (", nrow(merged), " subjects, ",
ncol(merged), " variables) to ", out_file)
}

invisible(path)
}
16 changes: 9 additions & 7 deletions R/sample_covariates.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
#' Sample covariates using a variety of methods
#'
#' @param method sampling method, one of `mvtnorm`, `bootstrap`, or `mice`.
#' E.g. `list(AGE = c(60, 80), WT = c(70, 100))`.
#'
#'
#' @param method sampling method, one of `mvtnorm`, `bootstrap`, `mice`, or
#' `nhanes`. E.g. `list(AGE = c(60, 80), WT = c(70, 100))`.
#' @param seed integer random seed passed to [set.seed()] for reproducibility.
#' Default `NULL` does not set a seed.
#' @param ... arguments passed to lower-level function(s).
#'
#'
#' @returns data.frame with covariates in each column
#'
#' @export
sample_covariates <- function(
method = c("mvtnorm", "mice", "bootstrap"),
method = c("mvtnorm", "mice", "bootstrap", "nhanes"),
seed = NULL,
...
) {
method <- rlang::arg_match(method)
do.call(paste0("sample_covariates_", method), args = list(...))
do.call(paste0("sample_covariates_", method), args = list(seed = seed, ...))
}
10 changes: 9 additions & 1 deletion R/sample_covariates_bootstrap.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#' Sample covariates using bootstrap
#'
#' @inheritParams sample_covariates_mice
#'
#' @param na.rm logical. If `TRUE` (default), rows with `NA` in any column are
#' dropped before sampling.
#'
#' @returns a data.frame with the simulated covariates, with `n_subjects`
#' rows and `p` columns
#'
Expand All @@ -10,8 +12,14 @@ sample_covariates_bootstrap <- function(
data,
n_subjects = nrow(data),
conditional = NULL,
seed = NULL,
na.rm = TRUE,
...
) {
if (!is.null(seed)) set.seed(seed)
if (na.rm) {
data <- data[stats::complete.cases(data), , drop = FALSE]
}
if(!is.null(conditional)) {
for(key in names(conditional)) {
data <- dplyr::filter(
Expand Down
6 changes: 5 additions & 1 deletion R/sample_covariates_mice.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
#' `list("WT" = c(40, 60), "BMI" = c(15, 25))`.
#' @param cont_method method used to predict continuous covariates within mice,
#' default is `pmm`.
#' @param replicates number of multiple imputations replicates to sample.
#' @param replicates number of multiple imputations replicates to sample.
#' Default is 1.
#' @param seed integer random seed passed to [set.seed()] for reproducibility.
#' Default `NULL` does not set a seed.
#' @param ... additional arguments passed to `mice::mice()` function
#'
#' @details missing values in `data` must be coded as NA
Expand All @@ -29,8 +31,10 @@ sample_covariates_mice <- function(
n_subjects = nrow(data),
cont_method = "pmm",
replicates = 1,
seed = NULL,
...
) {
if (!is.null(seed)) set.seed(seed)

# names of continuous covariates
cont_covs <- setdiff(names(data), cat_covs)
Expand Down
6 changes: 5 additions & 1 deletion R/sample_covariates_mvtnorm.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#' subjects in the data.
#' @param exponential sample from exponential distribution? Default `FALSE`.
#' @param conditional description...
#' @param seed integer random seed passed to [set.seed()] for reproducibility.
#' Default `NULL` does not set a seed.
#' @param ... additional arguments passed to `mvrnorm()` function
#'
#' @returns a data.frame with the simulated covariates, with `n_subjects`
Expand All @@ -23,9 +25,11 @@ sample_covariates_mvtnorm <- function(
n_subjects = nrow(data),
exponential = FALSE,
conditional = NULL,
seed = NULL,
...
) {

if (!is.null(seed)) set.seed(seed)

if(!is.null(conditional)) {
for(key in names(conditional)) {
data <- dplyr::filter(
Expand Down
Loading