Skip to content

Commit 86202d5

Browse files
committed
fix: actually evaluating prophet custom parameters #271
1 parent 74dddc4 commit 86202d5

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

R/R/inputs.R

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,12 @@ robyn_engineering <- function(InputCollect, ...) {
590590
#### Obtain prophet trend, seasonality and change-points
591591

592592
if (!is.null(InputCollect$prophet_vars) && length(InputCollect$prophet_vars) > 0) {
593-
args <- list(...)
593+
custom_params <- list(...)
594+
if (length(InputCollect[["custom_params"]]) > 0) {
595+
custom_params <- InputCollect[["custom_params"]]
596+
}
597+
if (length(custom_params) > 0)
598+
message(paste("Using custom prophet parameters:", paste(names(custom_params), collapse = ", ")))
594599
dt_transform <- prophet_decomp(
595600
dt_transform,
596601
dt_holidays = InputCollect$dt_holidays,
@@ -601,8 +606,7 @@ robyn_engineering <- function(InputCollect, ...) {
601606
context_vars = InputCollect$context_vars,
602607
paid_media_vars = paid_media_vars,
603608
intervalType = InputCollect$intervalType,
604-
custom_params = if (length(args) == 0) InputCollect[["custom_params"]] else args,
605-
...
609+
custom_params = custom_params
606610
)
607611
}
608612

@@ -638,12 +642,11 @@ robyn_engineering <- function(InputCollect, ...) {
638642
#' @param paid_media_vars As in \code{robyn_inputs()}
639643
#' @param intervalType As included in \code{InputCollect}
640644
#' @param custom_params List. Custom parameters passed to \code{prophet()}
641-
#' @param ... Additional parameters
642645
#' @return A list containing all prophet decomposition output.
643646
prophet_decomp <- function(dt_transform, dt_holidays,
644647
prophet_country, prophet_vars, prophet_signs,
645648
factor_vars, context_vars, paid_media_vars,
646-
intervalType, custom_params, ...) {
649+
intervalType, custom_params) {
647650
check_prophet(dt_holidays, prophet_country, prophet_vars, prophet_signs)
648651
recurrence <- subset(dt_transform, select = c("ds", "dep_var"))
649652
colnames(recurrence)[2] <- "y"
@@ -655,17 +658,19 @@ prophet_decomp <- function(dt_transform, dt_holidays,
655658
use_weekday <- any(c(str_detect("weekday", prophet_vars), "weekly.seasonality" %in% names(custom_params)))
656659

657660
dt_regressors <- cbind(recurrence, subset(dt_transform, select = c(context_vars, paid_media_vars)))
658-
modelRecurrence <- prophet(
661+
662+
prophet_params <- list(
659663
holidays = if (use_holiday) holidays[country == prophet_country] else NULL,
660664
yearly.seasonality = ifelse("yearly.seasonality" %in% names(custom_params),
661665
custom_params[["yearly.seasonality"]],
662666
use_season),
663667
weekly.seasonality = ifelse("weekly.seasonality" %in% names(custom_params),
664668
custom_params[["weekly.seasonality"]],
665669
use_weekday),
666-
daily.seasonality = FALSE, # No hourly models allowed
667-
...
670+
daily.seasonality = FALSE # No hourly models allowed
668671
)
672+
prophet_params <- append(prophet_params, custom_params)
673+
modelRecurrence <- do.call(prophet, as.list(prophet_params))
669674

670675
if (!is.null(factor_vars) && length(factor_vars) > 0) {
671676
dt_ohe <- as.data.table(model.matrix(y ~ ., dt_regressors[, c("y", factor_vars), with = FALSE]))[, -1]

R/man/prophet_decomp.Rd

Lines changed: 1 addition & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)