Skip to content

Commit 485c7d4

Browse files
committed
save memory in adversarial loop, fix #40
1 parent 2353854 commit 485c7d4

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

R/adversarial_rf.R

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,16 +155,16 @@ adversarial_rf <- function(
155155
# Create synthetic data by sampling from intra-leaf marginals
156156
nodeIDs <- stats::predict(rf0, x_real, type = 'terminalNodes')$predictions
157157
tmp <- data.table('tree' = rep(seq_len(num_trees), each = n),
158-
'leaf' = as.vector(nodeIDs))
159-
x_real_dt <- do.call(rbind, lapply(seq_len(num_trees), function(b) {
160-
cbind(x_real, tmp[tree == b])
158+
'leaf' = as.integer(nodeIDs))
159+
tmp2 <- tmp[sample(.N, n, replace = TRUE)]
160+
tmp2 <- unique(tmp2[, cnt := .N, by = .(tree, leaf)])
161+
draw_from <- rbindlist(lapply(seq_len(num_trees), function(b) {
162+
x_real_b <- cbind(x_real, tmp[tree == b])
163+
x_real_b[, factor_cols] <- lapply(x_real_b[, factor_cols, drop = FALSE], as.numeric)
164+
merge(tmp2, x_real_b, by = c('tree', 'leaf'),
165+
sort = FALSE)[, N := .N, by = .(tree, leaf)]
161166
}))
162-
x_real_dt[, factor_cols] <- lapply(x_real_dt[, factor_cols, drop = FALSE], as.numeric)
163-
tmp <- tmp[sample(.N, n, replace = TRUE)]
164-
tmp <- unique(tmp[, cnt := .N, by = .(tree, leaf)])
165-
draw_from <- merge(tmp, x_real_dt, by = c('tree', 'leaf'),
166-
sort = FALSE)[, N := .N, by = .(tree, leaf)]
167-
rm(nodeIDs, tmp, x_real_dt)
167+
rm(nodeIDs, tmp, tmp2)
168168
draw_params_within <- unique(draw_from, by = c('tree','leaf'))[, .(cnt, N)]
169169
adj_absolut_col <- rep(c(0, draw_params_within[-.N, cumsum(N)]),
170170
times = draw_params_within$cnt)

0 commit comments

Comments
 (0)