|
3 | 3 | #' Explains a random forest in a html document using plots created by randomForestExplainer |
4 | 4 | #' |
5 | 5 | #' @param forest A randomForest object created with the option localImp = TRUE |
| 6 | +#' @param path Path to write output html to |
6 | 7 | #' @param interactions Logical value: should variable interactions be considered (this may be time-consuming) |
7 | 8 | #' @param data The data frame on which forest was trained - necessary if interactions = TRUE |
8 | 9 | #' @param vars A character vector with variables with respect to which interactions will be considered if NULL then they will be selected using the important_variables() function |
9 | 10 | #' @param no_of_pred_plots The number of most frequent interactions of numeric variables to plot predictions for |
10 | 11 | #' @param pred_grid The number of points on the grid of plot_predict_interaction (decrease in case memory problems) |
11 | 12 | #' @param measures A character vector specifying the importance measures to be used for plotting ggpairs |
12 | 13 | #' |
13 | | -#' @return A html document in your working directory |
| 14 | +#' @return A html document. If path is not specified, this document will be "Your_forest_explained.html" in your working directory |
14 | 15 | #' |
15 | 16 | #' @import DT |
16 | 17 | #' |
|
21 | 22 | #' } |
22 | 23 | #' |
23 | 24 | #' @export |
24 | | -explain_forest <- function(forest, interactions = FALSE, data = NULL, vars = NULL, no_of_pred_plots = 3, pred_grid = 100, |
| 25 | +explain_forest <- function(forest, path = NULL, interactions = FALSE, data = NULL, vars = NULL, no_of_pred_plots = 3, pred_grid = 100, |
25 | 26 | measures = NULL){ |
| 27 | + if(is.null(path)) { |
| 28 | + directory <- getwd() |
| 29 | + path <- paste0(directory, "/Your_forest_explained.html") |
| 30 | + } |
26 | 31 | if(is.null(measures)){ |
27 | 32 | if("randomForest" %in% class(forest)){ |
28 | 33 | if(forest$type %in% c("classification", "unsupervised")){ |
@@ -53,12 +58,12 @@ explain_forest <- function(forest, interactions = FALSE, data = NULL, vars = NUL |
53 | 58 | environment$no_of_pred_plots <- no_of_pred_plots |
54 | 59 | environment$pred_grid <- pred_grid |
55 | 60 | environment$measures <- measures |
56 | | - directory <- getwd() |
| 61 | + |
57 | 62 | path_to_templates <- file.path(path.package("randomForestExplainer"), "templates") |
58 | 63 | template_name <- grep('explain_forest_template.rmd', list.files(path_to_templates), |
59 | 64 | ignore.case = TRUE, value = TRUE) |
60 | 65 |
|
61 | 66 | rmarkdown::render(file.path(path_to_templates, template_name), |
62 | | - "html_document", output_file = paste0(directory, "/Your_forest_explained.html"), |
| 67 | + "html_document", output_file = path, |
63 | 68 | envir = environment) |
64 | 69 | } |
0 commit comments