Skip to content

Commit 95091c9

Browse files
author
Gufeng Zhou
committed
refactor: adapted to dynamic convergence rules
1 parent 50f559b commit 95091c9

File tree

4 files changed

+60
-34
lines changed

4 files changed

+60
-34
lines changed

R/R/convergence.R

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,29 @@
1111
#'
1212
#' @param OutputModels List. Output from \code{robyn_run()}
1313
#' @param n_cuts Integer. Default to 20 (5% cuts). Convergence is calculated
14-
#' on using first and last quantile cuts. Criteria 1: last quantile's sd
15-
#' < threshold_sd. Criteria 2: last quantile's median < first quantile's
16-
#' median - 2 * sd. Both have to happen to consider convergence.
17-
#' @param threshold_sd Numeric. Default to 0.025 that is empirically derived.
14+
#' on using first and last quantile cuts. By default, criteria 1: last
15+
#' quantile's sd < first 3 quantiles' mean sd. Criteria 2: last quantile's
16+
#' median < first quantile's median - 3 * first 3 quantiles' mean sd. Both
17+
#' have to be satisfied to consider convergence.
18+
#' @param sd_qtref Integer. Reference quantile of the error convergence rule
19+
#' for standard deviation. Defaults to 3. Error convergence rule for sd is
20+
#' defined as by default: last quantile's sd < first 3 quantiles' mean sd.
21+
#' @param med_lowb Integer. Lower bound distance of the error convergence rule
22+
#' for median. Default to 3. Error convergence rule for median is defined as
23+
#' by default: last quantile's median < first quantile's median - 3 * first 3
24+
#' quantiles' mean sd.
1825
#' @param ... Additional parameters
1926
#' @examples
2027
#' \dontrun{
2128
#' OutputModels <- robyn_converge(
2229
#' OutputModels = OutputModels,
23-
#' n_cuts = 10,
24-
#' threshold_sd = 0.025
30+
#' n_cuts = 20,
31+
#' sd_qtref = 3,
32+
#' med_lowb = 3
2533
#' )
2634
#' }
2735
#' @export
28-
robyn_converge <- function(OutputModels, n_cuts = 20, threshold_sd = 0.025, ...) {
36+
robyn_converge <- function(OutputModels, n_cuts = 20, sd_qtref = 3, med_lowb = 3, ...) {
2937

3038
# Gather all trials
3139
get_lists <- as.logical(grepl("trial", names(OutputModels)) * sapply(OutputModels, is.list))
@@ -54,8 +62,8 @@ robyn_converge <- function(OutputModels, n_cuts = 20, threshold_sd = 0.025, ...)
5462
))
5563

5664
# Calculate sd and median on each cut to alert user on:
57-
# 1) last quantile's sd < threshold_sd
58-
# 2) last quantile's median < first quantile's median - 2 * sd
65+
# 1) last quantile's sd < mean sd of default first 3 qt
66+
# 2) last quantile's median < median of first qt - default 3 * mean sd of defualt first 3 qt
5967
errors <- dt_objfunc_cvg %>%
6068
group_by(.data$error_type, .data$cuts) %>%
6169
summarise(
@@ -66,29 +74,37 @@ robyn_converge <- function(OutputModels, n_cuts = 20, threshold_sd = 0.025, ...)
6674
) %>%
6775
group_by(.data$error_type) %>%
6876
mutate(
69-
med_var_P = abs(round(100 * (.data$median - lag(.data$median)) / .data$median, 2)),
70-
flag_sd = .data$std > threshold_sd
77+
med_var_P = abs(round(100 * (.data$median - lag(.data$median)) / .data$median, 2))
7178
) %>%
7279
group_by(.data$error_type) %>%
73-
mutate(flag_med = dplyr::last(.data$median[1]) < dplyr::first(.data$median[2]) - 2 * dplyr::first(.data$std))
80+
mutate(first_med = dplyr::first(.data$median),
81+
first_med_avg = mean(.data$median[1:sd_qtref]),
82+
last_med = dplyr::last(.data$median),
83+
first_sd = dplyr::first(.data$std),
84+
first_sd_avg = mean(.data$std[1:sd_qtref]),
85+
last_sd = dplyr::last(.data$std)) %>%
86+
mutate(med_thres = .data$first_med - med_lowb * .data$first_sd_avg,
87+
flag_med = .data$median < .data$first_med - med_lowb * .data$first_sd_avg,
88+
flag_sd = .data$std < .data$first_sd_avg)
7489

7590
conv_msg <- NULL
7691
for (obj_fun in unique(errors$error_type)) {
7792
temp.df <- filter(errors, .data$error_type == obj_fun) %>%
7893
mutate(median = signif(median, 2))
7994
last.qt <- tail(temp.df, 1)
8095
temp <- glued(paste(
81-
"{error_type} {did}converged: sd {sd} @qt.{quantile} {symb_sd} {sd_threh} &",
82-
"med {qtn_median} @qt.{quantile} {symb_med} {med_threh} med@qt.1-2*sd"),
96+
"{error_type} {did}converged: sd@qt.{quantile} {sd} {symb_sd} {sd_threh} &",
97+
"med@qt.{quantile} {qtn_median} {symb_med} {med_threh} med@qt.1-{med_lowb}*sd"),
8398
error_type = last.qt$error_type,
84-
did = ifelse(last.qt$flag_sd | last.qt$flag_med, "NOT ", ""),
85-
sd = signif(last.qt$std, 1),
86-
symb_sd = ifelse(last.qt$flag_sd, ">", "<="),
87-
sd_threh = threshold_sd,
88-
quantile = round(100/n_cuts),
89-
qtn_median = temp.df$median[n_cuts],
90-
symb_med = ifelse(last.qt$flag_med, ">", "<="),
91-
med_threh = signif(temp.df$median[1] - 2 * temp.df$std[1], 2)
99+
did = ifelse(last.qt$flag_sd & last.qt$flag_med, "", "NOT "),
100+
sd = signif(last.qt$last_sd, 2),
101+
symb_sd = ifelse(last.qt$flag_sd, "<", ">="),
102+
sd_threh = signif(last.qt$first_sd_avg, 2),
103+
quantile = n_cuts,
104+
qtn_median = signif(last.qt$last_med, 2),
105+
symb_med = ifelse(last.qt$flag_med, "<", ">="),
106+
med_threh = signif(last.qt$med_thres, 2),
107+
med_lowb = med_lowb
92108
)
93109
conv_msg <- c(conv_msg, temp)
94110
}
@@ -162,7 +178,8 @@ robyn_converge <- function(OutputModels, n_cuts = 20, threshold_sd = 0.025, ...)
162178
errors = errors,
163179
conv_msg = conv_msg
164180
)
165-
attr(cvg_out, "threshold_sd") <- threshold_sd
181+
attr(cvg_out, "sd_qtref") <- sd_qtref
182+
attr(cvg_out, "med_lowb") <- med_lowb
166183

167184
return(invisible(cvg_out))
168185
}

R/R/model.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ robyn_run <- function(InputCollect,
143143
#' @export
144144
print.robyn_models <- function(x, ...) {
145145
is_fixed <- all(lapply(x$hyper_updated, length) == 1)
146-
threshold_sd <- attr(x$convergence, "threshold_sd")
147146
print(glued(
148147
"
149148
Total trials: {x$trials}

R/man/robyn_converge.Rd

Lines changed: 16 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

demo/demo.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ hyperparameters <- list(
181181

182182
,ooh_S_alphas = c(0.5, 3)
183183
,ooh_S_gammas = c(0.3, 1)
184-
,ooh_S_thetas = c(0.3) # (0.1, 0.4)
184+
,ooh_S_thetas = c(0.1, 0.4)
185185

186186
,newsletter_alphas = c(0.5, 3)
187187
,newsletter_gammas = c(0.3, 1)
@@ -276,9 +276,10 @@ OutputModels <- robyn_run(
276276
)
277277
print(OutputModels)
278278

279-
## Check MOO (multi-objective optimisation) convergence
279+
## Check MOO (multi-objective optimisation) convergence plots
280280
OutputModels$convergence$moo_distrb_plot
281281
OutputModels$convergence$moo_cloud_plot
282+
# check convergence rules ?robyn_converge
282283

283284
## Calculate Pareto optimality, cluster and export results and plots. See ?robyn_outputs
284285
OutputCollect <- robyn_outputs(
@@ -306,7 +307,7 @@ print(OutputCollect)
306307
# , plot_pareto = TRUE
307308
# , plot_folder = robyn_object
308309
# )
309-
# convergence <- robyn_converge(OutputModels, n_cuts = 20, threshold_sd = 0.025)
310+
# convergence <- robyn_converge(OutputModels)
310311
# convergence$moo_distrb_plot
311312
# convergence$moo_cloud_plot
312313
# print(OutputCollect)

0 commit comments

Comments
 (0)