@@ -590,7 +590,12 @@ robyn_engineering <- function(InputCollect, ...) {
590590 # ### Obtain prophet trend, seasonality and change-points
591591
592592 if (! is.null(InputCollect $ prophet_vars ) && length(InputCollect $ prophet_vars ) > 0 ) {
593- args <- list (... )
593+ custom_params <- list (... )
594+ if (length(InputCollect [[" custom_params" ]]) > 0 ) {
595+ custom_params <- InputCollect [[" custom_params" ]]
596+ }
597+ if (length(custom_params ) > 0 )
598+ message(paste(" Using custom prophet parameters:" , paste(names(custom_params ), collapse = " , " )))
594599 dt_transform <- prophet_decomp(
595600 dt_transform ,
596601 dt_holidays = InputCollect $ dt_holidays ,
@@ -601,8 +606,7 @@ robyn_engineering <- function(InputCollect, ...) {
601606 context_vars = InputCollect $ context_vars ,
602607 paid_media_vars = paid_media_vars ,
603608 intervalType = InputCollect $ intervalType ,
604- custom_params = if (length(args ) == 0 ) InputCollect [[" custom_params" ]] else args ,
605- ...
609+ custom_params = custom_params
606610 )
607611 }
608612
@@ -638,12 +642,11 @@ robyn_engineering <- function(InputCollect, ...) {
638642# ' @param paid_media_vars As in \code{robyn_inputs()}
639643# ' @param intervalType As included in \code{InputCollect}
640644# ' @param custom_params List. Custom parameters passed to \code{prophet()}
641- # ' @param ... Additional parameters
642645# ' @return A list containing all prophet decomposition output.
643646prophet_decomp <- function (dt_transform , dt_holidays ,
644647 prophet_country , prophet_vars , prophet_signs ,
645648 factor_vars , context_vars , paid_media_vars ,
646- intervalType , custom_params , ... ) {
649+ intervalType , custom_params ) {
647650 check_prophet(dt_holidays , prophet_country , prophet_vars , prophet_signs )
648651 recurrence <- subset(dt_transform , select = c(" ds" , " dep_var" ))
649652 colnames(recurrence )[2 ] <- " y"
@@ -655,17 +658,19 @@ prophet_decomp <- function(dt_transform, dt_holidays,
655658 use_weekday <- any(c(str_detect(" weekday" , prophet_vars ), " weekly.seasonality" %in% names(custom_params )))
656659
657660 dt_regressors <- cbind(recurrence , subset(dt_transform , select = c(context_vars , paid_media_vars )))
658- modelRecurrence <- prophet(
661+
662+ prophet_params <- list (
659663 holidays = if (use_holiday ) holidays [country == prophet_country ] else NULL ,
660664 yearly.seasonality = ifelse(" yearly.seasonality" %in% names(custom_params ),
661665 custom_params [[" yearly.seasonality" ]],
662666 use_season ),
663667 weekly.seasonality = ifelse(" weekly.seasonality" %in% names(custom_params ),
664668 custom_params [[" weekly.seasonality" ]],
665669 use_weekday ),
666- daily.seasonality = FALSE , # No hourly models allowed
667- ...
670+ daily.seasonality = FALSE # No hourly models allowed
668671 )
672+ prophet_params <- append(prophet_params , custom_params )
673+ modelRecurrence <- do.call(prophet , as.list(prophet_params ))
669674
670675 if (! is.null(factor_vars ) && length(factor_vars ) > 0 ) {
671676 dt_ohe <- as.data.table(model.matrix(y ~ . , dt_regressors [, c(" y" , factor_vars ), with = FALSE ]))[, - 1 ]
0 commit comments