@@ -39,11 +39,11 @@ conditional_depth_ranger <- function(frame, vars){
3939 df <- frame [begin : nrow(frame ), setdiff(names(frame ), setdiff(vars , j ))]
4040 df [[j ]][1 ] <- 0
4141 for (k in 2 : nrow(df )){
42- if (length(df [df [, " leftChild" ] == as.numeric(df [k , " number" ]) |
43- df [, " rightChild" ] == as.numeric(df [k , " number" ]), j ]) != 0 ){
42+ if (length(df [( ! is.na( df [, " leftChild" ]) & df [, " leftChild " ] == as.numeric(df [k , " number" ]) ) |
43+ ( ! is.na( df [, " rightChild" ]) & df [, " rightChild " ] == as.numeric(df [k , " number" ]) ), j ]) != 0 ){
4444 df [k , j ] <-
45- df [df [, " leftChild" ] == as.numeric(df [k , " number" ]) |
46- df [, " rightChild" ] == as.numeric(df [k , " number" ]), j ] + 1
45+ df [( ! is.na( df [, " leftChild" ]) & df [, " leftChild " ] == as.numeric(df [k , " number" ]) ) |
46+ ( ! is.na( df [, " rightChild" ]) & df [, " rightChild " ] == as.numeric(df [k , " number" ]) ), j ] + 1
4747 }
4848 }
4949 frame [begin : nrow(frame ), setdiff(names(frame ), setdiff(vars , j ))] <- df
@@ -84,8 +84,8 @@ min_depth_interactions_values <- function(forest, vars){
8484min_depth_interactions_values_ranger <- function (forest , vars ){
8585 `.` <- NULL ; .SD <- NULL ; tree <- NULL ; splitvarName <- NULL
8686 interactions_frame <-
87- lapply(1 : forest $ num.trees , function (i ) ranger :: treeInfo(forest , k = i , labelVar = T ) %> %
88- calculate_tree_depth () %> % cbind(. , tree = i , number = 1 : nrow(. ))) %> %
87+ lapply(1 : forest $ num.trees , function (i ) ranger :: treeInfo(forest , tree = i ) %> %
88+ calculate_tree_depth_ranger () %> % cbind(. , tree = i , number = 1 : nrow(. ))) %> %
8989 data.table :: rbindlist() %> % as.data.frame()
9090 interactions_frame [vars ] <- as.numeric(NA )
9191 interactions_frame <-
@@ -95,10 +95,10 @@ min_depth_interactions_values_ranger <- function(forest, vars){
9595 mean_tree_depth [mean_tree_depth == - Inf ] <- NA
9696 mean_tree_depth <- colMeans(mean_tree_depth [, vars ], na.rm = TRUE )
9797 min_depth_interactions_frame <-
98- interactions_frame %> % dplyr :: group_by(tree , `split var` ) %> %
98+ interactions_frame %> % dplyr :: group_by(tree , splitvarName ) %> %
9999 dplyr :: summarize_at(vars , funs(min(. , na.rm = TRUE ))) %> % as.data.frame()
100100 min_depth_interactions_frame [min_depth_interactions_frame == Inf ] <- NA
101- min_depth_interactions_frame <- min_depth_interactions_frame [! is.na(min_depth_interactions_frame $ `split var` ), ]
101+ min_depth_interactions_frame <- min_depth_interactions_frame [! is.na(min_depth_interactions_frame $ splitvarName ), ]
102102 colnames(min_depth_interactions_frame )[2 ] <- " variable"
103103 min_depth_interactions_frame [, - c(1 : 2 )] <- min_depth_interactions_frame [, - c(1 : 2 )] - 1
104104 return (list (min_depth_interactions_frame , mean_tree_depth ))
@@ -120,7 +120,8 @@ min_depth_interactions_values_ranger <- function(forest, vars){
120120# ' min_depth_interactions(forest, c("Petal.Width", "Petal.Length"))
121121# '
122122# ' @export
123- min_depth_interactions <- function (){
123+ min_depth_interactions <- function (forest , vars = important_variables(measure_importance(forest )),
124+ mean_sample = " top_trees" , uncond_mean_sample = mean_sample ){
124125 UseMethod(" min_depth_interactions" )
125126}
126127
@@ -217,7 +218,7 @@ min_depth_interactions.ranger <- function(forest, vars = important_variables(mea
217218 dplyr :: summarize(min(depth ))
218219 colnames(min_depth_frame ) <- c(" tree" , " variable" , " minimal_depth" )
219220 min_depth_frame <- as.data.frame(min_depth_frame [! is.na(min_depth_frame $ variable ),])
220- importance_frame <- get_min_depth_means_ranger (min_depth_frame , min_depth_count(min_depth_frame ), uncond_mean_sample )
221+ importance_frame <- get_min_depth_means (min_depth_frame , min_depth_count(min_depth_frame ), uncond_mean_sample )
221222 colnames(importance_frame )[2 ] <- " uncond_mean_min_depth"
222223 interactions_frame <- merge(interactions_frame , importance_frame )
223224}
@@ -273,6 +274,9 @@ plot_min_depth_interactions <- function(interactions_frame, k = 30,
273274# ' @param variable2 A character string with the name a numerical predictor that will on Y-axis
274275# ' @param grid The number of points on the one-dimensional grid on x and y-axis
275276# ' @param main A string to be used as title of the plot
277+ # ' @param time A numeric value specifying the time at which to predict survival probability, only
278+ # ' applies to survival forests. If not specified, the time closest to predicted median survival
279+ # ' time is used
276280# '
277281# ' @return A ggplot2 object
278282# '
@@ -285,7 +289,8 @@ plot_min_depth_interactions <- function(interactions_frame, k = 30,
285289# ' @export
286290plot_predict_interaction <- function (forest , data , variable1 , variable2 , grid = 100 ,
287291 main = paste0(" Prediction of the forest for different values of " ,
288- paste0(variable1 , paste0(" and " , variable2 )))){
292+ paste0(variable1 , paste0(" and " , variable2 ))),
293+ time = NULL ){
289294 UseMethod(" plot_predict_interaction" )
290295}
291296
@@ -296,7 +301,8 @@ plot_predict_interaction <- function(forest, data, variable1, variable2, grid =
296301# ' @export
297302plot_predict_interaction.randomForest <- function (forest , data , variable1 , variable2 , grid = 100 ,
298303 main = paste0(" Prediction of the forest for different values of " ,
299- paste0(variable1 , paste0(" and " , variable2 )))){
304+ paste0(variable1 , paste0(" and " , variable2 ))),
305+ time = NULL ){
300306 newdata <- expand.grid(seq(min(data [[variable1 ]]), max(data [[variable1 ]]), length.out = grid ),
301307 seq(min(data [[variable2 ]]), max(data [[variable2 ]]), length.out = grid ))
302308 colnames(newdata ) <- c(variable1 , variable2 )
@@ -343,7 +349,8 @@ plot_predict_interaction.randomForest <- function(forest, data, variable1, varia
343349# ' @export
344350plot_predict_interaction.ranger <- function (forest , data , variable1 , variable2 , grid = 100 ,
345351 main = paste0(" Prediction of the forest for different values of " ,
346- paste0(variable1 , paste0(" and " , variable2 )))){
352+ paste0(variable1 , paste0(" and " , variable2 ))),
353+ time = NULL ){
347354 newdata <- expand.grid(seq(min(data [[variable1 ]]), max(data [[variable1 ]]), length.out = grid ),
348355 seq(min(data [[variable2 ]]), max(data [[variable2 ]]), length.out = grid ))
349356 colnames(newdata ) <- c(variable1 , variable2 )
@@ -358,7 +365,7 @@ plot_predict_interaction.ranger <- function(forest, data, variable1, variable2,
358365 newdata [[i ]] <- data [[i ]][sample(1 : n , nrow(newdata ), replace = TRUE )]
359366 }
360367 if (forest $ treetype == " Regression" ){
361- newdata $ prediction <- predict(forest , newdata , type = " response" )
368+ newdata $ prediction <- predict(forest , newdata , type = " response" )$ predictions
362369 plot <- ggplot(newdata , aes_string(x = variable1 , y = variable2 , fill = " prediction" )) +
363370 geom_raster() + theme_bw() +
364371 scale_fill_gradient2(midpoint = min(newdata $ prediction ) + 0.5 * (max(newdata $ prediction ) - min(newdata $ prediction )),
@@ -377,8 +384,23 @@ plot_predict_interaction.ranger <- function(forest, data, variable1, variable2,
377384 geom_raster() + theme_bw() + facet_wrap(~ variable ) +
378385 scale_fill_gradient2(midpoint = min(newdata $ prediction ) + 0.5 * (max(newdata $ prediction ) - min(newdata $ prediction )),
379386 low = " blue" , high = " red" )
380- } else if (forest $ treetype == " Classification" ) {
387+ } else if (forest $ treetype == " Classification" ){
381388 stop(" Ranger forest for classification needs to be generated by ranger(..., probability = TRUE)." )
389+ } else if (forest $ treetype == " Survival" ){
390+ pred <- predict(forest , newdata , type = " response" )
391+ if (is.null(time )) {
392+ time <- pred $ unique.death.times [which.min(abs(colMeans(pred $ survival , na.rm = TRUE ) - 0.5 ))]
393+ message(sprintf(" Using unique death time %s which is the closest to predicted median survival time." , time ))
394+ } else if (! time %in% pred $ unique.death.times ) {
395+ new_time <- pred $ unique.death.times [which.min(abs(pred $ unique.death.times - time ))]
396+ message(sprintf(" Using closest unique death time %s instead of %s." , new_time , time ))
397+ time <- new_time
398+ }
399+ newdata $ prediction <- pred $ survival [, pred $ unique.death.times == time , drop = TRUE ]
400+ plot <- ggplot(newdata , aes_string(x = variable1 , y = variable2 , fill = " prediction" )) +
401+ geom_raster() + theme_bw() +
402+ scale_fill_gradient2(midpoint = min(newdata $ prediction ) + 0.5 * (max(newdata $ prediction ) - min(newdata $ prediction )),
403+ low = " blue" , high = " red" )
382404 } else {
383405 stop(sprintf(" Ranger forest type '%s' is currently not supported." , forest $ treetype ))
384406 }
0 commit comments