Skip to content

Commit 4b3d829

Browse files
Merge pull request #283 from facebookexperimental/lares
Clustering and functions split for exporting results
2 parents 86202d5 + 20981cb commit 4b3d829

24 files changed

+1692
-1043
lines changed

R/DESCRIPTION

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
Package: Robyn
22
Type: Package
33
Title: Automated Marketing Mix Modeling (MMM) Open Source Beta Project from Facebook Marketing Science
4-
Version: 3.4.12
4+
Version: 3.5.0
55
Authors@R: c(
66
person("Gufeng", "Zhou", , "gufeng@fb.com", c("aut")),
77
person("Leonel", "Sentana", , "leonelsentana@fb.com", c("aut")),
88
person("Igor", "Skokan", , "igorskokan@fb.com", c("aut")),
9-
person("Bernardo", "Lares", , "bernardolares@fb.com", c("cre")))
9+
person("Bernardo", "Lares", , "bernardolares@fb.com", c("cre","aut")))
1010
Maintainer: Gufeng Zhou <gufeng@fb.com>, Bernardo Lares <bernardolares@fb.com>
1111
Description: Automated Marketing Mix Modeling (MMM) package that aims to reduce human bias by means of ridge regression and evolutionary algorithms, enables actionable decision making providing a budget allocator and diminishing returns curves and allows ground-truth calibration to account for causation.
1212
Depends:
@@ -15,18 +15,21 @@ Imports:
1515
data.table,
1616
doParallel,
1717
doRNG,
18+
dplyr,
1819
foreach,
1920
ggplot2,
2021
ggridges,
2122
glmnet,
23+
lares,
2224
lubridate,
2325
minpack.lm,
2426
nloptr,
2527
patchwork,
2628
prophet,
2729
reticulate,
2830
rPref,
29-
stringr
31+
stringr,
32+
tidyr
3033
Suggests:
3134
shiny
3235
Config/reticulate:

R/NAMESPACE

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,57 @@ export(mic_men)
88
export(plot_adstock)
99
export(plot_saturation)
1010
export(robyn_allocator)
11+
export(robyn_clusters)
12+
export(robyn_csv)
1113
export(robyn_engineering)
1214
export(robyn_inputs)
1315
export(robyn_mmm)
16+
export(robyn_onepagers)
17+
export(robyn_outputs)
18+
export(robyn_plots)
1419
export(robyn_refresh)
1520
export(robyn_response)
1621
export(robyn_run)
1722
export(robyn_save)
23+
export(robyn_train)
1824
export(saturation_hill)
1925
import(data.table)
2026
import(ggplot2)
2127
importFrom(doParallel,registerDoParallel)
2228
importFrom(doParallel,stopImplicitCluster)
2329
importFrom(doRNG,"%dorng%")
30+
importFrom(dplyr,any_of)
31+
importFrom(dplyr,arrange)
32+
importFrom(dplyr,as_tibble)
33+
importFrom(dplyr,bind_rows)
34+
importFrom(dplyr,contains)
35+
importFrom(dplyr,desc)
36+
importFrom(dplyr,distinct)
37+
importFrom(dplyr,everything)
38+
importFrom(dplyr,filter)
39+
importFrom(dplyr,group_by)
40+
importFrom(dplyr,lag)
41+
importFrom(dplyr,left_join)
42+
importFrom(dplyr,mutate)
43+
importFrom(dplyr,pull)
44+
importFrom(dplyr,row_number)
45+
importFrom(dplyr,select)
46+
importFrom(dplyr,slice)
47+
importFrom(dplyr,ungroup)
2448
importFrom(foreach,"%dopar%")
2549
importFrom(foreach,foreach)
2650
importFrom(foreach,getDoParWorkers)
2751
importFrom(foreach,registerDoSEQ)
2852
importFrom(ggridges,geom_density_ridges)
2953
importFrom(glmnet,cv.glmnet)
3054
importFrom(glmnet,glmnet)
55+
importFrom(lares,`%>%`)
56+
importFrom(lares,check_opts)
57+
importFrom(lares,clusterKmeans)
58+
importFrom(lares,formatNum)
59+
importFrom(lares,freqs)
60+
importFrom(lares,removenacols)
61+
importFrom(lares,theme_lares)
3162
importFrom(lubridate,day)
3263
importFrom(lubridate,floor_date)
3364
importFrom(lubridate,is.Date)
@@ -71,7 +102,10 @@ importFrom(stringr,str_extract)
71102
importFrom(stringr,str_remove)
72103
importFrom(stringr,str_replace)
73104
importFrom(stringr,str_which)
105+
importFrom(tidyr,pivot_longer)
106+
importFrom(tidyr,pivot_wider)
74107
importFrom(utils,askYesNo)
108+
importFrom(utils,flush.console)
75109
importFrom(utils,head)
76110
importFrom(utils,setTxtProgressBar)
77111
importFrom(utils,txtProgressBar)

R/R/allocator.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#' variable spends that maximizes the total media response.
1313
#'
1414
#' @inheritParams robyn_run
15+
#' @inheritParams robyn_outputs
1516
#' @param robyn_object Character. Path of the \code{Robyn.RDS} object
1617
#' that contains all previous modeling information.
1718
#' @param select_build Integer. Default to the latest model build. \code{select_buil = 0}
@@ -480,8 +481,6 @@ robyn_allocator <- function(robyn_object = NULL,
480481
)
481482
}
482483

483-
# print(nlsMod)
484-
485484
## collect output
486485

487486
dt_bestModel <- dt_bestCoef[, .(rn, mean_spend, xDecompAgg, roi_total, roi_mean)][order(rank(rn))]
@@ -511,7 +510,6 @@ robyn_allocator <- function(robyn_object = NULL,
511510
)
512511

513512
dt_optimOut[, optmResponseUnitTotalLift := (optmResponseUnitTotal / initResponseUnitTotal) - 1]
514-
# print(dt_optimOut)
515513

516514
## plot allocator results
517515

R/R/checks.R

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,3 +468,52 @@ check_calibconstr <- function(calibration_constraint, iterations, trials, calibr
468468
}
469469
return(calibration_constraint)
470470
}
471+
472+
check_hyper_fixed <- function(InputCollect, dt_hyper_fixed) {
473+
hyper_fixed <- all(length(InputCollect$hyperparameters) == 1)
474+
if (hyper_fixed & is.null(dt_hyper_fixed)) {
475+
stop(paste("hyperparameters can't be all fixed for hyperparameter optimisation.",
476+
"If you want to get old model result, please provide only 1 model / 1 row from",
477+
"OutputCollect$resultHypParam or pareto_hyperparameters.csv from previous runs"))
478+
}
479+
if (!is.null(dt_hyper_fixed)) {
480+
## Run robyn_mmm if using old model result tables
481+
dt_hyper_fixed <- as.data.table(dt_hyper_fixed)
482+
if (nrow(dt_hyper_fixed) != 1) {
483+
stop(paste("Provide only 1 model / 1 row from OutputCollect$resultHypParam or",
484+
"pareto_hyperparameters.csv from previous runs"))
485+
}
486+
hypParamSamName <- hyper_names(adstock = InputCollect$adstock, all_media = InputCollect$all_media)
487+
if (!all(c(hypParamSamName, "lambda") %in% names(dt_hyper_fixed))) {
488+
stop(paste("dt_hyper_fixed is provided with wrong input.",
489+
"Please provide the table OutputCollect$resultHypParam from previous runs or",
490+
"pareto_hyperparameters.csv with desired model ID"))
491+
}
492+
}
493+
return(hyper_fixed)
494+
}
495+
496+
# Enable parallelisation of main modelling loop for MacOS and Linux only
497+
check_parallel <- function() "unix" %in% .Platform$OS.type
498+
# ggplot doesn't work with process forking on MacOS; however it works fine on Linux and Windows
499+
check_parallel_plot <- function() !"Darwin" %in% Sys.info()["sysname"]
500+
501+
check_parallel_msg <- function(InputCollect) {
502+
if (check_parallel()) {
503+
message(paste(
504+
"Using", InputCollect$adstock, "adstocking with",
505+
length(InputCollect$hyperparameters),
506+
"hyperparameters & 10-fold ridge x-validation on", InputCollect$cores, "cores"
507+
))
508+
} else {
509+
message(paste(
510+
"Using", InputCollect$adstock, "adstocking with",
511+
length(InputCollect$hyperparameters),
512+
"hyperparameters & 10-fold ridge x-validation on 1 core (Windows fallback)"
513+
))
514+
}
515+
}
516+
517+
check_class <- function(x, object) {
518+
if (any(!x %in% class(object))) stop(sprintf("Input object must be class %s", x))
519+
}

R/R/clusters.R

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright (c) Meta Platforms, Inc. and its affiliates.
2+
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
####################################################################
7+
#' Reduce number of models based on ROI clusters and minimum combined errors
8+
#'
9+
#' The \code{robyn_clusters()} function uses output from \code{robyn_run()},
10+
#' to reduce the amount of models and help the user pick up the best (lowest
11+
#' combined error) of different kinds (clusters) of models.
12+
#'
13+
#' @inheritParams lares::clusterKmeans
14+
#' @inheritParams hyper_names
15+
#' @inheritParams robyn_outputs
16+
#' @param input \code{robyn_export()}'s output or \code{pareto_aggregated.csv} results.
17+
#' @param limit Integer. Top N results per cluster. If kept in "auto", will select k
18+
#' as the cluster in which the WSS variance was less than 5\%.
19+
#' @param weights Vector, size 3. How much should each error weight?
20+
#' Order: nrmse, decomp.rssd, mape. The highest the value, the closer it will be scaled
21+
#' to origin. Each value will be normalized so they all sum 1.
22+
#' @param export Export plots into local files?
23+
#' @param ... Additional parameters passed to \code{lares::clusterKmeans()}.
24+
#' @author Bernardo Lares (bernardolares@@fb.com)
25+
#' @examples
26+
#' \dontrun{
27+
#' cls <- robyn_clusters(input = OutputCollect,
28+
#' all_media = InputCollect$all_media,
29+
#' k = 3, limit = 2,
30+
#' weights = c(1, 1, 1.5))
31+
#' }
32+
#' @export
33+
robyn_clusters <- function(input, all_media = NULL, k = "auto", limit = 1,
34+
weights = rep(1, 3), dim_red = "PCA",
35+
quiet = FALSE, export = FALSE,
36+
...) {
37+
38+
if ("robyn_outputs" %in% class(input)) {
39+
if (is.null(all_media)) {
40+
aux <- colnames(input$mediaVecCollect)
41+
all_media <- aux[-c(1, which(aux == "type"):length(aux))]
42+
path <- input$plot_folder
43+
} else path <- paste0(getwd(), "/")
44+
# Pareto and ROI data
45+
rois <- input$xDecompAgg
46+
df <- .prepare_roi(rois, all_media = all_media)
47+
} else {
48+
if (all(c("solID", "mape", "nrmse", "decomp.rssd") %in% names(input)) & is.data.frame(input)) {
49+
df <- .prepare_roi(input, all_media)
50+
} else {
51+
stop(paste(
52+
"You must run robyn_export(..., clusters = TRUE) or",
53+
"pass a valid data.frame (sames as pareto_aggregated.csv output)",
54+
"in order to use robyn_clusters()"
55+
))
56+
}
57+
}
58+
59+
ignore <- c("solID", "mape", "decomp.rssd", "nrmse", "pareto")
60+
61+
# Auto K selected by less than 5% WSS variance (convergence)
62+
min_clusters <- 3
63+
limit_clusters <- min(nrow(df) - 1, 30)
64+
if ("auto" %in% k) {
65+
cls <- tryCatch({
66+
clusterKmeans(df, k = NULL, limit = limit_clusters, ignore = ignore, dim_red = dim_red, quiet = TRUE, ...)
67+
}, error = function(err) {
68+
message(paste("Couldn't automatically create clusters:", err))
69+
return(NULL)
70+
})
71+
#if (is.null(cls)) return(NULL)
72+
min_var <- 0.05
73+
k <- cls$nclusters %>%
74+
mutate(pareto = .data$wss/.data$wss[1],
75+
dif = lag(.data$pareto) - .data$pareto) %>%
76+
filter(.data$dif > min_var) %>% pull(.data$n) %>% max(.)
77+
if (k < min_clusters) k <- min_clusters
78+
if (!quiet) message(sprintf(
79+
">> Auto selected k = %s (clusters) based on minimum WSS variance of %s%%",
80+
k, min_var*100))
81+
}
82+
83+
# Build clusters
84+
stopifnot(k %in% min_clusters:30)
85+
cls <- clusterKmeans(df, k, limit = limit_clusters, ignore = ignore, dim_red = dim_red, quiet = TRUE, ...)
86+
87+
# Select top models by minimum (weighted) distance to zero
88+
top_sols <- .clusters_df(cls$df, weights) %>%
89+
mutate(error = (.data$nrmse^2 + .data$decomp.rssd^2 + .data$mape^2)^-(1 / 2)) %>%
90+
.crit_proc(limit)
91+
92+
output <- list(
93+
# Data and parameters
94+
data = mutate(cls$df, top_sol = .data$solID %in% top_sols$solID),
95+
n_clusters = k,
96+
errors_weights = weights,
97+
# Within Groups Sum of Squares Plot
98+
wss = cls$nclusters_plot,
99+
# Grouped correlations per cluster
100+
corrs = cls$correlations + labs(title = "ROI Top Correlations by Cluster", subtitle = NULL),
101+
# Mean ROI per cluster
102+
clusters_means = cls$means,
103+
# Dim reduction clusters
104+
clusters_PCA = cls[["PCA"]],
105+
clusters_tSNE = cls[["tSNE"]],
106+
# Top Clusters
107+
models = top_sols,
108+
plot_models_errors = .plot_topsols_errors(df, top_sols, limit, weights),
109+
plot_models_rois = .plot_topsols_rois(top_sols, all_media, limit)
110+
)
111+
112+
if (export) {
113+
fwrite(output$data, file = paste0(path, "pareto_clusters.csv"))
114+
ggsave(paste0(path, "pareto_clusters_wss.png"), plot = output$wss, dpi = 500, width = 5, height = 4)
115+
ggsave(paste0(path, "pareto_clusters_corr.png"), plot = output$corrs, dpi = 500, width = 7, height = 5)
116+
db <- wrap_plots(output$plot_models_rois, output$plot_models_errors)
117+
ggsave(paste0(path, "pareto_clusters_detail.png"), plot = db, dpi = 600, width = 12, height = 9)
118+
}
119+
120+
return(output)
121+
122+
}
123+
124+
125+
# ROIs data.frame for clustering (from xDecompAgg or pareto_aggregated.csv)
126+
.prepare_roi <- function(x, all_media) {
127+
check_opts(all_media, unique(x$rn))
128+
rois <- pivot_wider(x, id_cols = "solID", names_from = "rn", values_from = "roi_total")
129+
rois <- removenacols(rois, all = FALSE)
130+
rois <- select(rois, any_of(c("solID", all_media)))
131+
errors <- distinct(x, .data$solID, .data$nrmse, .data$decomp.rssd, .data$mape)
132+
rois <- left_join(rois, errors, "solID") %>% ungroup()
133+
return(rois)
134+
}
135+
136+
.min_max_norm <- function(x) (x - min(x)) / (max(x) - min(x))
137+
138+
.clusters_df <- function(df, balance = rep(1, 3)) {
139+
stopifnot(length(balance) == 3)
140+
balance <- balance / sum(balance)
141+
crit_df <- df %>%
142+
# Force normalized values so they can be comparable
143+
mutate(
144+
nrmse = .min_max_norm(.data$nrmse),
145+
decomp.rssd = .min_max_norm(.data$decomp.rssd),
146+
mape = .min_max_norm(.data$mape)
147+
) %>%
148+
# Balance to give more or less importance to each error
149+
mutate(
150+
nrmse = balance[1] / .data$nrmse,
151+
decomp.rssd = balance[2] / .data$decomp.rssd,
152+
mape = balance[3] / .data$mape
153+
) %>%
154+
replace(., is.na(.), 0) %>%
155+
group_by(.data$cluster)
156+
return(crit_df)
157+
}
158+
159+
.crit_proc <- function(df, limit) {
160+
arrange(df, .data$cluster, desc(.data$error)) %>%
161+
slice(1:limit) %>%
162+
mutate(rank = row_number()) %>%
163+
select(.data$cluster, .data$rank, everything())
164+
}
165+
166+
.plot_topsols_errors <- function(df, top_sols, limit = 1, balance = rep(1, 3)) {
167+
balance <- balance / sum(balance)
168+
left_join(df, select(top_sols, 1:3), "solID") %>%
169+
mutate(
170+
alpha = ifelse(is.na(.data$cluster), 0.5, 1),
171+
label = ifelse(!is.na(.data$cluster), sprintf(
172+
"[%s.%s]", .data$cluster, .data$rank
173+
), NA)
174+
) %>%
175+
ggplot(aes(x = .data$nrmse, y = .data$decomp.rssd)) +
176+
geom_point(aes(colour = .data$cluster, alpha = .data$alpha)) +
177+
geom_text(aes(label = .data$label), na.rm = TRUE, hjust = -0.3) +
178+
guides(alpha = "none", colour = "none") +
179+
labs(
180+
title = paste("Selecting Top", limit, "Performing Models by Cluster"),
181+
subtitle = "Based on minimum (weighted) distance to origin",
182+
x = "NRMSE", y = "DECOMP.RSSD",
183+
caption = sprintf(
184+
"Weights: NRMSE %s%%, DECOMP.RSSD %s%%, MAPE %s%%",
185+
round(100 * balance[1]), round(100 * balance[2]), round(100 * balance[3])
186+
)
187+
) +
188+
theme_lares()
189+
}
190+
191+
.plot_topsols_rois <- function(top_sols, all_media, limit = 1) {
192+
top_sols %>%
193+
mutate(label = sprintf("[%s.%s]\n%s", .data$cluster, .data$rank, .data$solID)) %>%
194+
tidyr::gather("media", "roi", contains(all_media)) %>%
195+
ggplot(aes(x = .data$media, y = .data$roi)) +
196+
facet_grid(.data$label ~ .) +
197+
geom_col() +
198+
coord_flip() +
199+
labs(
200+
title = paste("ROIs on Top", limit, "Performing Models"),
201+
x = NULL, y = "ROI per Media"
202+
) +
203+
theme_lares()
204+
}

0 commit comments

Comments
 (0)