Skip to content

Commit efb45d4

Browse files
laresbernardoGufeng Zhou
andauthored
feat: time series validation - v3.9.0 (#545)
- **Feat**: new time series validation via time-series train/val/test dynamic splits and Adjusted R2 and NRMSE metrics reported for each group feature. We are adding an additional `train_size` hyperparameter to pick the size of the training size, which by default will iterate in the range of 0.5-0.8. Given it's a hyperparameter, you can change the range or fix the value manually. Turn on/off this feature using the `ts_validation` new parameter on `robyn_run()`; default is set to `FALSE` for now. This is an important step for the forecasting coming function. - **Feat**: new `ts_validation()` function to plot time-series validation and convergence results. Generated and exported by default when `ts_validation = TRUE`, and when `export = TRUE`, creating `ts_validation_plot.png` file. - **Fix**: updated Adjusted R2 calculation for time-series validation using same denominator. - **Fix**: results are not sorted by lowest errors now to keep iteration results actual order. - **Feat**: added prophet monthly component to enrich decomposition results. - **Fix**: correct solID (not 1_1_1) for fixed hyperparameters recreated models. - **Recode**: reduced the size of `xDecompVec` on `OutputCollect` by keeping pareto-front models only. - **Docs**: adapted standard inputs for window start and end to include more data (3 years by default) Co-authored-by: Gufeng Zhou <gufengzhou@fb.com>
1 parent 3298ff8 commit efb45d4

21 files changed

+539
-211
lines changed

R/DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Package: Robyn
22
Type: Package
33
Title: Semi-Automated Marketing Mix Modeling (MMM) from Meta Marketing Science
4-
Version: 3.8.2
4+
Version: 3.9.0
55
Authors@R: c(
66
person("Gufeng", "Zhou", , "gufeng@meta.com", c("aut")),
77
person("Leonel", "Sentana", , "leonelsentana@meta.com", c("aut")),

R/NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ export(robyn_train)
3838
export(robyn_update)
3939
export(robyn_write)
4040
export(saturation_hill)
41+
export(ts_validation)
4142
import(ggplot2)
4243
importFrom(doParallel,registerDoParallel)
4344
importFrom(doParallel,stopImplicitCluster)
@@ -108,6 +109,7 @@ importFrom(patchwork,plot_annotation)
108109
importFrom(patchwork,plot_layout)
109110
importFrom(patchwork,wrap_plots)
110111
importFrom(prophet,add_regressor)
112+
importFrom(prophet,add_seasonality)
111113
importFrom(prophet,fit.prophet)
112114
importFrom(prophet,prophet)
113115
importFrom(rPref,low)

R/R/auxiliary.R

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,22 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# Calculate R-squared
7-
get_rsq <- function(true, predicted, p = NULL, df.int = NULL) {
7+
get_rsq <- function(true, predicted, p = NULL, df.int = NULL, n_train = NULL) {
88
sse <- sum((predicted - true)^2)
99
sst <- sum((true - mean(true))^2)
10-
rsq <- 1 - sse / sst
10+
rsq <- 1 - sse / sst # rsq interpreted as variance explained
11+
rsq_out <- rsq
1112
if (!is.null(p) && !is.null(df.int)) {
12-
n <- length(true)
13+
if (!is.null(n_train)) {
14+
n <- n_train # for oos dataset, use n from train set for adj. rsq
15+
} else {
16+
n <- length(true)
17+
}
1318
rdf <- n - p - 1
14-
rsq <- 1 - (1 - rsq) * ((n - df.int) / rdf)
19+
rsq_adj <- 1 - (1 - rsq) * ((n - df.int) / rdf)
20+
rsq_out <- rsq_adj
1521
}
16-
return(rsq)
22+
return(rsq_out)
1723
}
1824

1925
# Robyn colors

R/R/checks.R

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
############# Auxiliary non-exported functions #############
77

88
opts_pnd <- c("positive", "negative", "default")
9+
other_hyps <- c("lambda", "train_size")
10+
hyps_name <- c("thetas", "shapes", "scales", "alphas", "gammas")
911

1012
check_nas <- function(df) {
1113
name <- deparse(substitute(df))
@@ -151,7 +153,7 @@ check_prophet <- function(dt_holidays, prophet_country, prophet_vars, prophet_si
151153
if (is.null(dt_holidays) || is.null(prophet_vars)) {
152154
return(invisible(NULL))
153155
} else {
154-
opts <- c("trend", "season", "weekday", "holiday")
156+
opts <- c("trend", "season", "monthly", "weekday", "holiday")
155157
if (!all(prophet_vars %in% opts)) {
156158
stop("Allowed values for 'prophet_vars' are: ", paste(opts, collapse = ", "))
157159
}
@@ -426,7 +428,10 @@ check_hyperparameters <- function(hyperparameters = NULL, adstock = NULL,
426428
"robyn_inputs(InputCollect = InputCollect, hyperparameters = ...)"
427429
))
428430
} else {
429-
hyperparameters <- hyperparameters[which(!names(hyperparameters) %in% "lambda")]
431+
# Non-adstock hyperparameters check
432+
check_train_size(hyperparameters)
433+
# Adstock hyperparameters check
434+
hyperparameters <- hyperparameters[which(!names(hyperparameters) %in% other_hyps)]
430435
hyperparameters_ordered <- hyperparameters[order(names(hyperparameters))]
431436
get_hyp_names <- names(hyperparameters_ordered)
432437
ref_hyp_name_spend <- hyper_names(adstock, all_media = paid_media_spends)
@@ -470,6 +475,17 @@ check_hyperparameters <- function(hyperparameters = NULL, adstock = NULL,
470475
}
471476
}
472477

478+
check_train_size <- function(hyps) {
479+
if ("train_size" %in% names(hyps)) {
480+
if (!length(hyps$train_size) %in% 1:2) {
481+
stop("Hyperparameter 'train_size' must be length 1 (fixed) or 2 (range)")
482+
}
483+
if (any(hyps$train_size <= 0.1) || any(hyps$train_size > 1)) {
484+
stop("Hyperparameter 'train_size' values must be defined between 0.1 and 1")
485+
}
486+
}
487+
}
488+
473489
check_hyper_limits <- function(hyperparameters, hyper) {
474490
hyper_which <- which(endsWith(names(hyperparameters), hyper))
475491
if (length(hyper_which) == 0) {
@@ -692,8 +708,8 @@ check_hyper_fixed <- function(InputCollect, dt_hyper_fixed, add_penalty_factor)
692708
hyper_fixed <- !is.null(dt_hyper_fixed)
693709
# Adstock hyper-parameters
694710
hypParamSamName <- hyper_names(adstock = InputCollect$adstock, all_media = InputCollect$all_media)
695-
# Add lambda hyper-parameter
696-
hypParamSamName <- c(hypParamSamName, "lambda")
711+
# Add lambda and other hyper-parameters manually
712+
hypParamSamName <- c(hypParamSamName, other_hyps)
697713
# Add penalty factor hyper-parameters names
698714
if (add_penalty_factor) {
699715
for_penalty <- names(select(InputCollect$dt_mod, -.data$ds, -.data$dep_var))
@@ -734,10 +750,14 @@ check_init_msg <- function(InputCollect, cores) {
734750
"Using", InputCollect$adstock, "adstocking with",
735751
length(InputCollect$hyper_updated), "hyperparameters", det
736752
)
737-
if (check_parallel()) {
738-
message(paste(base, "on", cores, "cores"))
753+
if (cores == 1) {
754+
message(paste(base, "with no parallel computation"))
739755
} else {
740-
message(paste(base, "on 1 core (Windows fallback)"))
756+
if (check_parallel()) {
757+
message(paste(base, "on", cores, "cores"))
758+
} else {
759+
message(paste(base, "on 1 core (Windows fallback)"))
760+
}
741761
}
742762
}
743763

R/R/clusters.R

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,6 @@ confidence_calcs <- function(xDecompAgg, cls, all_paid, dep_var_type, k, boot_n
186186
ci_low <- ifelse(boot_res$ci[1] < 0, 0, boot_res$ci[1])
187187
ci_up <- boot_res$ci[2]
188188

189-
## Experiment with gamma distribution fitting
190-
# mod_gamma <- nloptr(x0 = c(1, 1), eval_f = gamma_mle, lb = c(0,0),
191-
# x = unlist(df_chn$roi_total),
192-
# opts = list(algorithm = "NLOPT_LN_SBPLX", maxeval = 1e5))
193-
# gamma_params <- mod_gamma$solution
194-
# g_low = qgamma(0.025, shape=gamma_params[[1]], scale= gamma_params[[2]])
195-
# g_up = qgamma(0.975, shape=gamma_params[[1]], scale= gamma_params[[2]])
196-
197189
# Collect loop results
198190
chn_collect[[i]] <- df_chn %>%
199191
mutate(
@@ -283,13 +275,6 @@ errors_scores <- function(df, balance = rep(1, 3)) {
283275
return(scores)
284276
}
285277

286-
# gamma_mle <- function(params, x) {
287-
# gamma_shape <- params[[1]]
288-
# gamma_scale <- params[[2]]
289-
# # Negative log-likelihood
290-
# return(-sum(dgamma(x, shape = gamma_shape, scale = gamma_scale, log = TRUE)))
291-
# }
292-
293278
# ROIs data.frame for clustering (from xDecompAgg or pareto_aggregated.csv)
294279
.prepare_df <- function(x, all_media, dep_var_type) {
295280
check_opts(all_media, unique(x$rn))

R/R/convergence.R

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,8 @@ robyn_converge <- function(OutputModels, n_cuts = 20, sd_qtref = 3, med_lowb = 2
4242
stopifnot(n_cuts > min(c(sd_qtref, med_lowb)) + 1)
4343

4444
# Gather all trials
45-
get_lists <- as.logical(grepl("trial", names(OutputModels)) * unlist(lapply(OutputModels, is.list)))
46-
OutModels <- OutputModels[get_lists]
47-
for (i in seq_along(OutModels)) {
48-
if (i == 1) df <- data.frame()
49-
temp <- OutModels[[i]]$resultCollect$resultHypParam %>% mutate(trial = i)
50-
df <- rbind(df, temp)
51-
}
45+
get_trials <- which(names(OutputModels) %in% paste0("trial", seq(OutputModels$trials)))
46+
df <- bind_rows(lapply(OutputModels[get_trials], function(x) x$resultCollect$resultHypParam))
5247
calibrated <- isTRUE(sum(df$mape) > 0)
5348

5449
# Calculate deciles
@@ -199,3 +194,37 @@ robyn_converge <- function(OutputModels, n_cuts = 20, sd_qtref = 3, med_lowb = 2
199194

200195
return(invisible(cvg_out))
201196
}
197+
198+
test_cvg <- function() {
199+
# Experiment with gamma distribution fitting
200+
gamma_mle <- function(params, x) {
201+
gamma_shape <- params[[1]]
202+
gamma_scale <- params[[2]]
203+
# Negative log-likelihood
204+
return(-sum(dgamma(x, shape = gamma_shape, scale = gamma_scale, log = TRUE)))
205+
}
206+
f_geo <- function(a, r, n) {
207+
for (i in 2:n) a[i] <- a[i - 1] * r
208+
return(a)
209+
}
210+
seq_nrmse <- f_geo(5, 0.7, 100)
211+
df_nrmse <- data.frame(x = 1:100, y = seq_nrmse, type = "true")
212+
mod_gamma <- nloptr(
213+
x0 = c(1, 1), eval_f = gamma_mle, lb = c(0, 0),
214+
x = seq_nrmse,
215+
opts = list(algorithm = "NLOPT_LN_SBPLX", maxeval = 1e5)
216+
)
217+
gamma_params <- mod_gamma$solution
218+
seq_nrmse_gam <- 1 / dgamma(seq_nrmse, shape = gamma_params[[1]], scale = gamma_params[[2]])
219+
seq_nrmse_gam <- seq_nrmse_gam / (max(seq_nrmse_gam) - min(seq_nrmse_gam))
220+
seq_nrmse_gam <- max(seq_nrmse) * seq_nrmse_gam
221+
range(seq_nrmse_gam)
222+
range(seq_nrmse)
223+
df_nrmse_gam <- data.frame(x = 1:100, y = seq_nrmse_gam, type = "pred")
224+
df_nrmse <- bind_rows(df_nrmse, df_nrmse_gam)
225+
p <- ggplot(df_nrmse, aes(.data$x, .data$y, color = .data$type)) +
226+
geom_line()
227+
return(p)
228+
# g_low = qgamma(0.025, shape=gamma_params[[1]], scale= gamma_params[[2]])
229+
# g_up = qgamma(0.975, shape=gamma_params[[1]], scale= gamma_params[[2]])
230+
}

R/R/exports.R

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ robyn_save <- function(InputCollect,
3535
)
3636

3737
# Nice and tidy table format for hyper-parameters
38-
hyps_name <- c("thetas", "shapes", "scales", "alphas", "gammas")
3938
regex <- paste(paste0("_", hyps_name), collapse = "|")
4039
hyps <- filter(OutputCollect$resultHypParam, .data$solID == select_model) %>%
4140
select(contains(hyps_name)) %>%
@@ -129,7 +128,13 @@ print.robyn_save <- function(x, ...) {
129128
print(glued(
130129
"\n\nModel's Performance and Errors:\n {errors}",
131130
errors = paste(
132-
"R2 (train):", signif(x$errors$rsq_train, 4),
131+
sprintf(
132+
"R2 (%s): %s)",
133+
ifelse(!isTRUE(x$ExportedModel$ts_validation), "train", "test"),
134+
ifelse(!isTRUE(x$ExportedModel$ts_validation),
135+
signif(x$errors$rsq_train, 4), signif(x$errors$rsq_test, 4)
136+
)
137+
),
133138
"| NRMSE =", signif(x$errors$nrmse, 4),
134139
"| DECOMP.RSSD =", signif(x$errors$decomp.rssd, 4),
135140
"| MAPE =", signif(x$errors$mape, 4)
@@ -144,7 +149,7 @@ print.robyn_save <- function(x, ...) {
144149
replace(., . == "NA", "-") %>% as.data.frame())
145150

146151
print(glued(
147-
"\n\nHyper-parameters for channel transformations:\n Adstock: {x$adstock}"
152+
"\n\nHyper-parameters:\n Adstock: {x$adstock}"
148153
))
149154

150155
print(as.data.frame(x$hyper_df))

R/R/imports.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
#' @importFrom nloptr nloptr
3636
#' @importFrom parallel detectCores
3737
#' @importFrom patchwork guide_area plot_layout plot_annotation wrap_plots
38-
#' @importFrom prophet add_regressor fit.prophet prophet
38+
#' @importFrom prophet add_regressor add_seasonality fit.prophet prophet
3939
#' @importFrom reticulate tuple use_condaenv import conda_create conda_install py_module_available
4040
#' virtualenv_create py_install use_virtualenv
4141
#' @importFrom rPref low psel

R/R/inputs.R

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ Adstock: {x$adstock}
418418
},
419419
hyps = if (!is.null(x$hyperparameters)) {
420420
glued(
421-
"Hyper-parameters for channel transformations:\n{flatten_hyps(x$hyperparameters)}"
421+
"Hyper-parameters ranges:\n{flatten_hyps(x$hyperparameters)}"
422422
)
423423
} else {
424424
paste("Hyper-parameters:", "\033[0;31mNot set yet\033[0m")
@@ -514,11 +514,14 @@ Adstock: {x$adstock}
514514
#' @export
515515
hyper_names <- function(adstock, all_media) {
516516
adstock <- check_adstock(adstock)
517-
global_name <- c("thetas", "shapes", "scales", "alphas", "gammas", "lambdas")
518517
if (adstock == "geometric") {
519-
local_name <- sort(apply(expand.grid(all_media, global_name[grepl("thetas|alphas|gammas", global_name)]), 1, paste, collapse = "_"))
518+
local_name <- sort(apply(expand.grid(all_media, hyps_name[
519+
grepl("thetas|alphas|gammas", hyps_name)
520+
]), 1, paste, collapse = "_"))
520521
} else if (adstock %in% c("weibull_cdf", "weibull_pdf")) {
521-
local_name <- sort(apply(expand.grid(all_media, global_name[grepl("shapes|scales|alphas|gammas", global_name)]), 1, paste, collapse = "_"))
522+
local_name <- sort(apply(expand.grid(all_media, hyps_name[
523+
grepl("shapes|scales|alphas|gammas", hyps_name)
524+
]), 1, paste, collapse = "_"))
522525
}
523526
return(local_name)
524527
}
@@ -769,6 +772,7 @@ prophet_decomp <- function(dt_transform, dt_holidays,
769772
use_trend <- "trend" %in% prophet_vars
770773
use_holiday <- "holiday" %in% prophet_vars
771774
use_season <- "season" %in% prophet_vars | "yearly.seasonality" %in% prophet_vars
775+
use_monthly <- "monthly" %in% prophet_vars
772776
use_weekday <- "weekday" %in% prophet_vars | "weekly.seasonality" %in% prophet_vars
773777

774778
dt_regressors <- bind_cols(recurrence, select(
@@ -791,6 +795,12 @@ prophet_decomp <- function(dt_transform, dt_holidays,
791795
custom_params$yearly.seasonality <- custom_params$weekly.seasonality <- NULL
792796
prophet_params <- append(prophet_params, custom_params)
793797
modelRecurrence <- do.call(prophet, as.list(prophet_params))
798+
if (use_monthly) {
799+
modelRecurrence <- add_seasonality(
800+
modelRecurrence,
801+
name = "monthly", period = 30.5, fourier.order = 5
802+
)
803+
}
794804

795805
# dt_regressors <<- dt_regressors
796806
# modelRecurrence <<- modelRecurrence
@@ -821,12 +831,13 @@ prophet_decomp <- function(dt_transform, dt_holidays,
821831
# dt_regressors <<- dt_regressors
822832
}
823833
mod <- fit.prophet(modelRecurrence, dt_regressors)
824-
forecastRecurrence <- predict(mod, dt_regressors)
834+
forecastRecurrence <- predict(mod, dt_regressors) # prophet::prophet_plot_components(modelRecurrence, forecastRecurrence)
825835
}
826836

827837
these <- seq_along(unlist(recurrence[, 1]))
828838
if (use_trend) dt_transform$trend <- forecastRecurrence$trend[these]
829839
if (use_season) dt_transform$season <- forecastRecurrence$yearly[these]
840+
if (use_monthly) dt_transform$monthly <- forecastRecurrence$monthly[these]
830841
if (use_weekday) dt_transform$weekday <- forecastRecurrence$weekly[these]
831842
if (use_holiday) dt_transform$holiday <- forecastRecurrence$holidays[these]
832843
return(dt_transform)

0 commit comments

Comments
 (0)