Skip to content

Commit 3bc8422

Browse files
fix: refresh hyps check + use data available in json + refresh hyps + upper constraints fix when higher than mean (#974)
* fix: refresh hyps check #960 + use data available in json * fix: update based on gz's comments * fix: fixed penalties and other fixed hyps on refreshing models * fix: refresh plot when chain is broken + feat: new bounds_freedom parameter to overwrite default calculation * fix: import and store original model when not in original plot_dir * recode: applied styler::tidyverse_style() to clean code for CRAN * fix: paid_media_total calc * fix: print ExportedModel only when available * fix: deal with negative trend - negative trend is not interpretable for MMM - force negative coef when trend is negative to get positive decomp * fix: upper constraint issue on BA for target_efficiency and weibull adstock feat: instead of Inf, use channel_constr_up, which by default is 10 for target_efficiency * fix: reverse wrong bounds update in refresh_hyps The refactoring of initBounds & listOutputPrev in refresh_hyps was wrong in 774c18d * recode: apply styler::tidyverse_style() --------- Co-authored-by: gufengzhou <gufengzhou@gmail.com>
1 parent a10572b commit 3bc8422

File tree

15 files changed

+245
-139
lines changed

15 files changed

+245
-139
lines changed

R/DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Package: Robyn
22
Type: Package
33
Title: Semi-Automated Marketing Mix Modeling (MMM) from Meta Marketing Science
4-
Version: 3.10.7.9000
4+
Version: 3.10.7.9001
55
Authors@R: c(
66
person("Gufeng", "Zhou", , "gufeng@meta.com", c("cre","aut")),
77
person("Bernardo", "Lares", , "laresbernardo@gmail.com", c("aut")),

R/R/allocator.R

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ robyn_allocator <- function(robyn_object = NULL,
168168
if (is.null(channel_constr_up)) {
169169
channel_constr_up <- case_when(
170170
scenario == "max_response" ~ 2,
171-
scenario == "target_efficiency" ~ Inf
171+
scenario == "target_efficiency" ~ 10
172172
)
173173
}
174174
if (length(channel_constr_low) == 1) channel_constr_low <- rep(channel_constr_low, length(paid_media_spends))
@@ -271,8 +271,8 @@ robyn_allocator <- function(robyn_object = NULL,
271271
select_build = select_build,
272272
select_model = select_model,
273273
metric_name = mediaSpendSorted[i],
274-
#metric_value = initSpendUnit[i] * simulation_period[i],
275-
#date_range = date_range,
274+
# metric_value = initSpendUnit[i] * simulation_period[i],
275+
# date_range = date_range,
276276
dt_hyppar = OutputCollect$resultHypParam,
277277
dt_coef = OutputCollect$xDecompAgg,
278278
InputCollect = InputCollect,
@@ -478,14 +478,13 @@ robyn_allocator <- function(robyn_object = NULL,
478478

479479
if (scenario == "target_efficiency") {
480480
## bounded optimisation
481-
total_response <- sum(OutputCollect$xDecompAgg$xDecompAgg)
482481
nlsMod <- nloptr::nloptr(
483482
x0 = x0,
484483
eval_f = eval_f,
485484
eval_g_eq = if (constr_mode == "eq") eval_g_eq_effi else NULL,
486485
eval_g_ineq = if (constr_mode == "ineq") eval_g_eq_effi else NULL,
487486
lb = lb,
488-
ub = rep(total_response, length(ub)),
487+
ub = x0 * channel_constr_up[1], # Large enough, but not infinite (customizable)
489488
opts = list(
490489
"algorithm" = "NLOPT_LD_AUGLAG",
491490
"xtol_rel" = 1.0e-10,
@@ -501,7 +500,7 @@ robyn_allocator <- function(robyn_object = NULL,
501500
eval_g_eq = if (constr_mode == "eq") eval_g_eq_effi else NULL,
502501
eval_g_ineq = if (constr_mode == "ineq") eval_g_eq_effi else NULL,
503502
lb = lb,
504-
ub = rep(total_response, length(ub)),
503+
ub = x0 * channel_constr_up[1], # Large enough, but not infinite (customizable)
505504
opts = list(
506505
"algorithm" = "NLOPT_LD_AUGLAG",
507506
"xtol_rel" = 1.0e-10,

R/R/auxiliary.R

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,20 @@ baseline_vars <- function(InputCollect, baseline_level) {
7777
stopifnot(length(baseline_level) == 1)
7878
stopifnot(baseline_level %in% 0:5)
7979
x <- ""
80-
if (baseline_level >= 1)
80+
if (baseline_level >= 1) {
8181
x <- c(x, "(Intercept)", "intercept")
82-
if (baseline_level >= 2)
82+
}
83+
if (baseline_level >= 2) {
8384
x <- c(x, "trend")
84-
if (baseline_level >= 3)
85+
}
86+
if (baseline_level >= 3) {
8587
x <- unique(c(x, InputCollect$prophet_vars))
86-
if (baseline_level >= 4)
88+
}
89+
if (baseline_level >= 4) {
8790
x <- c(x, InputCollect$context_vars)
88-
if (baseline_level >= 5)
91+
}
92+
if (baseline_level >= 5) {
8993
x <- c(x, InputCollect$organic_vars)
94+
}
9095
return(x)
9196
}

R/R/checks.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,8 @@ check_hyperparameters <- function(hyperparameters = NULL, adstock = NULL,
499499
# Adding penalty variations to the dictionary
500500
if (any(grepl("_penalty", paste0(get_hyp_names)))) {
501501
ref_hyp_name_penalties <- paste0(
502-
c(paid_media_spends, organic_vars, prophet_vars, contextual_vars), "_penalty")
502+
c(paid_media_spends, organic_vars, prophet_vars, contextual_vars), "_penalty"
503+
)
503504
all_ref_names <- c(all_ref_names, ref_hyp_name_penalties)
504505
} else {
505506
ref_hyp_name_penalties <- NULL
@@ -928,7 +929,7 @@ check_metric_dates <- function(date_range = NULL, all_dates, dayInterval = NULL,
928929
# dayInterval >= 30 & dayInterval <= 31 ~ 1,
929930
# ))
930931
# }
931-
date_range = "all"
932+
date_range <- "all"
932933
if (!quiet) message(sprintf("Automatically picked date_range = '%s'", date_range))
933934
}
934935
if (grepl("last|all", date_range[1])) {

R/R/clusters.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ robyn_clusters <- function(input, dep_var_type,
122122
dim_red = dim_red, quiet = TRUE, seed = seed
123123
)
124124
)
125-
cls$df <- group_by(cls$df, .data$cluster) %>% mutate(n = n()) %>% ungroup()
125+
cls$df <- group_by(cls$df, .data$cluster) %>%
126+
mutate(n = n()) %>%
127+
ungroup()
126128

127129
# Select top models by minimum (weighted) distance to zero
128130
all_paid <- setdiff(names(cls$df), c(ignore, "cluster"))

R/R/inputs.R

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ robyn_inputs <- function(dt_input = NULL,
181181
json <- robyn_read(json_file, step = 1, ...)
182182
if (is.null(dt_input)) {
183183
if ("raw_data" %in% names(json[["Extras"]])) {
184-
dt_input <- json[["Extras"]]$raw_data
184+
dt_input <- as_tibble(json[["Extras"]]$raw_data)
185185
} else {
186186
stop("Must provide 'dt_input' input; 'dt_holidays' input optional")
187187
}
@@ -204,7 +204,8 @@ robyn_inputs <- function(dt_input = NULL,
204204
dt_input, dt_holidays,
205205
dep_var, date_var,
206206
context_vars, paid_media_spends,
207-
organic_vars)
207+
organic_vars
208+
)
208209

209210
## Check for NA and all negative values
210211
dt_input <- check_allneg(dt_input)
@@ -254,9 +255,7 @@ robyn_inputs <- function(dt_input = NULL,
254255

255256
## Check window_start & window_end (and transform parameters/data)
256257
windows <- check_windows(dt_input, date_var, all_media, window_start, window_end)
257-
258258
if (TRUE) {
259-
dt_input <- windows$dt_input
260259
window_start <- windows$window_start
261260
rollingWindowStartWhich <- windows$rollingWindowStartWhich
262261
refreshAddedStart <- windows$refreshAddedStart
@@ -283,9 +282,14 @@ robyn_inputs <- function(dt_input = NULL,
283282
check_novar(select(dt_input, -all_of(unused_vars)))
284283

285284
# Calculate total media spend used to model
286-
paid_media_total <- dt_input[
287-
rollingWindowEndWhich:rollingWindowLength, ] %>%
288-
select(paid_media_vars) %>% sum()
285+
paid_media_total <- dt_input %>%
286+
mutate(temp_date = dt_input[[date_var]]) %>%
287+
filter(
288+
.data$temp_date >= window_start,
289+
.data$temp_date <= window_end
290+
) %>%
291+
select(all_of(paid_media_spends)) %>%
292+
sum()
289293

290294
## Collect input
291295
InputCollect <- list(
@@ -320,7 +324,7 @@ robyn_inputs <- function(dt_input = NULL,
320324
window_end = window_end,
321325
rollingWindowEndWhich = rollingWindowEndWhich,
322326
rollingWindowLength = rollingWindowLength,
323-
totalObservations = nrow(dt_input),
327+
totalObservations = nrow(windows$dt_input),
324328
refreshAddedStart = refreshAddedStart,
325329
adstock = adstock,
326330
hyperparameters = hyperparameters,
@@ -411,7 +415,7 @@ print.robyn_inputs <- function(x, ...) {
411415
mod_vars <- paste(setdiff(names(x$dt_mod), c("ds", "dep_var")), collapse = ", ")
412416
print(glued(
413417
"
414-
Total Observations: {nrow(x$dt_input)} ({x$intervalType}s)
418+
Total Observations: {x$totalObservations} ({x$intervalType}s)
415419
Input Table Columns ({ncol(x$dt_input)}):
416420
Date: {x$date_var}
417421
Dependent: {x$dep_var} [{x$dep_var_type}]
@@ -434,8 +438,10 @@ Adstock: {x$adstock}
434438
windows = paste(x$window_start, x$window_end, sep = ":"),
435439
custom_params = if (length(x$custom_params) > 0) paste("\n", flatten_hyps(x$custom_params)) else "None",
436440
prophet = if (length(x$prophet_vars) > 0) {
437-
sprintf("%s on %s", paste(x$prophet_vars, collapse = ", "),
438-
ifelse(!is.null(x$prophet_country), x$prophet_country, "data"))
441+
sprintf(
442+
"%s on %s", paste(x$prophet_vars, collapse = ", "),
443+
ifelse(!is.null(x$prophet_country), x$prophet_country, "data")
444+
)
439445
} else {
440446
"\033[0;31mDeactivated\033[0m"
441447
},
@@ -503,6 +509,7 @@ Adstock: {x$adstock}
503509
#' Accepts "geometric", "weibull_cdf" or "weibull_pdf"
504510
#' @param all_media Character vector. Default to \code{InputCollect$all_media}.
505511
#' Includes \code{InputCollect$paid_media_spends} and \code{InputCollect$organic_vars}.
512+
#' @param all_vars Used to check the penalties inputs, especially for refreshing models.
506513
#' @examples
507514
#' \donttest{
508515
#' media <- c("facebook_S", "print_S", "tv_S")
@@ -540,7 +547,7 @@ Adstock: {x$adstock}
540547
#' }
541548
#' @return Character vector. Names of hyper-parameters that should be defined.
542549
#' @export
543-
hyper_names <- function(adstock, all_media) {
550+
hyper_names <- function(adstock, all_media, all_vars = NULL) {
544551
adstock <- check_adstock(adstock)
545552
if (adstock == "geometric") {
546553
local_name <- sort(apply(expand.grid(all_media, HYPS_NAMES[
@@ -551,6 +558,9 @@ hyper_names <- function(adstock, all_media) {
551558
grepl("shapes|scales|alphas|gammas", HYPS_NAMES)
552559
]), 1, paste, collapse = "_"))
553560
}
561+
if (!is.null(all_vars)) {
562+
local_name <- sort(c(local_name, paste0(all_vars, "_penalty")))
563+
}
554564
return(local_name)
555565
}
556566

R/R/json.R

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ robyn_write <- function(InputCollect,
6161

6262
# ExportedModel JSON
6363
if (!is.null(OutputCollect)) {
64-
6564
# Modeling associated data
6665
collect <- list()
6766
collect$ts_validation <- OutputCollect$OutputModels$ts_validation
@@ -71,7 +70,8 @@ robyn_write <- function(InputCollect,
7170
collect$outputs_time <- sprintf("%s min", attr(OutputCollect, "runTime"))
7271
collect$total_time <- sprintf(
7372
"%s min", attr(OutputCollect, "runTime") +
74-
attr(OutputCollect$OutputModels, "runTime"))
73+
attr(OutputCollect$OutputModels, "runTime")
74+
)
7575
collect$total_iters <- OutputCollect$OutputModels$iterations *
7676
OutputCollect$OutputModels$trials
7777
collect$conv_msg <- gsub("\\:.*", "", OutputCollect$OutputModels$convergence$conv_msg)
@@ -94,11 +94,14 @@ robyn_write <- function(InputCollect,
9494
outputs$performance <- df %>%
9595
filter(.data$rn %in% InputCollect$paid_media_spends) %>%
9696
group_by(.data$solID) %>%
97-
summarise(metric = perf_metric,
98-
performance = ifelse(
99-
perf_metric == "ROAS",
100-
sum(.data$xDecompAgg) / sum(.data$total_spend),
101-
sum(.data$total_spend) / sum(.data$xDecompAgg)), .groups = "drop")
97+
summarise(
98+
metric = perf_metric,
99+
performance = ifelse(
100+
perf_metric == "ROAS",
101+
sum(.data$xDecompAgg) / sum(.data$total_spend),
102+
sum(.data$total_spend) / sum(.data$xDecompAgg)
103+
), .groups = "drop"
104+
)
102105
outputs$summary <- df %>%
103106
mutate(
104107
metric = perf_metric,
@@ -136,7 +139,7 @@ robyn_write <- function(InputCollect,
136139

137140
extras <- list(...)
138141
if (isTRUE(add_data) & !"raw_data" %in% names(extras)) {
139-
extras[["raw_data"]] <- InputCollect$dt_input
142+
extras[["raw_data"]] <- as_tibble(InputCollect$dt_input)
140143
}
141144
if (length(extras) > 0) {
142145
ret[["Extras"]] <- extras
@@ -153,7 +156,8 @@ robyn_write <- function(InputCollect,
153156
if (!all(c("solID", "cluster") %in% names(pareto_df))) {
154157
warning(paste(
155158
"Input 'pareto_df' is not a valid data.frame;",
156-
"must contain 'solID' and 'cluster' columns."))
159+
"must contain 'solID' and 'cluster' columns."
160+
))
157161
} else {
158162
all_c <- unique(pareto_df$cluster)
159163
pareto_df <- lapply(all_c, function(x) {
@@ -192,7 +196,8 @@ print.robyn_write <- function(x, ...) {
192196
"\n\nModel's Performance and Errors:\n {performance}{errors}",
193197
performance = ifelse("performance" %in% names(x$ExportedModel), sprintf(
194198
"Total Model %s = %s\n ",
195-
x$ExportedModel$performance$metric, signif(x$ExportedModel$performance$performance, 4)), ""),
199+
x$ExportedModel$performance$metric, signif(x$ExportedModel$performance$performance, 4)
200+
), ""),
196201
errors = paste(
197202
sprintf(
198203
"Adj.R2 (train): %s",
@@ -204,34 +209,36 @@ print.robyn_write <- function(x, ...) {
204209
)
205210
))
206211

207-
print(glued("\n\nSummary Values on Selected Model:"))
212+
if ("ExportedModel" %in% names(x)) {
213+
print(glued("\n\nSummary Values on Selected Model:"))
208214

209-
print(x$ExportedModel$summary %>%
210-
select(-contains("boot"), -contains("ci_")) %>%
211-
dplyr::rename_at("performance", list(~ ifelse(x$InputCollect$dep_var_type == "revenue", "ROAS", "CPA"))) %>%
212-
mutate(decompPer = formatNum(100 * .data$decompPer, pos = "%")) %>%
213-
dplyr::mutate_if(is.numeric, function(x) ifelse(!is.infinite(x), x, 0)) %>%
214-
dplyr::mutate_if(is.numeric, function(x) formatNum(x, 4, abbr = TRUE)) %>%
215-
replace(., . == "NA", "-") %>% as.data.frame())
215+
print(x$ExportedModel$summary %>%
216+
select(-contains("boot"), -contains("ci_")) %>%
217+
dplyr::rename_at("performance", list(~ ifelse(x$InputCollect$dep_var_type == "revenue", "ROAS", "CPA"))) %>%
218+
mutate(decompPer = formatNum(100 * .data$decompPer, pos = "%")) %>%
219+
dplyr::mutate_if(is.numeric, function(x) ifelse(!is.infinite(x), x, 0)) %>%
220+
dplyr::mutate_if(is.numeric, function(x) formatNum(x, 4, abbr = TRUE)) %>%
221+
replace(., . == "NA", "-") %>% as.data.frame())
216222

217-
print(glued(
218-
"\n\nHyper-parameters:\n Adstock: {x$InputCollect$adstock}"
219-
))
223+
print(glued(
224+
"\n\nHyper-parameters:\n Adstock: {x$InputCollect$adstock}"
225+
))
220226

221-
# Nice and tidy table format for hyper-parameters
222-
HYPS_NAMES <- c(HYPS_NAMES, "penalty")
223-
regex <- paste(paste0("_", HYPS_NAMES), collapse = "|")
224-
hyper_df <- as.data.frame(x$ExportedModel$hyper_values) %>%
225-
select(-contains("lambda"), -any_of(HYPS_OTHERS)) %>%
226-
tidyr::gather() %>%
227-
tidyr::separate(.data$key,
228-
into = c("channel", "none"),
229-
sep = regex, remove = FALSE
230-
) %>%
231-
mutate(hyperparameter = gsub("^.*_", "", .data$key)) %>%
232-
select(.data$channel, .data$hyperparameter, .data$value) %>%
233-
tidyr::spread(key = "hyperparameter", value = "value")
234-
print(hyper_df)
227+
# Nice and tidy table format for hyper-parameters
228+
HYPS_NAMES <- c(HYPS_NAMES, "penalty")
229+
regex <- paste(paste0("_", HYPS_NAMES), collapse = "|")
230+
hyper_df <- as.data.frame(x$ExportedModel$hyper_values) %>%
231+
select(-contains("lambda"), -any_of(HYPS_OTHERS)) %>%
232+
tidyr::gather() %>%
233+
tidyr::separate(.data$key,
234+
into = c("channel", "none"),
235+
sep = regex, remove = FALSE
236+
) %>%
237+
mutate(hyperparameter = gsub("^.*_", "", .data$key)) %>%
238+
select(.data$channel, .data$hyperparameter, .data$value) %>%
239+
tidyr::spread(key = "hyperparameter", value = "value")
240+
print(hyper_df)
241+
}
235242
}
236243

237244

@@ -342,7 +349,9 @@ robyn_recreate <- function(json_file, quiet = FALSE, ...) {
342349
quiet = quiet,
343350
...
344351
)
345-
} else OutputCollect <- NULL
352+
} else {
353+
OutputCollect <- NULL
354+
}
346355
} else {
347356
# Use case: skip feature engineering when InputCollect is provided
348357
InputCollect <- args[["InputCollect"]]
@@ -373,7 +382,8 @@ robyn_chain <- function(json_file) {
373382
temp <- list.files(plot_folder)
374383
mods <- unique(temp[
375384
(startsWith(temp, "RobynModel") | grepl("\\.json+$", temp)) &
376-
grepl("^[^_]*_[^_]*_[^_]*$", temp)])
385+
grepl("^[^_]*_[^_]*_[^_]*$", temp)
386+
])
377387
avlb <- gsub("RobynModel-|\\.json", "", mods)
378388
if (length(ids) == length(mods)) {
379389
chain <- rep_len(chain, length(mods))
@@ -394,7 +404,14 @@ robyn_chain <- function(json_file) {
394404
filename <- mods[avlb == ids[i]]
395405
json_new <- robyn_read(filename, quiet = TRUE)
396406
} else {
397-
message("Skipping chain. File can't be found: ", filename)
407+
last_try <- gsub(chain[1], "", filename)
408+
if (file.exists(last_try)) {
409+
json_new <- robyn_read(last_try, quiet = TRUE)
410+
message("Stored original model in new file: ", filename)
411+
jsonlite::write_json(json_new, filename, pretty = TRUE)
412+
} else {
413+
message("Skipping chain. File can't be found: ", filename)
414+
}
398415
}
399416
}
400417
}

0 commit comments

Comments
 (0)