@@ -155,16 +155,16 @@ adversarial_rf <- function(
155155 # Create synthetic data by sampling from intra-leaf marginals
156156 nodeIDs <- stats :: predict(rf0 , x_real , type = ' terminalNodes' )$ predictions
157157 tmp <- data.table(' tree' = rep(seq_len(num_trees ), each = n ),
158- ' leaf' = as.vector(nodeIDs ))
159- x_real_dt <- do.call(rbind , lapply(seq_len(num_trees ), function (b ) {
160- cbind(x_real , tmp [tree == b ])
158+ ' leaf' = as.integer(nodeIDs ))
159+ tmp2 <- tmp [sample(.N , n , replace = TRUE )]
160+ tmp2 <- unique(tmp2 [, cnt : = .N , by = .(tree , leaf )])
161+ draw_from <- rbindlist(lapply(seq_len(num_trees ), function (b ) {
162+ x_real_b <- cbind(x_real , tmp [tree == b ])
163+ x_real_b [, factor_cols ] <- lapply(x_real_b [, factor_cols , drop = FALSE ], as.numeric )
164+ merge(tmp2 , x_real_b , by = c(' tree' , ' leaf' ),
165+ sort = FALSE )[, N : = .N , by = .(tree , leaf )]
161166 }))
162- x_real_dt [, factor_cols ] <- lapply(x_real_dt [, factor_cols , drop = FALSE ], as.numeric )
163- tmp <- tmp [sample(.N , n , replace = TRUE )]
164- tmp <- unique(tmp [, cnt : = .N , by = .(tree , leaf )])
165- draw_from <- merge(tmp , x_real_dt , by = c(' tree' , ' leaf' ),
166- sort = FALSE )[, N : = .N , by = .(tree , leaf )]
167- rm(nodeIDs , tmp , x_real_dt )
167+ rm(nodeIDs , tmp , tmp2 )
168168 draw_params_within <- unique(draw_from , by = c(' tree' ,' leaf' ))[, .(cnt , N )]
169169 adj_absolut_col <- rep(c(0 , draw_params_within [- .N , cumsum(N )]),
170170 times = draw_params_within $ cnt )
0 commit comments