diff --git a/R/adversarial_rf.R b/R/adversarial_rf.R index db47eeae..82e55f7b 100644 --- a/R/adversarial_rf.R +++ b/R/adversarial_rf.R @@ -14,7 +14,7 @@ #' @param early_stop Terminate loop if performance fails to improve from one #' round to the next? #' @param prune Impose \code{min_node_size} by pruning? -#' @param verbose Print discriminator accuracy after each round? +#' @param verbose Print discriminator accuracy after each round? Will also show additional warnings. #' @param parallel Compute in parallel? Must register backend beforehand, e.g. #' via \code{doParallel} or \code{doFuture}; see examples. #' @param ... Extra parameters to be passed to \code{ranger}. @@ -118,7 +118,7 @@ adversarial_rf <- function( i <- b <- cnt <- obs <- tree <- leaf <- N <- . <- NULL # Prep data - x_real <- prep_x(x) + x_real <- prep_x(x, verbose) n <- nrow(x_real) d <- ncol(x_real) factor_cols <- sapply(x_real, is.factor) diff --git a/R/utils.R b/R/utils.R index e2ca1368..ded1d942 100644 --- a/R/utils.R +++ b/R/utils.R @@ -82,9 +82,10 @@ which.max.random <- function(x) { #' This function prepares input data for ARFs. #' #' @param x Input data.frame. +#' @param verbose Show warning if recoding integers? #' @keywords internal -prep_x <- function(x) { +prep_x <- function(x, verbose = TRUE) { # Reclass all non-numeric features as factors x <- as.data.frame(x) idx_char <- sapply(x, is.character) @@ -102,16 +103,20 @@ prep_x <- function(x) { idx_integer[j] & length(unique(x[[j]])) > 5 }) if (any(to_numeric)) { - warning('Recoding integers with more than 5 unique values as numeric. ', + if (verbose) { + warning('Recoding integers with more than 5 unique values as numeric. ', 'To override this behavior, explicitly code these variables as factors.') + } x[, to_numeric] <- lapply(x[, to_numeric, drop = FALSE], as.numeric) } to_factor <- sapply(seq_len(ncol(x)), function(j) { idx_integer[j] & length(unique(x[[j]])) < 6 }) if (any(to_factor)) { - warning('Recoding integers with fewer than 6 unique values as ordered factors. ', + if (verbose) { + warning('Recoding integers with fewer than 6 unique values as ordered factors. ', 'To override this behavior, explicitly code these variables as numeric.') + } x[, to_factor] <- lapply(which(to_factor), function(j) { lvls <- sort(unique(x[[j]])) factor(x[[j]], levels = lvls, ordered = TRUE) diff --git a/man/adversarial_rf.Rd b/man/adversarial_rf.Rd index b8df22d0..febd49ec 100644 --- a/man/adversarial_rf.Rd +++ b/man/adversarial_rf.Rd @@ -37,7 +37,7 @@ round to the next?} \item{prune}{Impose \code{min_node_size} by pruning?} -\item{verbose}{Print discriminator accuracy after each round?} +\item{verbose}{Print discriminator accuracy after each round? Will also show additional warnings.} \item{parallel}{Compute in parallel? Must register backend beforehand, e.g. via \code{doParallel} or \code{doFuture}; see examples.} diff --git a/man/prep_x.Rd b/man/prep_x.Rd index e35f229c..8ca1fc45 100644 --- a/man/prep_x.Rd +++ b/man/prep_x.Rd @@ -4,10 +4,12 @@ \alias{prep_x} \title{Preprocess input data} \usage{ -prep_x(x) +prep_x(x, verbose = TRUE) } \arguments{ \item{x}{Input data.frame.} + +\item{verbose}{Show warning if recoding integers?} } \description{ This function prepares input data for ARFs. diff --git a/tests/testthat/test-return_types.R b/tests/testthat/test-return_types.R index fbb471c3..88d2cd6a 100644 --- a/tests/testthat/test-return_types.R +++ b/tests/testthat/test-return_types.R @@ -72,7 +72,7 @@ test_that("FORGE returns correct column types", { factor = factor(sample(letters[1:5], n, replace = TRUE)), logical = (sample(0:1, n, replace = TRUE) == 1)) - expect_warning(arf <- adversarial_rf(dat, num_trees = 2, verbose = FALSE, parallel = FALSE)) + arf <- adversarial_rf(dat, num_trees = 2, verbose = FALSE, parallel = FALSE) psi <- forde(arf, dat, parallel = FALSE) # with round = TRUE