Skip to content

Commit d152ad3

Browse files
[BIG] feat: new allocator logic and outputs #600
* Feat: all new one-pager for `robyn_allocator()` showing initial, bounded and less-bounded scenarios, using last month's worth of data by default. Relevant changes from previous versions: initial spend is now mean of date range selected, not non-zero mean anymore + deprecated "max_response_expected_spend" scenario + carryover information is now provided in the curves + inform user when budget is topped and can't be fully allocated + added mROAS / mCPA for better understanding of allocation. * Feat: `robyn_response()` now requires date or date range for adstocking (last period by default) and accepts single or multiple values to return different use cases and scenarios. * Feat: new `transform_adstock()` exported wrapper function. * Feat: added NRMSE validation on test set. * Feat: added prophet monthly component. * Fix: added correct solID for fixed hyperparameters (not 1_1_1). * Recode: reduced the size of `xDecompVec` on `OutputCollect` to only pareto-front models. * Recode: got rid of "ggcorrplot" and "rPref" package dependencies. * Docs: added blueprint link to demo.R. --------- Co-authored: @gufengzhou @laresbernardo
1 parent 619765d commit d152ad3

24 files changed

+1423
-815
lines changed

R/DESCRIPTION

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Package: Robyn
22
Type: Package
33
Title: Semi-Automated Marketing Mix Modeling (MMM) from Meta Marketing Science
4-
Version: 3.9.1.9000
4+
Version: 3.10.0.9000
55
Authors@R: c(
66
person("Gufeng", "Zhou", , "gufeng@meta.com", c("aut")),
77
person("Leonel", "Sentana", , "leonelsentana@meta.com", c("aut")),
@@ -28,7 +28,6 @@ Imports:
2828
patchwork,
2929
prophet,
3030
reticulate,
31-
rPref,
3231
stringr,
3332
tidyr
3433
Suggests:

R/NAMESPACE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ export(robyn_train)
3838
export(robyn_update)
3939
export(robyn_write)
4040
export(saturation_hill)
41+
export(transform_adstock)
4142
export(ts_validation)
4243
import(ggplot2)
4344
importFrom(doParallel,registerDoParallel)
@@ -91,6 +92,7 @@ importFrom(lares,clusterKmeans)
9192
importFrom(lares,formatNum)
9293
importFrom(lares,freqs)
9394
importFrom(lares,glued)
95+
importFrom(lares,num_abbr)
9496
importFrom(lares,ohse)
9597
importFrom(lares,removenacols)
9698
importFrom(lares,scale_x_abbr)
@@ -114,8 +116,6 @@ importFrom(prophet,add_regressor)
114116
importFrom(prophet,add_seasonality)
115117
importFrom(prophet,fit.prophet)
116118
importFrom(prophet,prophet)
117-
importFrom(rPref,low)
118-
importFrom(rPref,psel)
119119
importFrom(reticulate,conda_create)
120120
importFrom(reticulate,conda_install)
121121
importFrom(reticulate,import)

R/R/allocator.R

Lines changed: 440 additions & 256 deletions
Large diffs are not rendered by default.

R/R/calibration.R

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,14 @@ robyn_calibrate <- function(calibration_input,
5151
## 1. Adstock
5252
if (adstock == "geometric") {
5353
theta <- hypParamSam[paste0(get_channels[l_chn], "_thetas")][[1]][[1]]
54-
x_list <- adstock_geometric(x = m_calib, theta = theta)
55-
} else if (adstock == "weibull_cdf") {
56-
shape <- hypParamSam[paste0(get_channels[l_chn], "_shapes")][[1]][[1]]
57-
scale <- hypParamSam[paste0(get_channels[l_chn], "_scales")][[1]][[1]]
58-
x_list <- adstock_weibull(x = m_calib, shape = shape, scale = scale, windlen = length(m), type = "cdf")
59-
} else if (adstock == "weibull_pdf") {
54+
}
55+
if (grepl("weibull", adstock)) {
6056
shape <- hypParamSam[paste0(get_channels[l_chn], "_shapes")][[1]][[1]]
6157
scale <- hypParamSam[paste0(get_channels[l_chn], "_scales")][[1]][[1]]
62-
x_list <- adstock_weibull(x = m_calib, shape = shape, scale = scale, windlen = length(m), type = "pdf")
6358
}
64-
m_calib_total_adst <- dt_modAdstocked[calib_pos, get_channels[l_chn]][[1]]
59+
x_list <- transform_adstock(m_calib, adstock, theta = theta, shape = shape, scale = scale)
6560
m_calib_imme_adst <- x_list$x_decayed
61+
m_calib_total_adst <- dt_modAdstocked[calib_pos, get_channels[l_chn]][[1]]
6662
m_calib_hist_adst <- m_calib_total_adst - m_calib_imme_adst
6763
# Adapt for weibull_pdf with lags
6864
m_calib_imme_adst[m_calib_hist_adst < 0] <- m_calib_total_adst[m_calib_hist_adst < 0]

R/R/checks.R

Lines changed: 140 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
############# Auxiliary non-exported functions #############
77

8-
opts_pnd <- c("positive", "negative", "default")
9-
other_hyps <- c("lambda", "train_size")
10-
hyps_name <- c("thetas", "shapes", "scales", "alphas", "gammas")
8+
OPTS_PDN <- c("positive", "negative", "default")
9+
HYPS_NAMES <- c("thetas", "shapes", "scales", "alphas", "gammas")
10+
HYPS_OTHERS <- c("lambda", "train_size")
11+
LEGACY_PARAMS <- c("cores", "iterations", "trials", "intercept_sign", "nevergrad_algo")
1112

1213
check_nas <- function(df) {
1314
name <- deparse(substitute(df))
@@ -172,8 +173,8 @@ check_prophet <- function(dt_holidays, prophet_country, prophet_vars, prophet_si
172173
if (is.null(prophet_signs)) {
173174
prophet_signs <- rep("default", length(prophet_vars))
174175
}
175-
if (!all(prophet_signs %in% opts_pnd)) {
176-
stop("Allowed values for 'prophet_signs' are: ", paste(opts_pnd, collapse = ", "))
176+
if (!all(prophet_signs %in% OPTS_PDN)) {
177+
stop("Allowed values for 'prophet_signs' are: ", paste(OPTS_PDN, collapse = ", "))
177178
}
178179
if (length(prophet_signs) != length(prophet_vars)) {
179180
stop("'prophet_signs' must have same length as 'prophet_vars'")
@@ -185,8 +186,8 @@ check_prophet <- function(dt_holidays, prophet_country, prophet_vars, prophet_si
185186
check_context <- function(dt_input, context_vars, context_signs) {
186187
if (!is.null(context_vars)) {
187188
if (is.null(context_signs)) context_signs <- rep("default", length(context_vars))
188-
if (!all(context_signs %in% opts_pnd)) {
189-
stop("Allowed values for 'context_signs' are: ", paste(opts_pnd, collapse = ", "))
189+
if (!all(context_signs %in% OPTS_PDN)) {
190+
stop("Allowed values for 'context_signs' are: ", paste(OPTS_PDN, collapse = ", "))
190191
}
191192
if (length(context_signs) != length(context_vars)) {
192193
stop("Input 'context_signs' must have same length as 'context_vars'")
@@ -235,8 +236,8 @@ check_paidmedia <- function(dt_input, paid_media_vars, paid_media_signs, paid_me
235236
if (is.null(paid_media_signs)) {
236237
paid_media_signs <- rep("positive", mediaVarCount)
237238
}
238-
if (!all(paid_media_signs %in% opts_pnd)) {
239-
stop("Allowed values for 'paid_media_signs' are: ", paste(opts_pnd, collapse = ", "))
239+
if (!all(paid_media_signs %in% OPTS_PDN)) {
240+
stop("Allowed values for 'paid_media_signs' are: ", paste(OPTS_PDN, collapse = ", "))
240241
}
241242
if (length(paid_media_signs) == 1) {
242243
paid_media_signs <- rep(paid_media_signs, length(paid_media_vars))
@@ -281,8 +282,8 @@ check_organicvars <- function(dt_input, organic_vars, organic_signs) {
281282
organic_signs <- rep("positive", length(organic_vars))
282283
# message("'organic_signs' were not provided. Using 'positive'")
283284
}
284-
if (!all(organic_signs %in% opts_pnd)) {
285-
stop("Allowed values for 'organic_signs' are: ", paste(opts_pnd, collapse = ", "))
285+
if (!all(organic_signs %in% OPTS_PDN)) {
286+
stop("Allowed values for 'organic_signs' are: ", paste(OPTS_PDN, collapse = ", "))
286287
}
287288
if (length(organic_signs) != length(organic_vars)) {
288289
stop("Input 'organic_signs' must have same length as 'organic_vars'")
@@ -444,10 +445,10 @@ check_hyperparameters <- function(hyperparameters = NULL, adstock = NULL,
444445
ref_hyp_name_spend <- hyper_names(adstock, all_media = paid_media_spends)
445446
ref_hyp_name_expo <- hyper_names(adstock, all_media = exposure_vars)
446447
ref_hyp_name_org <- hyper_names(adstock, all_media = organic_vars)
447-
ref_hyp_name_other <- get_hyp_names[get_hyp_names %in% other_hyps]
448-
# Excluding lambda (first other_hyps) given its range is not customizable
449-
ref_all_media <- sort(c(ref_hyp_name_spend, ref_hyp_name_org, other_hyps))
450-
all_ref_names <- c(ref_hyp_name_spend, ref_hyp_name_expo, ref_hyp_name_org, other_hyps)
448+
ref_hyp_name_other <- get_hyp_names[get_hyp_names %in% HYPS_OTHERS]
449+
# Excluding lambda (first HYPS_OTHERS) given its range is not customizable
450+
ref_all_media <- sort(c(ref_hyp_name_spend, ref_hyp_name_org, HYPS_OTHERS))
451+
all_ref_names <- c(ref_hyp_name_spend, ref_hyp_name_expo, ref_hyp_name_org, HYPS_OTHERS)
451452
all_ref_names <- all_ref_names[order(all_ref_names)]
452453
if (!all(get_hyp_names %in% all_ref_names)) {
453454
wrong_hyp_names <- get_hyp_names[which(!(get_hyp_names %in% all_ref_names))]
@@ -717,7 +718,7 @@ check_hyper_fixed <- function(InputCollect, dt_hyper_fixed, add_penalty_factor)
717718
# Adstock hyper-parameters
718719
hypParamSamName <- hyper_names(adstock = InputCollect$adstock, all_media = InputCollect$all_media)
719720
# Add lambda and other hyper-parameters manually
720-
hypParamSamName <- c(hypParamSamName, other_hyps)
721+
hypParamSamName <- c(hypParamSamName, HYPS_OTHERS)
721722
# Add penalty factor hyper-parameters names
722723
if (add_penalty_factor) {
723724
for_penalty <- names(select(InputCollect$dt_mod, -.data$ds, -.data$dep_var))
@@ -774,8 +775,7 @@ check_class <- function(x, object) {
774775
}
775776

776777
check_allocator <- function(OutputCollect, select_model, paid_media_spends, scenario,
777-
channel_constr_low, channel_constr_up,
778-
expected_spend, expected_spend_days, constr_mode) {
778+
channel_constr_low, channel_constr_up, constr_mode) {
779779
dt_hyppar <- OutputCollect$resultHypParam[OutputCollect$resultHypParam$solID == select_model, ]
780780
if (!(select_model %in% OutputCollect$allSolutions)) {
781781
stop(
@@ -792,11 +792,10 @@ check_allocator <- function(OutputCollect, select_model, paid_media_spends, scen
792792
if (any(channel_constr_up > 5)) {
793793
warning("Inputs 'channel_constr_up' > 5 might cause unrealistic allocation")
794794
}
795-
opts <- c("max_historical_response", "max_response_expected_spend")
795+
opts <- "max_historical_response" # Deprecated: max_response_expected_spend
796796
if (!(scenario %in% opts)) {
797797
stop("Input 'scenario' must be one of: ", paste(opts, collapse = ", "))
798798
}
799-
800799
if (length(channel_constr_low) != 1 && length(channel_constr_low) != length(paid_media_spends)) {
801800
stop(paste(
802801
"Input 'channel_constr_low' have to contain either only 1",
@@ -809,35 +808,144 @@ check_allocator <- function(OutputCollect, select_model, paid_media_spends, scen
809808
"value or have same length as 'InputCollect$paid_media_spends':", length(paid_media_spends)
810809
))
811810
}
812-
813-
if ("max_response_expected_spend" %in% scenario) {
814-
if (any(is.null(expected_spend), is.null(expected_spend_days))) {
815-
stop("When scenario = 'max_response_expected_spend', expected_spend and expected_spend_days must be provided")
816-
}
817-
}
818811
opts <- c("eq", "ineq")
819812
if (!(constr_mode %in% opts)) {
820813
stop("Input 'constr_mode' must be one of: ", paste(opts, collapse = ", "))
821814
}
822815
}
823816

824-
check_metric_value <- function(metric_value, media_metric) {
817+
check_metric_type <- function(metric_name, paid_media_spends, paid_media_vars, exposure_vars, organic_vars) {
818+
if (metric_name %in% paid_media_spends && length(metric_name) == 1) {
819+
metric_type <- "spend"
820+
} else if (metric_name %in% exposure_vars && length(metric_name) == 1) {
821+
metric_type <- "exposure"
822+
} else if (metric_name %in% organic_vars && length(metric_name) == 1) {
823+
metric_type <- "organic"
824+
} else {
825+
stop(paste(
826+
"Invalid 'metric_name' input. It must be any media variable from",
827+
"paid_media_spends (spend), paid_media_vars (exposure),",
828+
"or organic_vars (organic); NOT:", metric_name,
829+
paste("\n- paid_media_spends:", v2t(paid_media_spends, quotes = FALSE)),
830+
paste("\n- paid_media_vars:", v2t(paid_media_vars, quotes = FALSE)),
831+
paste("\n- organic_vars:", v2t(organic_vars, quotes = FALSE))
832+
))
833+
}
834+
return(metric_type)
835+
}
836+
837+
check_metric_dates <- function(date_range = NULL, all_dates, dayInterval = NULL, quiet = FALSE, is_allocator = FALSE, ...) {
838+
## default using latest 30 days / 4 weeks / 1 month for spend level
839+
if (is.null(date_range)) {
840+
if (is.null(dayInterval)) stop("Input 'date_range' or 'dayInterval' must be defined")
841+
if (!is_allocator) {
842+
date_range <- "last_1"
843+
} else {
844+
date_range <- paste0("last_", dplyr::case_when(
845+
dayInterval == 1 ~ 30,
846+
dayInterval == 7 ~ 4,
847+
dayInterval >= 30 & dayInterval <= 31 ~ 1,
848+
))
849+
}
850+
if (!quiet) message(sprintf("Automatically picked date_range = '%s'", date_range))
851+
}
852+
if (grepl("last|all", date_range[1])) {
853+
## Using last_n as date_range range
854+
if ("all" %in% date_range) date_range <- paste0("last_", length(all_dates))
855+
get_n <- ifelse(grepl("_", date_range[1]), as.integer(gsub("last_", "", date_range)), 1)
856+
date_range <- tail(all_dates, get_n)
857+
date_range_loc <- which(all_dates %in% date_range)
858+
date_range_updated <- all_dates[date_range_loc]
859+
rg <- v2t(range(date_range_updated), sep = ":", quotes = FALSE)
860+
} else {
861+
## Using dates as date_range range
862+
if (all(is.Date(as.Date(date_range, origin = "1970-01-01")))) {
863+
date_range <- as.Date(date_range, origin = "1970-01-01")
864+
if (length(date_range) == 1) {
865+
## Using only 1 date
866+
if (all(date_range %in% all_dates)) {
867+
date_range_updated <- date_range
868+
date_range_loc <- which(all_dates == date_range)
869+
if (!quiet) message("Using ds '", date_range_updated, "' as the response period")
870+
} else {
871+
date_range_loc <- which.min(abs(date_range - all_dates))
872+
date_range_updated <- all_dates[date_range_loc]
873+
if (!quiet) warning("Input 'date_range' (", date_range, ") has no match. Picking closest date: ", date_range_updated)
874+
}
875+
} else if (length(date_range) == 2) {
876+
## Using two dates as "from-to" date range
877+
date_range_loc <- unlist(lapply(date_range, function(x) which.min(abs(x - all_dates))))
878+
date_range_loc <- date_range_loc[1]:date_range_loc[2]
879+
date_range_updated <- all_dates[date_range_loc]
880+
if (!quiet & !all(date_range %in% date_range_updated)) {
881+
warning(paste(
882+
"At least one date in 'date_range' input do not match any date.",
883+
"Picking closest dates for range:", paste(range(date_range_updated), collapse = ":")
884+
))
885+
}
886+
rg <- v2t(range(date_range_updated), sep = ":", quotes = FALSE)
887+
get_n <- length(date_range_loc)
888+
} else {
889+
## Manually inputting each date
890+
date_range_updated <- date_range
891+
if (all(date_range %in% all_dates)) {
892+
date_range_loc <- which(all_dates %in% date_range_updated)
893+
} else {
894+
date_range_loc <- unlist(lapply(date_range_updated, function(x) which.min(abs(x - all_dates))))
895+
rg <- v2t(range(date_range_updated), sep = ":", quotes = FALSE)
896+
}
897+
if (all(na.omit(date_range_loc - lag(date_range_loc)) == 1)) {
898+
date_range_updated <- all_dates[date_range_loc]
899+
if (!quiet) warning("At least one date in 'date_range' do not match ds. Picking closest date: ", date_range_updated)
900+
} else {
901+
stop("Input 'date_range' needs to have sequential dates")
902+
}
903+
}
904+
} else {
905+
stop("Input 'date_range' must have date format '2023-01-01' or use 'last_n'")
906+
}
907+
}
908+
return(list(
909+
date_range_updated = date_range_updated,
910+
metric_loc = date_range_loc
911+
))
912+
}
913+
914+
check_metric_value <- function(metric_value, metric_name, all_values, metric_loc) {
915+
get_n <- length(metric_loc)
916+
if (any(is.nan(metric_value))) metric_value <- NULL
825917
if (!is.null(metric_value)) {
826918
if (!is.numeric(metric_value)) {
827919
stop(sprintf(
828-
"Input 'metric_value' for %s (%s) must be a numerical value\n", media_metric, toString(metric_value)
920+
"Input 'metric_value' for %s (%s) must be a numerical value\n", metric_name, toString(metric_value)
829921
))
830922
}
831-
if (sum(metric_value <= 0) > 0) {
923+
if (any(metric_value < 0)) {
832924
stop(sprintf(
833-
"Input 'metric_value' for %s (%s) must be a positive value\n", media_metric, metric_value[metric_value <= 0]
925+
"Input 'metric_value' for %s must be positive\n", metric_name
834926
))
835927
}
928+
if (get_n > 1 & length(metric_value) == 1) {
929+
metric_value_updated <- rep(metric_value / get_n, get_n)
930+
# message(paste0("'metric_value'", metric_value, " splitting into ", get_n, " periods evenly"))
931+
} else {
932+
if (length(metric_value) != get_n) {
933+
stop("robyn_response metric_value & date_range must have same length\n")
934+
}
935+
metric_value_updated <- metric_value
936+
}
836937
}
938+
if (is.null(metric_value)) {
939+
metric_value_updated <- all_values[metric_loc]
940+
}
941+
all_values_updated <- all_values
942+
all_values_updated[metric_loc] <- metric_value_updated
943+
return(list(
944+
metric_value_updated = metric_value_updated,
945+
all_values_updated = all_values_updated
946+
))
837947
}
838948

839-
LEGACY_PARAMS <- c("cores", "iterations", "trials", "intercept_sign", "nevergrad_algo")
840-
841949
check_legacy_input <- function(InputCollect,
842950
cores = NULL, iterations = NULL, trials = NULL,
843951
intercept_sign = NULL, nevergrad_algo = NULL) {

R/R/exports.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ robyn_save <- function(InputCollect,
3535
)
3636

3737
# Nice and tidy table format for hyper-parameters
38-
regex <- paste(paste0("_", hyps_name), collapse = "|")
38+
regex <- paste(paste0("_", HYPS_NAMES), collapse = "|")
3939
hyps <- filter(OutputCollect$resultHypParam, .data$solID == select_model) %>%
40-
select(contains(hyps_name)) %>%
40+
select(contains(HYPS_NAMES)) %>%
4141
tidyr::gather() %>%
4242
tidyr::separate(.data$key,
4343
into = c("channel", "none"),

R/R/imports.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
#' @importFrom ggridges geom_density_ridges geom_density_ridges_gradient
2929
#' @importFrom glmnet cv.glmnet glmnet
3030
#' @importFrom jsonlite fromJSON toJSON write_json read_json
31-
#' @importFrom lares check_opts clusterKmeans formatNum freqs glued ohse removenacols
31+
#' @importFrom lares check_opts clusterKmeans formatNum freqs glued num_abbr ohse removenacols
3232
#' theme_lares `%>%` scale_x_abbr scale_x_percent scale_y_percent scale_y_abbr try_require v2t
3333
#' @importFrom lubridate is.Date day floor_date
3434
#' @importFrom minpack.lm nlsLM
@@ -38,7 +38,6 @@
3838
#' @importFrom prophet add_regressor add_seasonality fit.prophet prophet
3939
#' @importFrom reticulate tuple use_condaenv import conda_create conda_install py_module_available
4040
#' virtualenv_create py_install use_virtualenv
41-
#' @importFrom rPref low psel
4241
#' @importFrom stats AIC BIC coef complete.cases dgamma dnorm end lm model.matrix na.omit
4342
#' nls.control median qt sd predict pweibull dweibull quantile qunif reorder rnorm start setNames
4443
#' @importFrom stringr str_count str_detect str_remove str_split str_which str_extract str_replace

R/R/inputs.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -517,12 +517,12 @@ Adstock: {x$adstock}
517517
hyper_names <- function(adstock, all_media) {
518518
adstock <- check_adstock(adstock)
519519
if (adstock == "geometric") {
520-
local_name <- sort(apply(expand.grid(all_media, hyps_name[
521-
grepl("thetas|alphas|gammas", hyps_name)
520+
local_name <- sort(apply(expand.grid(all_media, HYPS_NAMES[
521+
grepl("thetas|alphas|gammas", HYPS_NAMES)
522522
]), 1, paste, collapse = "_"))
523523
} else if (adstock %in% c("weibull_cdf", "weibull_pdf")) {
524-
local_name <- sort(apply(expand.grid(all_media, hyps_name[
525-
grepl("shapes|scales|alphas|gammas", hyps_name)
524+
local_name <- sort(apply(expand.grid(all_media, HYPS_NAMES[
525+
grepl("shapes|scales|alphas|gammas", HYPS_NAMES)
526526
]), 1, paste, collapse = "_"))
527527
}
528528
return(local_name)
@@ -831,7 +831,7 @@ prophet_decomp <- function(dt_transform, dt_holidays,
831831
)
832832
}
833833
mod <- fit.prophet(modelRecurrence, dt_regressors)
834-
forecastRecurrence <- predict(mod, dt_regressors)
834+
forecastRecurrence <- predict(mod, dt_regressors) # prophet::prophet_plot_components(modelRecurrence, forecastRecurrence)
835835
}
836836

837837
these <- seq_along(unlist(recurrence[, 1]))

0 commit comments

Comments
 (0)