Skip to content

Commit 8355a59

Browse files
committed
feat: option to use var for train/test split
1 parent 95f167c commit 8355a59

File tree

3 files changed

+43
-6
lines changed

3 files changed

+43
-6
lines changed

R/per_sd_metrics.R

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#' @param prs_col Name of PRS column (character).
55
#' @param seed A random number to be set as a seed for the training and testing sampling to be reproducible.
66
#' @param recipe_var A recipe from recipes::recipe() object, to be provided optionally, in case the recipe originally included in the model (status ~ prs + scaled_centered_pc(1:10) + age_analysis) is not what you need for your PRS.
7+
#' @param split_var Name of column with two categories that define split between training and testing sets (character).
78
#'
89
#' @return A list with OR, AUC, delta AUC, and ROC curve for AUC with and without PRS.
910
#' @importFrom stats anova
@@ -56,7 +57,13 @@
5657
#' per_sd_metrics(data_mock, prs_col_mock, seed)
5758
#'
5859

59-
per_sd_metrics <- function(dataset, prs_col, seed, recipe_var = NULL) {
60+
per_sd_metrics <- function(
61+
dataset,
62+
prs_col,
63+
seed,
64+
recipe_var = NULL,
65+
split_var = NULL
66+
) {
6067
stopifnot(
6168
is.double(dataset |> dplyr::pull({{ prs_col }})),
6269
is.factor(dataset |> dplyr::pull(status))
@@ -67,10 +74,23 @@ per_sd_metrics <- function(dataset, prs_col, seed, recipe_var = NULL) {
6774
# Set PRS name with "norm" because we use the normalized version for the analyses
6875
norm_prs <- paste0("norm_", prs_col)
6976

70-
split <- rsample::initial_split(dataset, strata = status, prop = 0.75)
71-
72-
train <- rsample::training(split)
73-
test <- rsample::testing(split)
77+
if (is.null(split_var)) {
78+
split <- rsample::initial_split(dataset, strata = status, prop = 0.75)
79+
80+
train <- rsample::training(split)
81+
test <- rsample::testing(split)
82+
} else {
83+
# Set category with more entries as the training set
84+
n_split_var <- dataset |>
85+
group_by(get(split_var)) |>
86+
count() |>
87+
arrange(desc(n))
88+
train_var <- head(n_split_var, n = 1) |> pull(`get(split_var)`)
89+
test_var <- setdiff(n_split_var$`get(split_var)`, train_var)
90+
91+
train <- dataset |> filter(get(split_var) == train_var)
92+
test <- dataset |> filter(get(split_var) == test_var)
93+
}
7494

7595
control_stats_train <- get_control_stats(train) |> as.data.frame()
7696

man/per_sd_metrics.Rd

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-per_sd_metrics.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,18 @@ test_that("per_sd_metrics returns a list with non-empty variables when an extern
129129
# Test
130130
expect_length(res, 12)
131131
})
132+
133+
test_that("per_sd_metrics function returns full list result when provided a split variable", {
134+
data_mock <- setup_mock_df()
135+
136+
# Run
137+
res <- per_sd_metrics(
138+
dataset = data_mock,
139+
prs_col = "prs_test",
140+
seed = 82,
141+
split_var = "version"
142+
)
143+
144+
# Test
145+
expect_length(res, 12)
146+
})

0 commit comments

Comments
 (0)