Skip to content

Commit b2857e3

Browse files
authored
Merge pull request #3 from InsightRX/add-nhanes-sampling
Add NHANES covariate sampling method
2 parents 0711c86 + a8a66b6 commit b2857e3

20 files changed

+1271
-13
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Imports:
2323
tidyr,
2424
tidyselect
2525
Suggests:
26+
nhanesA,
2627
testthat (>= 3.2.0),
2728
withr
2829
Remotes:

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
export(add_nominal_timepoints)
44
export(create_demographics_table)
5+
export(download_nhanes_cache)
56
export(get_nominal_timepoints)
67
export(get_route_from_data_column)
78
export(reformat_data)
@@ -13,6 +14,7 @@ export(sample_covariates)
1314
export(sample_covariates_bootstrap)
1415
export(sample_covariates_mice)
1516
export(sample_covariates_mvtnorm)
17+
export(sample_covariates_nhanes)
1618
importFrom(dplyr,"%>%")
1719
importFrom(irxutils,"%<=%")
1820
importFrom(irxutils,"%>=%")

R/download_nhanes_cache.R

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#' Download and cache NHANES data locally
2+
#'
3+
#' Downloads all tables within the specified NHANES data groups for one or
4+
#' more survey years, merges them into a single data frame per year, and
5+
#' saves it as an RDS file. The resulting cache is used automatically by
6+
#' [sample_covariates_nhanes()].
7+
#'
8+
#' @param groups character vector of NHANES data groups to download. Valid
9+
#' values: `"DEMO"` (Demographics), `"LAB"` (Laboratory), `"EXAM"`
10+
#' (Examination), `"Q"` (Questionnaire), `"DIET"` (Dietary).
11+
#' Defaults to `c("DEMO", "LAB", "EXAM")`.
12+
#' @param years character vector of NHANES survey cycles, e.g.
13+
#' `c("2015-2016", "2017-2018")`. Defaults to `"2017-2018"`.
14+
#' @param path directory where merged RDS files will be saved. Created
15+
#' automatically if it does not exist. Defaults to the package-level cache
16+
#' directory returned by `nhanes_default_cache_dir()`.
17+
#' @param overwrite logical. If `FALSE` (default), a year that already has a
18+
#' merged RDS in `path` is skipped. Set to `TRUE` to re-download and
19+
#' overwrite.
20+
#' @param ... additional arguments (currently unused)
21+
#'
22+
#' @details
23+
#' Each survey year is saved as a single file `nhanes_<year>.rds`
24+
#' (e.g. `nhanes_2017-2018.rds`) containing all variables from all downloaded
25+
#' tables, merged on the SEQN respondent sequence number. Tables with
26+
#' multiple rows per subject (e.g. dietary recall) are skipped automatically.
27+
#'
28+
#' Requires the `nhanesA` package.
29+
#'
30+
#' @returns the `path` directory, invisibly.
31+
#'
32+
#' @export
33+
download_nhanes_cache <- function(
34+
groups = c("DEMO", "LAB", "EXAM"),
35+
years = "2017-2018",
36+
path = nhanes_default_cache_dir(),
37+
overwrite = FALSE,
38+
...
39+
) {
40+
if (!requireNamespace("nhanesA", quietly = TRUE)) {
41+
stop(
42+
"Package 'nhanesA' is required to download NHANES data. ",
43+
"Install it with: install.packages('nhanesA')",
44+
call. = FALSE
45+
)
46+
}
47+
48+
if (!dir.exists(path)) {
49+
dir.create(path, recursive = TRUE)
50+
}
51+
52+
for (year in years) {
53+
nhanes_year_suffix(year) # validates year; errors on unsupported values
54+
out_file <- file.path(path, paste0("nhanes_", year, ".rds"))
55+
56+
if (file.exists(out_file) && !overwrite) {
57+
message("Skipping ", year, " (cache exists; use overwrite = TRUE to re-download)")
58+
next
59+
}
60+
61+
year_end <- as.integer(sub(".*-", "", year))
62+
table_list <- list()
63+
64+
for (group in groups) {
65+
message("Fetching table list for group ", group, ", year ", year, " ...")
66+
tbl_info <- nhanesA::nhanesTables(group, year_end)
67+
# nhanesTables() returns a data.frame; table names are in Data.File.Name
68+
tbl_names <- tbl_info[["Data.File.Name"]]
69+
70+
for (tbl_name in tbl_names) {
71+
message(" Downloading ", tbl_name, " ...")
72+
tbl_data <- tryCatch(
73+
nhanesA::nhanes(tbl_name),
74+
error = function(e) {
75+
message(" Skipping ", tbl_name, ": ", conditionMessage(e))
76+
NULL
77+
}
78+
)
79+
if (is.null(tbl_data) || !"SEQN" %in% names(tbl_data)) next
80+
if (anyDuplicated(tbl_data[["SEQN"]]) > 0) {
81+
message(" Skipping ", tbl_name, " (multiple rows per subject)")
82+
next
83+
}
84+
table_list[[tbl_name]] <- tbl_data
85+
}
86+
}
87+
88+
if (length(table_list) == 0) {
89+
warning("No tables downloaded for year ", year, "; skipping.")
90+
next
91+
}
92+
93+
message("Merging ", length(table_list), " tables for ", year, " ...")
94+
merged <- Reduce(
95+
function(a, b) dplyr::full_join(a, b, by = "SEQN"),
96+
table_list
97+
)
98+
99+
saveRDS(merged, out_file)
100+
message("Saved merged NHANES data (", nrow(merged), " subjects, ",
101+
ncol(merged), " variables) to ", out_file)
102+
}
103+
104+
invisible(path)
105+
}

R/sample_covariates.R

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
#' Sample covariates using a variety of methods
2-
#'
3-
#' @param method sampling method, one of `mvtnorm`, `bootstrap`, or `mice`.
4-
#' E.g. `list(AGE = c(60, 80), WT = c(70, 100))`.
5-
#'
2+
#'
3+
#' @param method sampling method, one of `mvtnorm`, `bootstrap`, `mice`, or
4+
#' `nhanes`. E.g. `list(AGE = c(60, 80), WT = c(70, 100))`.
5+
#' @param seed integer random seed passed to [set.seed()] for reproducibility.
6+
#' Default `NULL` does not set a seed.
67
#' @param ... arguments passed to lower-level function(s).
7-
#'
8+
#'
89
#' @returns data.frame with covariates in each column
910
#'
1011
#' @export
1112
sample_covariates <- function(
12-
method = c("mvtnorm", "mice", "bootstrap"),
13+
method = c("mvtnorm", "mice", "bootstrap", "nhanes"),
14+
seed = NULL,
1315
...
1416
) {
1517
method <- rlang::arg_match(method)
16-
do.call(paste0("sample_covariates_", method), args = list(...))
18+
do.call(paste0("sample_covariates_", method), args = list(seed = seed, ...))
1719
}

R/sample_covariates_bootstrap.R

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#' Sample covariates using bootstrap
22
#'
33
#' @inheritParams sample_covariates_mice
4-
#'
4+
#' @param na.rm logical. If `TRUE` (default), rows with `NA` in any column are
5+
#' dropped before sampling.
6+
#'
57
#' @returns a data.frame with the simulated covariates, with `n_subjects`
68
#' rows and `p` columns
79
#'
@@ -10,8 +12,14 @@ sample_covariates_bootstrap <- function(
1012
data,
1113
n_subjects = nrow(data),
1214
conditional = NULL,
15+
seed = NULL,
16+
na.rm = TRUE,
1317
...
1418
) {
19+
if (!is.null(seed)) set.seed(seed)
20+
if (na.rm) {
21+
data <- data[stats::complete.cases(data), , drop = FALSE]
22+
}
1523
if(!is.null(conditional)) {
1624
for(key in names(conditional)) {
1725
data <- dplyr::filter(

R/sample_covariates_mice.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
#' `list("WT" = c(40, 60), "BMI" = c(15, 25))`.
1313
#' @param cont_method method used to predict continuous covariates within mice,
1414
#' default is `pmm`.
15-
#' @param replicates number of multiple imputations replicates to sample.
15+
#' @param replicates number of multiple imputations replicates to sample.
1616
#' Default is 1.
17+
#' @param seed integer random seed passed to [set.seed()] for reproducibility.
18+
#' Default `NULL` does not set a seed.
1719
#' @param ... additional arguments passed to `mice::mice()` function
1820
#'
1921
#' @details missing values in `data` must be coded as NA
@@ -29,8 +31,10 @@ sample_covariates_mice <- function(
2931
n_subjects = nrow(data),
3032
cont_method = "pmm",
3133
replicates = 1,
34+
seed = NULL,
3235
...
3336
) {
37+
if (!is.null(seed)) set.seed(seed)
3438

3539
# names of continuous covariates
3640
cont_covs <- setdiff(names(data), cat_covs)

R/sample_covariates_mvtnorm.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#' subjects in the data.
1010
#' @param exponential sample from exponential distribution? Default `FALSE`.
1111
#' @param conditional description...
12+
#' @param seed integer random seed passed to [set.seed()] for reproducibility.
13+
#' Default `NULL` does not set a seed.
1214
#' @param ... additional arguments passed to `mvrnorm()` function
1315
#'
1416
#' @returns a data.frame with the simulated covariates, with `n_subjects`
@@ -23,9 +25,11 @@ sample_covariates_mvtnorm <- function(
2325
n_subjects = nrow(data),
2426
exponential = FALSE,
2527
conditional = NULL,
28+
seed = NULL,
2629
...
2730
) {
28-
31+
if (!is.null(seed)) set.seed(seed)
32+
2933
if(!is.null(conditional)) {
3034
for(key in names(conditional)) {
3135
data <- dplyr::filter(

0 commit comments

Comments
 (0)