|
| 1 | +# Copyright (c) Meta Platforms, Inc. and its affiliates. |
| 2 | + |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +#################################################################### |
| 7 | +#' Reduce number of models based on ROI clusters and minimum combined errors |
| 8 | +#' |
| 9 | +#' The \code{robyn_clusters()} function uses output from \code{robyn_run()}, |
| 10 | +#' to reduce the amount of models and help the user pick up the best (lowest |
| 11 | +#' combined error) of different kinds (clusters) of models. |
| 12 | +#' |
| 13 | +#' @inheritParams lares::clusterKmeans |
| 14 | +#' @inheritParams hyper_names |
| 15 | +#' @inheritParams robyn_outputs |
| 16 | +#' @param input \code{robyn_export()}'s output or \code{pareto_aggregated.csv} results. |
| 17 | +#' @param limit Integer. Top N results per cluster. If kept in "auto", will select k |
| 18 | +#' as the cluster in which the WSS variance was less than 5\%. |
| 19 | +#' @param weights Vector, size 3. How much should each error weight? |
| 20 | +#' Order: nrmse, decomp.rssd, mape. The highest the value, the closer it will be scaled |
| 21 | +#' to origin. Each value will be normalized so they all sum 1. |
| 22 | +#' @param export Export plots into local files? |
| 23 | +#' @param ... Additional parameters passed to \code{lares::clusterKmeans()}. |
| 24 | +#' @author Bernardo Lares (bernardolares@@fb.com) |
| 25 | +#' @examples |
| 26 | +#' \dontrun{ |
| 27 | +#' cls <- robyn_clusters(input = OutputCollect, |
| 28 | +#' all_media = InputCollect$all_media, |
| 29 | +#' k = 3, limit = 2, |
| 30 | +#' weights = c(1, 1, 1.5)) |
| 31 | +#' } |
| 32 | +#' @export |
| 33 | +robyn_clusters <- function(input, all_media = NULL, k = "auto", limit = 1, |
| 34 | + weights = rep(1, 3), dim_red = "PCA", |
| 35 | + quiet = FALSE, export = FALSE, |
| 36 | + ...) { |
| 37 | + |
| 38 | + if ("robyn_outputs" %in% class(input)) { |
| 39 | + if (is.null(all_media)) { |
| 40 | + aux <- colnames(input$mediaVecCollect) |
| 41 | + all_media <- aux[-c(1, which(aux == "type"):length(aux))] |
| 42 | + path <- input$plot_folder |
| 43 | + } else path <- paste0(getwd(), "/") |
| 44 | + # Pareto and ROI data |
| 45 | + rois <- input$xDecompAgg |
| 46 | + df <- .prepare_roi(rois, all_media = all_media) |
| 47 | + } else { |
| 48 | + if (all(c("solID", "mape", "nrmse", "decomp.rssd") %in% names(input)) & is.data.frame(input)) { |
| 49 | + df <- .prepare_roi(input, all_media) |
| 50 | + } else { |
| 51 | + stop(paste( |
| 52 | + "You must run robyn_export(..., clusters = TRUE) or", |
| 53 | + "pass a valid data.frame (sames as pareto_aggregated.csv output)", |
| 54 | + "in order to use robyn_clusters()" |
| 55 | + )) |
| 56 | + } |
| 57 | + } |
| 58 | + |
| 59 | + ignore <- c("solID", "mape", "decomp.rssd", "nrmse", "pareto") |
| 60 | + |
| 61 | + # Auto K selected by less than 5% WSS variance (convergence) |
| 62 | + min_clusters <- 3 |
| 63 | + limit_clusters <- min(nrow(df) - 1, 30) |
| 64 | + if ("auto" %in% k) { |
| 65 | + cls <- tryCatch({ |
| 66 | + clusterKmeans(df, k = NULL, limit = limit_clusters, ignore = ignore, dim_red = dim_red, quiet = TRUE, ...) |
| 67 | + }, error = function(err) { |
| 68 | + message(paste("Couldn't automatically create clusters:", err)) |
| 69 | + return(NULL) |
| 70 | + }) |
| 71 | + #if (is.null(cls)) return(NULL) |
| 72 | + min_var <- 0.05 |
| 73 | + k <- cls$nclusters %>% |
| 74 | + mutate(pareto = .data$wss/.data$wss[1], |
| 75 | + dif = lag(.data$pareto) - .data$pareto) %>% |
| 76 | + filter(.data$dif > min_var) %>% pull(.data$n) %>% max(.) |
| 77 | + if (k < min_clusters) k <- min_clusters |
| 78 | + if (!quiet) message(sprintf( |
| 79 | + ">> Auto selected k = %s (clusters) based on minimum WSS variance of %s%%", |
| 80 | + k, min_var*100)) |
| 81 | + } |
| 82 | + |
| 83 | + # Build clusters |
| 84 | + stopifnot(k %in% min_clusters:30) |
| 85 | + cls <- clusterKmeans(df, k, limit = limit_clusters, ignore = ignore, dim_red = dim_red, quiet = TRUE, ...) |
| 86 | + |
| 87 | + # Select top models by minimum (weighted) distance to zero |
| 88 | + top_sols <- .clusters_df(cls$df, weights) %>% |
| 89 | + mutate(error = (.data$nrmse^2 + .data$decomp.rssd^2 + .data$mape^2)^-(1 / 2)) %>% |
| 90 | + .crit_proc(limit) |
| 91 | + |
| 92 | + output <- list( |
| 93 | + # Data and parameters |
| 94 | + data = mutate(cls$df, top_sol = .data$solID %in% top_sols$solID), |
| 95 | + n_clusters = k, |
| 96 | + errors_weights = weights, |
| 97 | + # Within Groups Sum of Squares Plot |
| 98 | + wss = cls$nclusters_plot, |
| 99 | + # Grouped correlations per cluster |
| 100 | + corrs = cls$correlations + labs(title = "ROI Top Correlations by Cluster", subtitle = NULL), |
| 101 | + # Mean ROI per cluster |
| 102 | + clusters_means = cls$means, |
| 103 | + # Dim reduction clusters |
| 104 | + clusters_PCA = cls[["PCA"]], |
| 105 | + clusters_tSNE = cls[["tSNE"]], |
| 106 | + # Top Clusters |
| 107 | + models = top_sols, |
| 108 | + plot_models_errors = .plot_topsols_errors(df, top_sols, limit, weights), |
| 109 | + plot_models_rois = .plot_topsols_rois(top_sols, all_media, limit) |
| 110 | + ) |
| 111 | + |
| 112 | + if (export) { |
| 113 | + fwrite(output$data, file = paste0(path, "pareto_clusters.csv")) |
| 114 | + ggsave(paste0(path, "pareto_clusters_wss.png"), plot = output$wss, dpi = 500, width = 5, height = 4) |
| 115 | + ggsave(paste0(path, "pareto_clusters_corr.png"), plot = output$corrs, dpi = 500, width = 7, height = 5) |
| 116 | + db <- wrap_plots(output$plot_models_rois, output$plot_models_errors) |
| 117 | + ggsave(paste0(path, "pareto_clusters_detail.png"), plot = db, dpi = 600, width = 12, height = 9) |
| 118 | + } |
| 119 | + |
| 120 | + return(output) |
| 121 | + |
| 122 | +} |
| 123 | + |
| 124 | + |
| 125 | +# ROIs data.frame for clustering (from xDecompAgg or pareto_aggregated.csv) |
| 126 | +.prepare_roi <- function(x, all_media) { |
| 127 | + check_opts(all_media, unique(x$rn)) |
| 128 | + rois <- pivot_wider(x, id_cols = "solID", names_from = "rn", values_from = "roi_total") |
| 129 | + rois <- removenacols(rois, all = FALSE) |
| 130 | + rois <- select(rois, any_of(c("solID", all_media))) |
| 131 | + errors <- distinct(x, .data$solID, .data$nrmse, .data$decomp.rssd, .data$mape) |
| 132 | + rois <- left_join(rois, errors, "solID") %>% ungroup() |
| 133 | + return(rois) |
| 134 | +} |
| 135 | + |
| 136 | +.min_max_norm <- function(x) (x - min(x)) / (max(x) - min(x)) |
| 137 | + |
| 138 | +.clusters_df <- function(df, balance = rep(1, 3)) { |
| 139 | + stopifnot(length(balance) == 3) |
| 140 | + balance <- balance / sum(balance) |
| 141 | + crit_df <- df %>% |
| 142 | + # Force normalized values so they can be comparable |
| 143 | + mutate( |
| 144 | + nrmse = .min_max_norm(.data$nrmse), |
| 145 | + decomp.rssd = .min_max_norm(.data$decomp.rssd), |
| 146 | + mape = .min_max_norm(.data$mape) |
| 147 | + ) %>% |
| 148 | + # Balance to give more or less importance to each error |
| 149 | + mutate( |
| 150 | + nrmse = balance[1] / .data$nrmse, |
| 151 | + decomp.rssd = balance[2] / .data$decomp.rssd, |
| 152 | + mape = balance[3] / .data$mape |
| 153 | + ) %>% |
| 154 | + replace(., is.na(.), 0) %>% |
| 155 | + group_by(.data$cluster) |
| 156 | + return(crit_df) |
| 157 | +} |
| 158 | + |
| 159 | +.crit_proc <- function(df, limit) { |
| 160 | + arrange(df, .data$cluster, desc(.data$error)) %>% |
| 161 | + slice(1:limit) %>% |
| 162 | + mutate(rank = row_number()) %>% |
| 163 | + select(.data$cluster, .data$rank, everything()) |
| 164 | +} |
| 165 | + |
| 166 | +.plot_topsols_errors <- function(df, top_sols, limit = 1, balance = rep(1, 3)) { |
| 167 | + balance <- balance / sum(balance) |
| 168 | + left_join(df, select(top_sols, 1:3), "solID") %>% |
| 169 | + mutate( |
| 170 | + alpha = ifelse(is.na(.data$cluster), 0.5, 1), |
| 171 | + label = ifelse(!is.na(.data$cluster), sprintf( |
| 172 | + "[%s.%s]", .data$cluster, .data$rank |
| 173 | + ), NA) |
| 174 | + ) %>% |
| 175 | + ggplot(aes(x = .data$nrmse, y = .data$decomp.rssd)) + |
| 176 | + geom_point(aes(colour = .data$cluster, alpha = .data$alpha)) + |
| 177 | + geom_text(aes(label = .data$label), na.rm = TRUE, hjust = -0.3) + |
| 178 | + guides(alpha = "none", colour = "none") + |
| 179 | + labs( |
| 180 | + title = paste("Selecting Top", limit, "Performing Models by Cluster"), |
| 181 | + subtitle = "Based on minimum (weighted) distance to origin", |
| 182 | + x = "NRMSE", y = "DECOMP.RSSD", |
| 183 | + caption = sprintf( |
| 184 | + "Weights: NRMSE %s%%, DECOMP.RSSD %s%%, MAPE %s%%", |
| 185 | + round(100 * balance[1]), round(100 * balance[2]), round(100 * balance[3]) |
| 186 | + ) |
| 187 | + ) + |
| 188 | + theme_lares() |
| 189 | +} |
| 190 | + |
| 191 | +.plot_topsols_rois <- function(top_sols, all_media, limit = 1) { |
| 192 | + top_sols %>% |
| 193 | + mutate(label = sprintf("[%s.%s]\n%s", .data$cluster, .data$rank, .data$solID)) %>% |
| 194 | + tidyr::gather("media", "roi", contains(all_media)) %>% |
| 195 | + ggplot(aes(x = .data$media, y = .data$roi)) + |
| 196 | + facet_grid(.data$label ~ .) + |
| 197 | + geom_col() + |
| 198 | + coord_flip() + |
| 199 | + labs( |
| 200 | + title = paste("ROIs on Top", limit, "Performing Models"), |
| 201 | + x = NULL, y = "ROI per Media" |
| 202 | + ) + |
| 203 | + theme_lares() |
| 204 | +} |
0 commit comments