Skip to content

Commit 61b8b28

Browse files
committed
styler for all scripts
1 parent 8758fc8 commit 61b8b28

File tree

10 files changed

+91
-63
lines changed

10 files changed

+91
-63
lines changed

R/R/allocator.R

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ robyn_allocator <- function(robyn_object = NULL,
108108
quiet = FALSE,
109109
ui = FALSE,
110110
...) {
111-
112111
### Use previously exported model using json_file
113112
if (!is.null(json_file)) {
114113
if (is.null(InputCollect)) {
@@ -141,12 +140,12 @@ robyn_allocator <- function(robyn_object = NULL,
141140
# OutputCollect <- imported$OutputCollect
142141
# select_model <- imported$select_model
143142
# } else {
144-
if (is.null(select_model) && length(OutputCollect$allSolutions == 1)) {
145-
select_model <- OutputCollect$allSolutions
146-
}
147-
if (any(is.null(InputCollect), is.null(OutputCollect), is.null(select_model))) {
148-
stop("When 'robyn_object' is not provided, then InputCollect, OutputCollect, select_model must be provided")
149-
}
143+
if (is.null(select_model) && length(OutputCollect$allSolutions == 1)) {
144+
select_model <- OutputCollect$allSolutions
145+
}
146+
if (any(is.null(InputCollect), is.null(OutputCollect), is.null(select_model))) {
147+
stop("When 'robyn_object' is not provided, then InputCollect, OutputCollect, select_model must be provided")
148+
}
150149
# }
151150

152151
if (length(InputCollect$paid_media_spends) <= 1) {
@@ -238,9 +237,12 @@ robyn_allocator <- function(robyn_object = NULL,
238237
simulation_period <- initial_mean_period <- unlist(summarise_all(select(histFiltered, any_of(mediaSpendSorted)), length))
239238
nDates <- lapply(mediaSpendSorted, function(x) histFiltered$ds)
240239
names(nDates) <- mediaSpendSorted
241-
if (!quiet) message(sprintf(
242-
"Date Window: %s:%s (%s %ss)",
243-
date_min, date_max, unique(initial_mean_period), InputCollect$intervalType))
240+
if (!quiet) {
241+
message(sprintf(
242+
"Date Window: %s:%s (%s %ss)",
243+
date_min, date_max, unique(initial_mean_period), InputCollect$intervalType
244+
))
245+
}
244246
zero_spend_channel <- names(histSpendWindow[histSpendWindow == 0])
245247

246248
initSpendUnitTotal <- sum(initSpendUnit)
@@ -359,14 +361,19 @@ robyn_allocator <- function(robyn_object = NULL,
359361
skip_these <- (channel_constr_low == 0 & channel_constr_up == 0)
360362
zero_constraint_channel <- mediaSpendSorted[skip_these]
361363
if (any(skip_these) && !quiet) {
362-
message("Excluded variables (constrained to 0): ",
363-
paste(zero_constraint_channel, collapse = ", "))
364+
message(
365+
"Excluded variables (constrained to 0): ",
366+
paste(zero_constraint_channel, collapse = ", ")
367+
)
364368
}
365369
if (!all(coefSelectorSorted)) {
366370
zero_coef_channel <- setdiff(names(coefSelectorSorted), mediaSpendSorted[coefSelectorSorted])
367-
if (!quiet) message(
368-
"Excluded variables (coefficients are 0): ",
369-
paste(zero_coef_channel, collapse = ", "))
371+
if (!quiet) {
372+
message(
373+
"Excluded variables (coefficients are 0): ",
374+
paste(zero_coef_channel, collapse = ", ")
375+
)
376+
}
370377
} else {
371378
zero_coef_channel <- as.character()
372379
}
@@ -754,7 +761,9 @@ robyn_allocator <- function(robyn_object = NULL,
754761
select_model, scenario, eval_list,
755762
export, plot_folder, quiet
756763
)
757-
} else plots <- NULL
764+
} else {
765+
plots <- NULL
766+
}
758767

759768
output <- list(
760769
dt_optimOut = dt_optimOut,

R/R/checks.R

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,14 @@ check_novar <- function(dt_input, InputCollect = NULL) {
3838
"There are %s column(s) with no-variance: %s. \nPlease, remove the variable(s) to proceed...",
3939
length(novar), v2t(novar)
4040
)
41-
if (!is.null(InputCollect)) msg <- sprintf(
42-
"%s\n>>> Note: there's no variance on these variables because of the modeling window filter (%s:%s)",
43-
msg,
44-
InputCollect$window_start,
45-
InputCollect$window_end
46-
)
41+
if (!is.null(InputCollect)) {
42+
msg <- sprintf(
43+
"%s\n>>> Note: there's no variance on these variables because of the modeling window filter (%s:%s)",
44+
msg,
45+
InputCollect$window_start,
46+
InputCollect$window_end
47+
)
48+
}
4749
stop(msg)
4850
}
4951
}
@@ -164,9 +166,12 @@ check_prophet <- function(dt_holidays, prophet_country, prophet_vars, prophet_si
164166
prophet_vars <- tolower(prophet_vars)
165167
opts <- c("trend", "season", "monthly", "weekday", "holiday")
166168
if (!"holiday" %in% prophet_vars) {
167-
if (!is.null(prophet_country)) warning(paste(
168-
"Input 'prophet_country' is defined as", prophet_country,
169-
"but 'holiday' is not setup within 'prophet_vars' parameter"))
169+
if (!is.null(prophet_country)) {
170+
warning(paste(
171+
"Input 'prophet_country' is defined as", prophet_country,
172+
"but 'holiday' is not setup within 'prophet_vars' parameter"
173+
))
174+
}
170175
prophet_country <- NULL
171176
}
172177
if (!all(prophet_vars %in% opts)) {
@@ -177,7 +182,7 @@ check_prophet <- function(dt_holidays, prophet_country, prophet_vars, prophet_si
177182
}
178183
if ("holiday" %in% prophet_vars && (
179184
is.null(prophet_country) || length(prophet_country) > 1 |
180-
isTRUE(!prophet_country %in% unique(dt_holidays$country)))) {
185+
isTRUE(!prophet_country %in% unique(dt_holidays$country)))) {
181186
stop(paste(
182187
"You must provide 1 country code in 'prophet_country' input.",
183188
length(unique(dt_holidays$country)), "countries are included:",
@@ -651,16 +656,16 @@ check_calibration <- function(dt_input, date_var, calibration_input, dayInterval
651656

652657
check_obj_weight <- function(calibration_input, objective_weights, refresh) {
653658
obj_len <- ifelse(is.null(calibration_input), 2, 3)
654-
if(!is.null(objective_weights)) {
655-
if((length(objective_weights) != obj_len)) {
659+
if (!is.null(objective_weights)) {
660+
if ((length(objective_weights) != obj_len)) {
656661
stop(paste0("objective_weights must have length of ", obj_len))
657662
}
658-
if(any(objective_weights < 0) | any(objective_weights > 10)) {
663+
if (any(objective_weights < 0) | any(objective_weights > 10)) {
659664
stop("objective_weights must be >= 0 & <= 10")
660665
}
661666
}
662-
if(is.null(objective_weights) & refresh) {
663-
if(obj_len == 2) {
667+
if (is.null(objective_weights) & refresh) {
668+
if (obj_len == 2) {
664669
objective_weights <- c(1, 10)
665670
} else {
666671
objective_weights <- c(1, 10, 10)

R/R/clusters.R

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,10 @@ robyn_clusters <- function(input, dep_var_type,
7777
{
7878
suppressMessages(
7979
clusterKmeans(df,
80-
k = NULL, limit = limit_clusters, ignore = ignore,
81-
dim_red = dim_red, quiet = TRUE, seed = seed
82-
))
80+
k = NULL, limit = limit_clusters, ignore = ignore,
81+
dim_red = dim_red, quiet = TRUE, seed = seed
82+
)
83+
)
8384
},
8485
error = function(err) {
8586
message(paste("Couldn't automatically create clusters:", err))
@@ -109,8 +110,10 @@ robyn_clusters <- function(input, dep_var_type,
109110
stopifnot(k %in% min_clusters:30)
110111
suppressMessages(
111112
cls <- clusterKmeans(
112-
df, k = k, limit = limit_clusters, ignore = ignore,
113-
dim_red = dim_red, quiet = TRUE, seed = seed)
113+
df,
114+
k = k, limit = limit_clusters, ignore = ignore,
115+
dim_red = dim_red, quiet = TRUE, seed = seed
116+
)
114117
)
115118

116119
# Select top models by minimum (weighted) distance to zero
@@ -181,8 +184,9 @@ confidence_calcs <- function(
181184
if (length(unique(df_outcome$solID)) < 3) {
182185
warning(paste("Cluster", j, "does not contain enough models to calculate CI"))
183186
} else {
184-
if (cluster_by == "hyperparameters")
187+
if (cluster_by == "hyperparameters") {
185188
all_paid <- unique(gsub(paste(paste0("_", HYPS_NAMES), collapse = "|"), "", all_paid))
189+
}
186190
for (i in all_paid) {
187191
# Bootstrap CI
188192
if (dep_var_type == "conversion") {
@@ -317,7 +321,8 @@ errors_scores <- function(df, balance = rep(1, 3), ts_validation = TRUE, ...) {
317321
if (cluster_by == "hyperparameters") {
318322
outcome <- select(
319323
x, .data$solID, contains(HYPS_NAMES),
320-
contains(c("nrmse", "decomp.rssd", "mape"))) %>%
324+
contains(c("nrmse", "decomp.rssd", "mape"))
325+
) %>%
321326
removenacols(all = FALSE)
322327
}
323328
}

R/R/convergence.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ robyn_converge <- function(OutputModels,
148148
moo_cloud_plot <- df %>%
149149
mutate(nrmse = lares::winsorize(.data$nrmse, nrmse_win)) %>%
150150
ggplot(aes(
151-
x = .data$nrmse, y = .data$decomp.rssd, colour = .data$ElapsedAccum
152-
)) +
151+
x = .data$nrmse, y = .data$decomp.rssd, colour = .data$ElapsedAccum
152+
)) +
153153
scale_colour_gradient(low = "skyblue", high = "navyblue") +
154154
labs(
155155
title = ifelse(!calibrated, "Multi-objective evolutionary performance",

R/R/inputs.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,10 @@ robyn_inputs <- function(dt_input = NULL,
361361
# Check for no-variance columns (after filtering modeling window)
362362
dt_mod_model_window <- InputCollect$dt_mod %>%
363363
select(-any_of(InputCollect$unused_vars)) %>%
364-
filter(.data$ds >= InputCollect$window_start,
365-
.data$ds <= InputCollect$window_end)
364+
filter(
365+
.data$ds >= InputCollect$window_start,
366+
.data$ds <= InputCollect$window_end
367+
)
366368
check_novar(dt_mod_model_window, InputCollect)
367369
}
368370

R/R/json.R

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ robyn_write <- function(InputCollect,
6464
outputs_time <- sprintf("%s min", attr(OutputCollect, "runTime"))
6565
total_time <- sprintf(
6666
"%s min",
67-
attr(OutputCollect, "runTime") + attr(OutputCollect$OutputModels, "runTime"))
67+
attr(OutputCollect, "runTime") + attr(OutputCollect$OutputModels, "runTime")
68+
)
6869
if (!is.null(OutputCollect)) {
6970
outputs <- list()
7071
outputs$select_model <- select_model
@@ -160,12 +161,12 @@ print.robyn_write <- function(x, ...) {
160161
print(glued("\n\nSummary Values on Selected Model:"))
161162

162163
print(x$ExportedModel$summary %>%
163-
select(-contains("boot"), -contains("ci_")) %>%
164-
dplyr::rename_at("performance", list(~ ifelse(x$InputCollect$dep_var_type == "revenue", "ROI", "CPA"))) %>%
165-
mutate(decompPer = formatNum(100 * .data$decompPer, pos = "%")) %>%
166-
dplyr::mutate_if(is.numeric, function(x) ifelse(!is.infinite(x), x, 0)) %>%
167-
dplyr::mutate_if(is.numeric, function(x) formatNum(x, 4, abbr = TRUE)) %>%
168-
replace(., . == "NA", "-") %>% as.data.frame())
164+
select(-contains("boot"), -contains("ci_")) %>%
165+
dplyr::rename_at("performance", list(~ ifelse(x$InputCollect$dep_var_type == "revenue", "ROI", "CPA"))) %>%
166+
mutate(decompPer = formatNum(100 * .data$decompPer, pos = "%")) %>%
167+
dplyr::mutate_if(is.numeric, function(x) ifelse(!is.infinite(x), x, 0)) %>%
168+
dplyr::mutate_if(is.numeric, function(x) formatNum(x, 4, abbr = TRUE)) %>%
169+
replace(., . == "NA", "-") %>% as.data.frame())
169170

170171
print(glued(
171172
"\n\nHyper-parameters:\n Adstock: {x$InputCollect$adstock}"
@@ -178,8 +179,8 @@ print.robyn_write <- function(x, ...) {
178179
select(-contains("lambda"), -any_of(HYPS_OTHERS)) %>%
179180
tidyr::gather() %>%
180181
tidyr::separate(.data$key,
181-
into = c("channel", "none"),
182-
sep = regex, remove = FALSE
182+
into = c("channel", "none"),
183+
sep = regex, remove = FALSE
183184
) %>%
184185
mutate(hyperparameter = gsub("^.*_", "", .data$key)) %>%
185186
select(.data$channel, .data$hyperparameter, .data$value) %>%

R/R/plots.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -589,10 +589,10 @@ robyn_onepagers <- function(
589589
rver <- utils::sessionInfo()$R.version
590590
onepagerTitle <- sprintf("One-pager for Model ID: %s", sid)
591591
onepagerCaption <- sprintf("Robyn v%s [R-%s.%s]", ver, rver$major, rver$minor)
592-
get_height <- length(unique(plotMediaShareLoopLine$rn)) / 5
592+
get_height <- length(unique(plotMediaShareLoopLine$rn)) / 5
593593
pg <- (p2 + p5) / (p1 + p8) / (p3 + p7) / (p4 + p6) +
594594
patchwork::plot_layout(heights = c(get_height, get_height, get_height, 1), guides = "collect") +
595-
#pg <- wrap_plots(p2, p5, p1, p8, p3, p7, p4, p6, ncol = 2) +
595+
# pg <- wrap_plots(p2, p5, p1, p8, p3, p7, p4, p6, ncol = 2) +
596596
plot_annotation(
597597
title = onepagerTitle, subtitle = errors,
598598
theme = theme_lares(background = "white", legend = "none"),
@@ -1339,7 +1339,8 @@ refresh_plots_json <- function(OutputCollectRF, json_file, export = TRUE, ...) {
13391339
ggsave(
13401340
filename = paste0(
13411341
chainData[[length(chainData)]]$ExportedModel$plot_folder,
1342-
"report_decomposition.png"),
1342+
"report_decomposition.png"
1343+
),
13431344
plot = pBarRF,
13441345
dpi = 900, width = 12, height = 8, limitsize = FALSE
13451346
)

R/R/refresh.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,10 +450,13 @@ robyn_refresh <- function(json_file = NULL,
450450
export = TRUE, quiet = TRUE, ...
451451
)
452452
plots <- refresh_plots_json(
453-
OutputCollectRF, json_file = attr(json_temp, "json_file"), export, ...)
453+
OutputCollectRF,
454+
json_file = attr(json_temp, "json_file"), export, ...
455+
)
454456
} else {
455457
plots <- try(refresh_plots(
456-
InputCollectRF, OutputCollectRF, ReportCollect, export, ...))
458+
InputCollectRF, OutputCollectRF, ReportCollect, export, ...
459+
))
457460
}
458461

459462
if (export) {

R/R/response.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,10 @@ robyn_response <- function(InputCollect = NULL,
179179
if (usecase == "all_historical_vec") {
180180
ds_list <- check_metric_dates(date_range = "all", all_dates[1:endRW], dayInterval, quiet, ...)
181181
metric_value <- NULL
182-
#val_list <- check_metric_value(metric_value, metric_name, all_values, ds_list$metric_loc)
182+
# val_list <- check_metric_value(metric_value, metric_name, all_values, ds_list$metric_loc)
183183
} else if (usecase == "unit_metric_default_last_n") {
184184
ds_list <- check_metric_dates(date_range = paste0("last_", length(metric_value)), all_dates[1:endRW], dayInterval, quiet, ...)
185-
#val_list <- check_metric_value(metric_value, metric_name, all_values, ds_list$metric_loc)
185+
# val_list <- check_metric_value(metric_value, metric_name, all_values, ds_list$metric_loc)
186186
} else {
187187
ds_list <- check_metric_dates(date_range, all_dates[1:endRW], dayInterval, quiet, ...)
188188
}
@@ -351,7 +351,7 @@ which_usecase <- function(metric_value, date_range) {
351351
TRUE ~ "unit_metric_selected_dates"
352352
)
353353
if (!is.null(date_range)) {
354-
if (length(date_range) ==1 & date_range[1] == "all") {
354+
if (length(date_range) == 1 & date_range[1] == "all") {
355355
usecase <- "all_historical_vec"
356356
}
357357
}

R/R/transformation.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ adstock_weibull <- function(x, shape, scale, windlen = length(x), type = "cdf")
162162
x_decayed = x_decayed,
163163
thetaVecCum = thetaVecCum,
164164
inflation_total = inflation_total,
165-
x_imme = x_imme)
166-
)
165+
x_imme = x_imme
166+
))
167167
}
168168

169169
#' @rdname adstocks
@@ -272,8 +272,10 @@ plot_adstock <- function(plot = TRUE) {
272272
dt_weibull <- data.frame(
273273
x = 1:100,
274274
decay_accumulated = adstock_weibull(
275-
1:100, shape = shapeVec[v1], scale = scaleVec[v2],
276-
type = tolower(types[t]))$thetaVecCum,
275+
1:100,
276+
shape = shapeVec[v1], scale = scaleVec[v2],
277+
type = tolower(types[t])
278+
)$thetaVecCum,
277279
shape = paste0("shape=", shapeVec[v1]),
278280
scale = as.factor(scaleVec[v2]),
279281
type = types[t]

0 commit comments

Comments
 (0)