Skip to content

Commit 4ad9693

Browse files
committed
unsupervised randomForest should be supported by all functions except plot_predict_interaction
1 parent 1804b4c commit 4ad9693

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

R/explain_forest.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ explain_forest <- function(forest, interactions = FALSE, data = NULL, vars = NUL
2525
measures = NULL){
2626
if(is.null(measures)){
2727
if("randomForest" %in% class(forest)){
28-
if(forest$type == "classification"){
28+
if(forest$type %in% c("classification", "unsupervised")){
2929
measures <- c("mean_min_depth", "accuracy_decrease", "gini_decrease", "no_of_nodes", "times_a_root")
3030
} else{
3131
measures <- c("mean_min_depth", "mse_increase", "node_purity_increase", "no_of_nodes", "times_a_root")
@@ -36,7 +36,7 @@ explain_forest <- function(forest, interactions = FALSE, data = NULL, vars = NUL
3636
}
3737
if("randomForest" %in% class(forest) && dim(forest$importance)[2] == 1){
3838
stop(paste("Your forest does not contain information on local importance so",
39-
ifelse(forest$type == "classification", "accuracy_decrease", "mse_increase"),
39+
ifelse(forest$type %in% c("classification", "unsupervised"), "accuracy_decrease", "mse_increase"),
4040
"measure cannot be extracted.",
4141
"To add it regrow the forest with the option localImp = TRUE and run this function again."))
4242
}

R/measure_importance.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ measure_no_of_nodes_ranger <- function(forest_table){
3131
# Extract randomForest variable importance measures
3232
# randomForest
3333
measure_vimp <- function(forest, only_nonlocal = FALSE){
34-
if(forest$type == "classification"){
34+
if(forest$type %in% c("classification", "unsupervised")){
3535
if(dim(forest$importance)[2] == 1){
3636
if(only_nonlocal == FALSE){
3737
print("Warning: your forest does not contain information on local importance so 'accuracy_decrease' measure cannot be extracted. To add it regrow the forest with the option localImp = TRUE and run this function again.")
@@ -129,7 +129,7 @@ measure_importance <- function(forest, mean_sample = "top_trees", measures = NUL
129129
measure_importance.randomForest <- function(forest, mean_sample = "top_trees", measures = NULL){
130130
tree <- NULL; `split var` <- NULL; depth <- NULL
131131
if(is.null(measures)){
132-
if(forest$type == "classification"){
132+
if(forest$type %in% c("classification", "unsupervised")){
133133
measures <- c("mean_min_depth", "no_of_nodes", "accuracy_decrease",
134134
"gini_decrease", "no_of_trees", "times_a_root", "p_value")
135135
} else if(forest$type =="regression"){
@@ -159,7 +159,7 @@ measure_importance.randomForest <- function(forest, mean_sample = "top_trees", m
159159
importance_frame <- merge(importance_frame, measure_no_of_nodes(forest_table), all = TRUE)
160160
importance_frame[is.na(importance_frame$no_of_nodes), "no_of_nodes"] <- 0
161161
}
162-
if(forest$type == "classification"){
162+
if(forest$type %in% c("classification", "unsupervised")){
163163
vimp <- c("accuracy_decrease", "gini_decrease")
164164
} else if(forest$type =="regression"){
165165
vimp <- c("mse_increase", "node_purity_increase")

R/min_depth_interactions.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,10 @@ plot_predict_interaction.randomForest <- function(forest, data, variable1, varia
303303
main = paste0("Prediction of the forest for different values of ",
304304
paste0(variable1, paste0(" and ", variable2))),
305305
time = NULL){
306+
if (forest$type == "unsupervised") {
307+
warning("plot_predict_interaction cannot be performed on unsupervised random forests.")
308+
return(NULL)
309+
}
306310
newdata <- expand.grid(seq(min(data[[variable1]]), max(data[[variable1]]), length.out = grid),
307311
seq(min(data[[variable2]]), max(data[[variable2]]), length.out = grid))
308312
colnames(newdata) <- c(variable1, variable2)

0 commit comments

Comments
 (0)