Skip to content

Commit dbea66d

Browse files
committed
n_synth default with given evidence
1 parent f9074b8 commit dbea66d

File tree

3 files changed

+21
-6
lines changed

3 files changed

+21
-6
lines changed

R/forde.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ forde <- function(
280280
dt[NA_share == 1, c('min', 'max') := .(fifelse(is.infinite(min), min_emp, min),
281281
fifelse(is.infinite(max), max_emp, max))]
282282
dt[, c("min_emp", "max_emp") := NULL]
283-
dt[NA_share == 1, mu := (max - min) / 2]
283+
dt[NA_share == 1, mu := (max + min) / 2]
284284
dt[is.na(sigma), sigma := 0]
285285
if (any(dt[, sigma == 0])) {
286286
dt[, new_min := fifelse(!is.finite(min), min(value, na.rm = TRUE), min), by = variable]

R/shortcut_functions.R

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,12 @@ darf <- function(x, query = NULL, ...) {
7373
#'
7474
#' @param x Input data. Integer variables are recoded as ordered factors with
7575
#' a warning. See Details.
76-
#' @param n_synth Number of synthetic samples to generate. Is set to \code{nrow(x)} if
77-
#' \code{NULL}.
76+
#' @param n_synth Number of synthetic samples to generate for unconditional
77+
#' generation with no \code{evidence} given.
78+
#' Number of synthetic samples to generate per \code{evidence} row if \code{evidence}
79+
#' is provided.
80+
#' If \code{NULL}, defaults to \code{nrow(x)} if no \code{evidence} is provided and to
81+
#' \code{1} otherwise.
7882
#' @param ... Extra parameters to be passed to \code{adversarial_rf}, \code{forde}
7983
#' and \code{forge}.
8084
#'
@@ -122,7 +126,14 @@ rarf <- function(x, n_synth = NULL, ...) {
122126

123127
if (!("params" %in% names(forde_args))) params <- do.call(forde, c(arf = list(arf), x = list(x), forde_args))
124128

125-
if(is.null(n_synth)) n_synth <- nrow(x)
129+
if(is.null(n_synth)) {
130+
if (is.null(forge_args$evidence))
131+
n_synth <- nrow(x)
132+
else {
133+
n_synth <- 1
134+
}
135+
}
136+
126137
do.call(forge, c(params = list(params),
127138
n_synth = list(n_synth),
128139
forge_args))

man/rarf.Rd

Lines changed: 6 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)