Skip to content

Commit 19bf071

Browse files
Merge branch 'main' into allocation_range
2 parents be4fe1a + a7406bf commit 19bf071

File tree

13 files changed

+178
-42
lines changed

13 files changed

+178
-42
lines changed

R/NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Generated by roxygen2: do not edit by hand
22

33
S3method(plot,robyn_allocator)
4+
S3method(plot,robyn_save)
45
S3method(print,robyn_allocator)
56
S3method(print,robyn_inputs)
67
S3method(print,robyn_models)
78
S3method(print,robyn_outputs)
89
S3method(print,robyn_refresh)
10+
S3method(print,robyn_save)
911
export(adstock_geometric)
1012
export(adstock_weibull)
1113
export(hyper_limits)

R/R/clusters.R

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ robyn_clusters <- function(input, all_media = NULL, k = "auto", limit = 1,
112112
if (export) {
113113
fwrite(output$data, file = paste0(path, "pareto_clusters.csv"))
114114
ggsave(paste0(path, "pareto_clusters_wss.png"), plot = output$wss, dpi = 500, width = 5, height = 4)
115-
ggsave(paste0(path, "pareto_clusters_corr.png"), plot = output$corrs, dpi = 500, width = 7, height = 5)
115+
# ggsave(paste0(path, "pareto_clusters_corr.png"), plot = output$corrs, dpi = 500, width = 7, height = 5)
116116
db <- wrap_plots(output$plot_models_rois, output$plot_models_errors)
117117
ggsave(paste0(path, "pareto_clusters_detail.png"), plot = db, dpi = 600, width = 12, height = 9)
118118
}
@@ -121,6 +121,28 @@ robyn_clusters <- function(input, all_media = NULL, k = "auto", limit = 1,
121121

122122
}
123123

124+
# # Mean Media ROI by Cluster
125+
# df %>%
126+
# mutate(cluster = sprintf("Cluster %s", cls$df$cluster)) %>%
127+
# select(-.data$mape, -.data$decomp.rssd, -.data$nrmse, -.data$solID) %>%
128+
# group_by(.data$cluster) %>%
129+
# summarize_all(list(mean)) %>%
130+
# tidyr::pivot_longer(-one_of("cluster"), names_to = "media", values_to = "meanROI") %>%
131+
# ggplot(aes(y = reorder(.data$media, .data$meanROI), x = .data$meanROI)) +
132+
# facet_grid(.data$cluster~.) +
133+
# geom_col() + theme_lares() +
134+
# labs(title = "Mean Media ROI by Cluster",
135+
# x = "(Un-normalized) mean ROI within clsuter", y = NULL)
136+
# df %>%
137+
# mutate(cluster = sprintf("Cluster %s", cls$df$cluster)) %>%
138+
# select(-.data$solID, -.data$mape, -.data$decomp.rssd, -.data$nrmse) %>%
139+
# tidyr::pivot_longer(-one_of("cluster"), names_to = "media", values_to = "roi") %>%
140+
# ggplot(aes(y = reorder(.data$media, .data$roi), x = .data$roi)) +
141+
# facet_grid(.data$cluster~.) +
142+
# geom_boxplot() + theme_lares() +
143+
# labs(title = "Media ROI by Cluster",
144+
# x = "(Un-normalized) ROI", y = NULL)
145+
124146
# ROIs data.frame for clustering (from xDecompAgg or pareto_aggregated.csv)
125147
.prepare_roi <- function(x, all_media) {
126148
check_opts(all_media, unique(x$rn))
@@ -171,7 +193,7 @@ robyn_clusters <- function(input, all_media = NULL, k = "auto", limit = 1,
171193
balance <- balance / sum(balance)
172194
left_join(df, select(top_sols, 1:3), "solID") %>%
173195
mutate(
174-
alpha = ifelse(is.na(.data$cluster), 0.5, 1),
196+
alpha = ifelse(is.na(.data$cluster), 0.6, 1),
175197
label = ifelse(!is.na(.data$cluster), sprintf(
176198
"[%s.%s]", .data$cluster, .data$rank
177199
), NA)

R/R/inputs.R

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ robyn_inputs <- function(dt_input = NULL,
287287
adstock = adstock,
288288
hyperparameters = hyperparameters,
289289
calibration_input = calibration_input,
290-
...
290+
custom_params = list(...)
291291
)
292292

293293
if (!is.null(hyperparameters)) {
@@ -323,7 +323,6 @@ robyn_inputs <- function(dt_input = NULL,
323323
output <- robyn_engineering(InputCollect, ...)
324324
}
325325
}
326-
output$custom_params <- list(...)
327326
class(output) <- c("robyn_inputs", class(output))
328327
return(output)
329328
}
@@ -626,10 +625,9 @@ robyn_engineering <- function(x, ...) {
626625
#### Obtain prophet trend, seasonality and change-points
627626

628627
if (!is.null(InputCollect$prophet_vars) && length(InputCollect$prophet_vars) > 0) {
629-
custom_params <- list(...) # custom_params <- list()
630628
if (length(InputCollect[["custom_params"]]) > 0) {
631629
custom_params <- InputCollect[["custom_params"]]
632-
}
630+
} else custom_params <- list(...) # custom_params <- list()
633631
robyn_args <- setdiff(
634632
unique(c(names(as.list(args(robyn_run))),
635633
names(as.list(args(robyn_outputs))),
@@ -638,7 +636,7 @@ robyn_engineering <- function(x, ...) {
638636
c("", "..."))
639637
prophet_custom_args <- setdiff(names(custom_params), robyn_args)
640638
if (length(prophet_custom_args)>0)
641-
message(paste("Using custom prophet parameters:", paste(names(prophet_custom_args), collapse = ", ")))
639+
message(paste("Using custom prophet parameters:", paste(prophet_custom_args, collapse = ", ")))
642640
dt_transform <- prophet_decomp(
643641
dt_transform,
644642
dt_holidays = InputCollect$dt_holidays,
@@ -865,7 +863,7 @@ set_holidays <- function(dt_transform, dt_holidays, intervalType) {
865863
if (intervalType == "week") {
866864
weekStartInput <- lubridate::wday(dt_transform$ds[1], week_start = 1)
867865
if (!weekStartInput %in% c(1, 7)) stop("Week start has to be Monday or Sunday")
868-
dt_holidays$dsWeekStart <- floor_date(dt_holidays$ds, unit = "week", week_start = 1)
866+
dt_holidays$dsWeekStart <- floor_date(dt_holidays$ds, unit = "week", week_start = weekStartInput)
869867
holidays <- dt_holidays[, .(ds = dsWeekStart, holiday, country, year)]
870868
holidays <- holidays[, lapply(.SD, paste0, collapse = "#"), by = c("ds", "country", "year"), .SDcols = "holiday"]
871869
}

R/R/model.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -512,11 +512,11 @@ robyn_mmm <- function(InputCollect,
512512
} else if (adstock == "weibull_cdf") {
513513
shape <- hypParamSam[paste0(all_media[v], "_shapes")]
514514
scale <- hypParamSam[paste0(all_media[v], "_scales")]
515-
x_list <- adstock_weibull(x = m, shape = shape, scale = scale, windlen = rollingWindowLength, type = "cdf")
515+
x_list <- adstock_weibull(x = m, shape = shape, scale = scale, type = "cdf")
516516
} else if (adstock == "weibull_pdf") {
517517
shape <- hypParamSam[paste0(all_media[v], "_shapes")]
518518
scale <- hypParamSam[paste0(all_media[v], "_scales")]
519-
x_list <- adstock_weibull(x = m, shape = shape, scale = scale, windlen = rollingWindowLength, type = "pdf")
519+
x_list <- adstock_weibull(x = m, shape = shape, scale = scale, type = "pdf")
520520
}
521521
m_adstocked <- x_list$x_decayed
522522
mediaAdstocked[[v]] <- m_adstocked
@@ -1075,11 +1075,11 @@ robyn_response <- function(robyn_object = NULL,
10751075
} else if (adstock == "weibull_cdf") {
10761076
shape <- dt_hyppar[solID == select_model, get(paste0(hpm_name, "_shapes"))]
10771077
scale <- dt_hyppar[solID == select_model, get(paste0(hpm_name, "_scales"))]
1078-
x_list <- adstock_weibull(x = media_vec, shape = shape, scale = scale, windlen = InputCollect$rollingWindowLength, type = "cdf")
1078+
x_list <- adstock_weibull(x = media_vec, shape = shape, scale = scale, type = "cdf")
10791079
} else if (adstock == "weibull_pdf") {
10801080
shape <- dt_hyppar[solID == select_model, get(paste0(hpm_name, "_shapes"))]
10811081
scale <- dt_hyppar[solID == select_model, get(paste0(hpm_name, "_scales"))]
1082-
x_list <- adstock_weibull(x = media_vec, shape = shape, scale = scale, windlen = InputCollect$rollingWindowLength, type = "pdf")
1082+
x_list <- adstock_weibull(x = media_vec, shape = shape, scale = scale, type = "pdf")
10831083
}
10841084
m_adstocked <- x_list$x_decayed
10851085

R/R/pareto.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,11 @@ robyn_pareto <- function(InputCollect, OutputModels, pareto_fronts, calibration_
189189
} else if (InputCollect$adstock == "weibull_cdf") {
190190
shape <- hypParam[paste0(InputCollect$all_media[med], "_shapes")]
191191
scale <- hypParam[paste0(InputCollect$all_media[med], "_scales")]
192-
x_list <- adstock_weibull(x = m, shape = shape, scale = scale, windlen = InputCollect$rollingWindowLength, type = "cdf")
192+
x_list <- adstock_weibull(x = m, shape = shape, scale = scale, type = "cdf")
193193
} else if (InputCollect$adstock == "weibull_pdf") {
194194
shape <- hypParam[paste0(InputCollect$all_media[med], "_shapes")]
195195
scale <- hypParam[paste0(InputCollect$all_media[med], "_scales")]
196-
x_list <- adstock_weibull(x = m, shape = shape, scale = scale, windlen = InputCollect$rollingWindowLength, type = "pdf")
196+
x_list <- adstock_weibull(x = m, shape = shape, scale = scale, type = "pdf")
197197
}
198198
m_adstocked <- x_list$x_decayed
199199
dt_transformAdstock[, (med_select) := m_adstocked]

R/R/plots.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,13 +245,13 @@ robyn_onepagers <- function(InputCollect, OutputCollect, select_model = NULL, qu
245245
all_plots <- list()
246246
cnt <- 0
247247

248-
for (pf in pareto_fronts_vec) { # pf = 1
248+
for (pf in pareto_fronts_vec) { # pf = pareto_fronts_vec[1]
249249

250250
plotMediaShare <- xDecompAgg[robynPareto == pf & rn %in% InputCollect$paid_media_spends]
251251
uniqueSol <- plotMediaShare[, unique(solID)]
252252

253253
# parallelResult <- for (sid in uniqueSol) { # sid = uniqueSol[1]
254-
parallelResult <- foreach(sid = uniqueSol) %dorng% {
254+
parallelResult <- foreach(sid = uniqueSol) %dorng% { # sid = uniqueSol[1]
255255
plotMediaShareLoop <- plotMediaShare[solID == sid]
256256
rsq_train_plot <- plotMediaShareLoop[, round(unique(rsq_train), 4)]
257257
nrmse_plot <- plotMediaShareLoop[, round(unique(nrmse), 4)]
@@ -429,6 +429,7 @@ robyn_onepagers <- function(InputCollect, OutputCollect, select_model = NULL, qu
429429
cnt <- cnt + 1
430430
setTxtProgressBar(pbplot, cnt)
431431
}
432+
return(all_plots)
432433
}
433434
if (!quiet & count_mod_out > 0) {
434435
cnt <- cnt + length(uniqueSol)
@@ -438,7 +439,7 @@ robyn_onepagers <- function(InputCollect, OutputCollect, select_model = NULL, qu
438439
if (!quiet & count_mod_out > 0) close(pbplot)
439440
# Stop cluster to avoid memory leaks
440441
if (check_parallel_plot()) stopImplicitCluster()
441-
return(invisible(all_plots))
442+
return(invisible(parallelResult[[1]]))
442443
}
443444

444445
allocation_plots <- function(InputCollect, OutputCollect, dt_optimOut, select_model, scenario, export = TRUE, quiet = FALSE) {

R/R/refresh.R

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#' Use \code{robyn_save()} to select and save as .RDS file the initial model.
1010
#'
1111
#' @inheritParams robyn_allocator
12-
#' @return (Invisible) file's name.
12+
#' @return (Invisible) list with filename and summary.
1313
#' @examples
1414
#' \dontrun{
1515
#' # Get model IDs from OutputCollect
@@ -24,7 +24,11 @@
2424
#' )
2525
#' }
2626
#' @export
27-
robyn_save <- function(robyn_object, InputCollect, OutputCollect, select_model = NULL) {
27+
robyn_save <- function(robyn_object,
28+
select_model,
29+
InputCollect,
30+
OutputCollect,
31+
quiet = FALSE) {
2832
check_robyn_object(robyn_object)
2933
if (is.null(select_model)) select_model <- OutputCollect[["selectID"]]
3034
if (!(select_model %in% OutputCollect$resultHypParam$solID)) {
@@ -34,10 +38,23 @@ robyn_save <- function(robyn_object, InputCollect, OutputCollect, select_model =
3438
)))
3539
}
3640

41+
output <- list(
42+
robyn_object = robyn_object,
43+
select_model = select_model,
44+
summary = OutputCollect$xDecompAgg[
45+
solID == select_model & !is.na(mean_spend)
46+
, .(rn, coef,mean_spend, mean_response, roi_mean
47+
, total_spend, total_response = xDecompAgg, roi_total)],
48+
plot = robyn_onepagers(InputCollect, OutputCollect, select_model, quiet = TRUE, export = FALSE))
49+
class(output) <- c("robyn_save", class(output))
50+
3751
if (file.exists(robyn_object)) {
38-
answer <- askYesNo(paste0(robyn_object, " already exists. Are you certain to overwrite it?"))
52+
if (!quiet) {
53+
answer <- askYesNo(paste0(robyn_object, " already exists. Are you certain to overwrite it?"))
54+
} else answer <- TRUE
3955
if (answer == FALSE | is.na(answer)) {
40-
stop("stopped")
56+
message("Stopped export to avoid overwriting")
57+
return(invisible(output))
4158
}
4259
}
4360

@@ -50,11 +67,33 @@ robyn_save <- function(robyn_object, InputCollect, OutputCollect, select_model =
5067
InputCollect$refreshCounter <- 0
5168
listInit <- list(OutputCollect = OutputCollect, InputCollect = InputCollect)
5269
Robyn <- list(listInit = listInit)
70+
5371
saveRDS(Robyn, file = robyn_object)
72+
if (!quiet) message("Exported results: ", robyn_object)
73+
return(invisible(output))
74+
}
5475

55-
return(invisible(robyn_object))
76+
#' @rdname robyn_save
77+
#' @aliases robyn_save
78+
#' @param x \code{robyn_save()} output.
79+
#' @export
80+
print.robyn_save <- function(x, ...) {
81+
print(glued(
82+
"
83+
Exported file: {x$robyn_object}
84+
Exported model: {x$select_model}
85+
86+
Media Summary for Selected Model:
87+
"))
88+
print(x$summary)
5689
}
5790

91+
#' @rdname robyn_save
92+
#' @aliases robyn_save
93+
#' @param x \code{robyn_save()} output.
94+
#' @export
95+
plot.robyn_save <- function(x, ...) plot(x$plot[[1]], ...)
96+
5897

5998
####################################################################
6099
#' Build Refresh Model

R/man/robyn_save.Rd

Lines changed: 19 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)