|
6 | 6 | ############# Auxiliary non-exported functions ############# |
7 | 7 |
|
8 | 8 | opts_pnd <- c("positive", "negative", "default") |
| 9 | +other_hyps <- c("lambda", "train_size") |
| 10 | +hyps_name <- c("thetas", "shapes", "scales", "alphas", "gammas") |
9 | 11 |
|
10 | 12 | check_nas <- function(df) { |
11 | 13 | name <- deparse(substitute(df)) |
@@ -151,7 +153,7 @@ check_prophet <- function(dt_holidays, prophet_country, prophet_vars, prophet_si |
151 | 153 | if (is.null(dt_holidays) || is.null(prophet_vars)) { |
152 | 154 | return(invisible(NULL)) |
153 | 155 | } else { |
154 | | - opts <- c("trend", "season", "weekday", "holiday") |
| 156 | + opts <- c("trend", "season", "monthly", "weekday", "holiday") |
155 | 157 | if (!all(prophet_vars %in% opts)) { |
156 | 158 | stop("Allowed values for 'prophet_vars' are: ", paste(opts, collapse = ", ")) |
157 | 159 | } |
@@ -426,7 +428,10 @@ check_hyperparameters <- function(hyperparameters = NULL, adstock = NULL, |
426 | 428 | "robyn_inputs(InputCollect = InputCollect, hyperparameters = ...)" |
427 | 429 | )) |
428 | 430 | } else { |
429 | | - hyperparameters <- hyperparameters[which(!names(hyperparameters) %in% "lambda")] |
| 431 | + # Non-adstock hyperparameters check |
| 432 | + check_train_size(hyperparameters) |
| 433 | + # Adstock hyperparameters check |
| 434 | + hyperparameters <- hyperparameters[which(!names(hyperparameters) %in% other_hyps)] |
430 | 435 | hyperparameters_ordered <- hyperparameters[order(names(hyperparameters))] |
431 | 436 | get_hyp_names <- names(hyperparameters_ordered) |
432 | 437 | ref_hyp_name_spend <- hyper_names(adstock, all_media = paid_media_spends) |
@@ -470,6 +475,17 @@ check_hyperparameters <- function(hyperparameters = NULL, adstock = NULL, |
470 | 475 | } |
471 | 476 | } |
472 | 477 |
|
| 478 | +check_train_size <- function(hyps) { |
| 479 | + if ("train_size" %in% names(hyps)) { |
| 480 | + if (!length(hyps$train_size) %in% 1:2) { |
| 481 | + stop("Hyperparameter 'train_size' must be length 1 (fixed) or 2 (range)") |
| 482 | + } |
| 483 | + if (any(hyps$train_size <= 0.1) || any(hyps$train_size > 1)) { |
| 484 | + stop("Hyperparameter 'train_size' values must be defined between 0.1 and 1") |
| 485 | + } |
| 486 | + } |
| 487 | +} |
| 488 | + |
473 | 489 | check_hyper_limits <- function(hyperparameters, hyper) { |
474 | 490 | hyper_which <- which(endsWith(names(hyperparameters), hyper)) |
475 | 491 | if (length(hyper_which) == 0) { |
@@ -692,8 +708,8 @@ check_hyper_fixed <- function(InputCollect, dt_hyper_fixed, add_penalty_factor) |
692 | 708 | hyper_fixed <- !is.null(dt_hyper_fixed) |
693 | 709 | # Adstock hyper-parameters |
694 | 710 | hypParamSamName <- hyper_names(adstock = InputCollect$adstock, all_media = InputCollect$all_media) |
695 | | - # Add lambda hyper-parameter |
696 | | - hypParamSamName <- c(hypParamSamName, "lambda") |
| 711 | + # Add lambda and other hyper-parameters manually |
| 712 | + hypParamSamName <- c(hypParamSamName, other_hyps) |
697 | 713 | # Add penalty factor hyper-parameters names |
698 | 714 | if (add_penalty_factor) { |
699 | 715 | for_penalty <- names(select(InputCollect$dt_mod, -.data$ds, -.data$dep_var)) |
@@ -734,10 +750,14 @@ check_init_msg <- function(InputCollect, cores) { |
734 | 750 | "Using", InputCollect$adstock, "adstocking with", |
735 | 751 | length(InputCollect$hyper_updated), "hyperparameters", det |
736 | 752 | ) |
737 | | - if (check_parallel()) { |
738 | | - message(paste(base, "on", cores, "cores")) |
| 753 | + if (cores == 1) { |
| 754 | + message(paste(base, "with no parallel computation")) |
739 | 755 | } else { |
740 | | - message(paste(base, "on 1 core (Windows fallback)")) |
| 756 | + if (check_parallel()) { |
| 757 | + message(paste(base, "on", cores, "cores")) |
| 758 | + } else { |
| 759 | + message(paste(base, "on 1 core (Windows fallback)")) |
| 760 | + } |
741 | 761 | } |
742 | 762 | } |
743 | 763 |
|
|
0 commit comments