55
66# ############ Auxiliary non-exported functions #############
77
8- opts_pnd <- c(" positive" , " negative" , " default" )
9- other_hyps <- c(" lambda" , " train_size" )
10- hyps_name <- c(" thetas" , " shapes" , " scales" , " alphas" , " gammas" )
8+ OPTS_PDN <- c(" positive" , " negative" , " default" )
9+ HYPS_NAMES <- c(" thetas" , " shapes" , " scales" , " alphas" , " gammas" )
10+ HYPS_OTHERS <- c(" lambda" , " train_size" )
11+ LEGACY_PARAMS <- c(" cores" , " iterations" , " trials" , " intercept_sign" , " nevergrad_algo" )
1112
1213check_nas <- function (df ) {
1314 name <- deparse(substitute(df ))
@@ -172,8 +173,8 @@ check_prophet <- function(dt_holidays, prophet_country, prophet_vars, prophet_si
172173 if (is.null(prophet_signs )) {
173174 prophet_signs <- rep(" default" , length(prophet_vars ))
174175 }
175- if (! all(prophet_signs %in% opts_pnd )) {
176- stop(" Allowed values for 'prophet_signs' are: " , paste(opts_pnd , collapse = " , " ))
176+ if (! all(prophet_signs %in% OPTS_PDN )) {
177+ stop(" Allowed values for 'prophet_signs' are: " , paste(OPTS_PDN , collapse = " , " ))
177178 }
178179 if (length(prophet_signs ) != length(prophet_vars )) {
179180 stop(" 'prophet_signs' must have same length as 'prophet_vars'" )
@@ -185,8 +186,8 @@ check_prophet <- function(dt_holidays, prophet_country, prophet_vars, prophet_si
185186check_context <- function (dt_input , context_vars , context_signs ) {
186187 if (! is.null(context_vars )) {
187188 if (is.null(context_signs )) context_signs <- rep(" default" , length(context_vars ))
188- if (! all(context_signs %in% opts_pnd )) {
189- stop(" Allowed values for 'context_signs' are: " , paste(opts_pnd , collapse = " , " ))
189+ if (! all(context_signs %in% OPTS_PDN )) {
190+ stop(" Allowed values for 'context_signs' are: " , paste(OPTS_PDN , collapse = " , " ))
190191 }
191192 if (length(context_signs ) != length(context_vars )) {
192193 stop(" Input 'context_signs' must have same length as 'context_vars'" )
@@ -235,8 +236,8 @@ check_paidmedia <- function(dt_input, paid_media_vars, paid_media_signs, paid_me
235236 if (is.null(paid_media_signs )) {
236237 paid_media_signs <- rep(" positive" , mediaVarCount )
237238 }
238- if (! all(paid_media_signs %in% opts_pnd )) {
239- stop(" Allowed values for 'paid_media_signs' are: " , paste(opts_pnd , collapse = " , " ))
239+ if (! all(paid_media_signs %in% OPTS_PDN )) {
240+ stop(" Allowed values for 'paid_media_signs' are: " , paste(OPTS_PDN , collapse = " , " ))
240241 }
241242 if (length(paid_media_signs ) == 1 ) {
242243 paid_media_signs <- rep(paid_media_signs , length(paid_media_vars ))
@@ -281,8 +282,8 @@ check_organicvars <- function(dt_input, organic_vars, organic_signs) {
281282 organic_signs <- rep(" positive" , length(organic_vars ))
282283 # message("'organic_signs' were not provided. Using 'positive'")
283284 }
284- if (! all(organic_signs %in% opts_pnd )) {
285- stop(" Allowed values for 'organic_signs' are: " , paste(opts_pnd , collapse = " , " ))
285+ if (! all(organic_signs %in% OPTS_PDN )) {
286+ stop(" Allowed values for 'organic_signs' are: " , paste(OPTS_PDN , collapse = " , " ))
286287 }
287288 if (length(organic_signs ) != length(organic_vars )) {
288289 stop(" Input 'organic_signs' must have same length as 'organic_vars'" )
@@ -444,10 +445,10 @@ check_hyperparameters <- function(hyperparameters = NULL, adstock = NULL,
444445 ref_hyp_name_spend <- hyper_names(adstock , all_media = paid_media_spends )
445446 ref_hyp_name_expo <- hyper_names(adstock , all_media = exposure_vars )
446447 ref_hyp_name_org <- hyper_names(adstock , all_media = organic_vars )
447- ref_hyp_name_other <- get_hyp_names [get_hyp_names %in% other_hyps ]
448- # Excluding lambda (first other_hyps ) given its range is not customizable
449- ref_all_media <- sort(c(ref_hyp_name_spend , ref_hyp_name_org , other_hyps ))
450- all_ref_names <- c(ref_hyp_name_spend , ref_hyp_name_expo , ref_hyp_name_org , other_hyps )
448+ ref_hyp_name_other <- get_hyp_names [get_hyp_names %in% HYPS_OTHERS ]
449+ # Excluding lambda (first HYPS_OTHERS ) given its range is not customizable
450+ ref_all_media <- sort(c(ref_hyp_name_spend , ref_hyp_name_org , HYPS_OTHERS ))
451+ all_ref_names <- c(ref_hyp_name_spend , ref_hyp_name_expo , ref_hyp_name_org , HYPS_OTHERS )
451452 all_ref_names <- all_ref_names [order(all_ref_names )]
452453 if (! all(get_hyp_names %in% all_ref_names )) {
453454 wrong_hyp_names <- get_hyp_names [which(! (get_hyp_names %in% all_ref_names ))]
@@ -717,7 +718,7 @@ check_hyper_fixed <- function(InputCollect, dt_hyper_fixed, add_penalty_factor)
717718 # Adstock hyper-parameters
718719 hypParamSamName <- hyper_names(adstock = InputCollect $ adstock , all_media = InputCollect $ all_media )
719720 # Add lambda and other hyper-parameters manually
720- hypParamSamName <- c(hypParamSamName , other_hyps )
721+ hypParamSamName <- c(hypParamSamName , HYPS_OTHERS )
721722 # Add penalty factor hyper-parameters names
722723 if (add_penalty_factor ) {
723724 for_penalty <- names(select(InputCollect $ dt_mod , - .data $ ds , - .data $ dep_var ))
@@ -774,8 +775,7 @@ check_class <- function(x, object) {
774775}
775776
776777check_allocator <- function (OutputCollect , select_model , paid_media_spends , scenario ,
777- channel_constr_low , channel_constr_up ,
778- expected_spend , expected_spend_days , constr_mode ) {
778+ channel_constr_low , channel_constr_up , constr_mode ) {
779779 dt_hyppar <- OutputCollect $ resultHypParam [OutputCollect $ resultHypParam $ solID == select_model , ]
780780 if (! (select_model %in% OutputCollect $ allSolutions )) {
781781 stop(
@@ -792,11 +792,10 @@ check_allocator <- function(OutputCollect, select_model, paid_media_spends, scen
792792 if (any(channel_constr_up > 5 )) {
793793 warning(" Inputs 'channel_constr_up' > 5 might cause unrealistic allocation" )
794794 }
795- opts <- c( " max_historical_response" , " max_response_expected_spend" )
795+ opts <- " max_historical_response" # Deprecated: max_response_expected_spend
796796 if (! (scenario %in% opts )) {
797797 stop(" Input 'scenario' must be one of: " , paste(opts , collapse = " , " ))
798798 }
799-
800799 if (length(channel_constr_low ) != 1 && length(channel_constr_low ) != length(paid_media_spends )) {
801800 stop(paste(
802801 " Input 'channel_constr_low' have to contain either only 1" ,
@@ -809,35 +808,144 @@ check_allocator <- function(OutputCollect, select_model, paid_media_spends, scen
809808 " value or have same length as 'InputCollect$paid_media_spends':" , length(paid_media_spends )
810809 ))
811810 }
812-
813- if (" max_response_expected_spend" %in% scenario ) {
814- if (any(is.null(expected_spend ), is.null(expected_spend_days ))) {
815- stop(" When scenario = 'max_response_expected_spend', expected_spend and expected_spend_days must be provided" )
816- }
817- }
818811 opts <- c(" eq" , " ineq" )
819812 if (! (constr_mode %in% opts )) {
820813 stop(" Input 'constr_mode' must be one of: " , paste(opts , collapse = " , " ))
821814 }
822815}
823816
824- check_metric_value <- function (metric_value , media_metric ) {
817+ check_metric_type <- function (metric_name , paid_media_spends , paid_media_vars , exposure_vars , organic_vars ) {
818+ if (metric_name %in% paid_media_spends && length(metric_name ) == 1 ) {
819+ metric_type <- " spend"
820+ } else if (metric_name %in% exposure_vars && length(metric_name ) == 1 ) {
821+ metric_type <- " exposure"
822+ } else if (metric_name %in% organic_vars && length(metric_name ) == 1 ) {
823+ metric_type <- " organic"
824+ } else {
825+ stop(paste(
826+ " Invalid 'metric_name' input. It must be any media variable from" ,
827+ " paid_media_spends (spend), paid_media_vars (exposure)," ,
828+ " or organic_vars (organic); NOT:" , metric_name ,
829+ paste(" \n - paid_media_spends:" , v2t(paid_media_spends , quotes = FALSE )),
830+ paste(" \n - paid_media_vars:" , v2t(paid_media_vars , quotes = FALSE )),
831+ paste(" \n - organic_vars:" , v2t(organic_vars , quotes = FALSE ))
832+ ))
833+ }
834+ return (metric_type )
835+ }
836+
837+ check_metric_dates <- function (date_range = NULL , all_dates , dayInterval = NULL , quiet = FALSE , is_allocator = FALSE , ... ) {
838+ # # default using latest 30 days / 4 weeks / 1 month for spend level
839+ if (is.null(date_range )) {
840+ if (is.null(dayInterval )) stop(" Input 'date_range' or 'dayInterval' must be defined" )
841+ if (! is_allocator ) {
842+ date_range <- " last_1"
843+ } else {
844+ date_range <- paste0(" last_" , dplyr :: case_when(
845+ dayInterval == 1 ~ 30 ,
846+ dayInterval == 7 ~ 4 ,
847+ dayInterval > = 30 & dayInterval < = 31 ~ 1 ,
848+ ))
849+ }
850+ if (! quiet ) message(sprintf(" Automatically picked date_range = '%s'" , date_range ))
851+ }
852+ if (grepl(" last|all" , date_range [1 ])) {
853+ # # Using last_n as date_range range
854+ if (" all" %in% date_range ) date_range <- paste0(" last_" , length(all_dates ))
855+ get_n <- ifelse(grepl(" _" , date_range [1 ]), as.integer(gsub(" last_" , " " , date_range )), 1 )
856+ date_range <- tail(all_dates , get_n )
857+ date_range_loc <- which(all_dates %in% date_range )
858+ date_range_updated <- all_dates [date_range_loc ]
859+ rg <- v2t(range(date_range_updated ), sep = " :" , quotes = FALSE )
860+ } else {
861+ # # Using dates as date_range range
862+ if (all(is.Date(as.Date(date_range , origin = " 1970-01-01" )))) {
863+ date_range <- as.Date(date_range , origin = " 1970-01-01" )
864+ if (length(date_range ) == 1 ) {
865+ # # Using only 1 date
866+ if (all(date_range %in% all_dates )) {
867+ date_range_updated <- date_range
868+ date_range_loc <- which(all_dates == date_range )
869+ if (! quiet ) message(" Using ds '" , date_range_updated , " ' as the response period" )
870+ } else {
871+ date_range_loc <- which.min(abs(date_range - all_dates ))
872+ date_range_updated <- all_dates [date_range_loc ]
873+ if (! quiet ) warning(" Input 'date_range' (" , date_range , " ) has no match. Picking closest date: " , date_range_updated )
874+ }
875+ } else if (length(date_range ) == 2 ) {
876+ # # Using two dates as "from-to" date range
877+ date_range_loc <- unlist(lapply(date_range , function (x ) which.min(abs(x - all_dates ))))
878+ date_range_loc <- date_range_loc [1 ]: date_range_loc [2 ]
879+ date_range_updated <- all_dates [date_range_loc ]
880+ if (! quiet & ! all(date_range %in% date_range_updated )) {
881+ warning(paste(
882+ " At least one date in 'date_range' input do not match any date." ,
883+ " Picking closest dates for range:" , paste(range(date_range_updated ), collapse = " :" )
884+ ))
885+ }
886+ rg <- v2t(range(date_range_updated ), sep = " :" , quotes = FALSE )
887+ get_n <- length(date_range_loc )
888+ } else {
889+ # # Manually inputting each date
890+ date_range_updated <- date_range
891+ if (all(date_range %in% all_dates )) {
892+ date_range_loc <- which(all_dates %in% date_range_updated )
893+ } else {
894+ date_range_loc <- unlist(lapply(date_range_updated , function (x ) which.min(abs(x - all_dates ))))
895+ rg <- v2t(range(date_range_updated ), sep = " :" , quotes = FALSE )
896+ }
897+ if (all(na.omit(date_range_loc - lag(date_range_loc )) == 1 )) {
898+ date_range_updated <- all_dates [date_range_loc ]
899+ if (! quiet ) warning(" At least one date in 'date_range' do not match ds. Picking closest date: " , date_range_updated )
900+ } else {
901+ stop(" Input 'date_range' needs to have sequential dates" )
902+ }
903+ }
904+ } else {
905+ stop(" Input 'date_range' must have date format '2023-01-01' or use 'last_n'" )
906+ }
907+ }
908+ return (list (
909+ date_range_updated = date_range_updated ,
910+ metric_loc = date_range_loc
911+ ))
912+ }
913+
914+ check_metric_value <- function (metric_value , metric_name , all_values , metric_loc ) {
915+ get_n <- length(metric_loc )
916+ if (any(is.nan(metric_value ))) metric_value <- NULL
825917 if (! is.null(metric_value )) {
826918 if (! is.numeric(metric_value )) {
827919 stop(sprintf(
828- " Input 'metric_value' for %s (%s) must be a numerical value\n " , media_metric , toString(metric_value )
920+ " Input 'metric_value' for %s (%s) must be a numerical value\n " , metric_name , toString(metric_value )
829921 ))
830922 }
831- if (sum (metric_value < = 0 ) > 0 ) {
923+ if (any (metric_value < 0 )) {
832924 stop(sprintf(
833- " Input 'metric_value' for %s (%s) must be a positive value \n " , media_metric , metric_value [ metric_value < = 0 ]
925+ " Input 'metric_value' for %s must be positive\n " , metric_name
834926 ))
835927 }
928+ if (get_n > 1 & length(metric_value ) == 1 ) {
929+ metric_value_updated <- rep(metric_value / get_n , get_n )
930+ # message(paste0("'metric_value'", metric_value, " splitting into ", get_n, " periods evenly"))
931+ } else {
932+ if (length(metric_value ) != get_n ) {
933+ stop(" robyn_response metric_value & date_range must have same length\n " )
934+ }
935+ metric_value_updated <- metric_value
936+ }
836937 }
938+ if (is.null(metric_value )) {
939+ metric_value_updated <- all_values [metric_loc ]
940+ }
941+ all_values_updated <- all_values
942+ all_values_updated [metric_loc ] <- metric_value_updated
943+ return (list (
944+ metric_value_updated = metric_value_updated ,
945+ all_values_updated = all_values_updated
946+ ))
837947}
838948
839- LEGACY_PARAMS <- c(" cores" , " iterations" , " trials" , " intercept_sign" , " nevergrad_algo" )
840-
841949check_legacy_input <- function (InputCollect ,
842950 cores = NULL , iterations = NULL , trials = NULL ,
843951 intercept_sign = NULL , nevergrad_algo = NULL ) {
0 commit comments