Skip to content

Commit 6e40335

Browse files
author
Gufeng Zhou
committed
fixed loading old model & refresh bug
1 parent d12dfec commit 6e40335

File tree

3 files changed

+33
-15
lines changed

3 files changed

+33
-15
lines changed

R/R/model.R

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,12 @@ robyn_run <- function(InputCollect,
185185
message(">>> Collecting results...")
186186

187187
## collect hyperparameter results
188-
names(model_output_collect) <- paste0("trial", 1:InputCollect$trials)
188+
if (hyper_fixed) {
189+
names(model_output_collect) <- "trial1"
190+
} else {
191+
names(model_output_collect) <- paste0("trial", 1:InputCollect$trials)
192+
}
193+
189194
resultHypParam <- rbindlist(lapply(model_output_collect, function(x) x$resultCollect$resultHypParam[, trial := x$trial]))
190195
resultHypParam[, iterations := (iterNG - 1) * InputCollect$cores + iterPar]
191196
xDecompAgg <- rbindlist(lapply(model_output_collect, function(x) x$resultCollect$xDecompAgg[, trial := x$trial]))
@@ -1063,14 +1068,14 @@ robyn_mmm <- function(hyper_collect,
10631068
# assign("InputCollect", InputCollect, envir = .GlobalEnv) # adding this to enable InputCollect reading during parallel
10641069
# opts <- list(progress = function(n) setTxtProgressBar(pb, n))
10651070
sysTimeDopar <- system.time({
1066-
for (lng in 1:iterNG) {
1071+
for (lng in 1:iterNG) { # lng = 1
10671072
nevergrad_hp <- list()
10681073
nevergrad_hp_val <- list()
10691074
hypParamSamList <- list()
10701075
hypParamSamNG <- c()
10711076

10721077
if (hyper_fixed == FALSE) {
1073-
for (co in 1:iterPar) {
1078+
for (co in 1:iterPar) { # co = 1
10741079

10751080
## get hyperparameter sample with ask
10761081
nevergrad_hp[[co]] <- optimizer$ask()
@@ -1119,7 +1124,7 @@ robyn_mmm <- function(hyper_collect,
11191124

11201125
getDoParWorkers()
11211126
doparCollect <- suppressPackageStartupMessages(
1122-
foreach(i = 1:iterPar) %dorng% {
1127+
foreach(i = 1:iterPar) %dorng% { # i = 1
11231128
t1 <- Sys.time()
11241129

11251130
#####################################

R/R/refresh.R

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ robyn_refresh <- function(robyn_object,
300300
OutputCollectRF$selectID <- selectID
301301
message(
302302
"Selected model ID: ", selectID, " for refresh model nr.",
303-
refreshCounter, " based on the smallest combined error of NRMSE & DECOMP.RSSD"
303+
refreshCounter, " based on the smallest combined error of NRMSE & DECOMP.RSSD\n"
304304
)
305305

306306
OutputCollectRF$resultHypParam[, bestModRF := solID == selectID]
@@ -334,37 +334,39 @@ robyn_refresh <- function(robyn_object,
334334
OutputCollectRF$mediaVecCollect[
335335
bestModRF == TRUE & ds >= InputCollectRF$refreshAddedStart &
336336
ds <= refreshEnd
337-
][, refreshStatus := refreshCounter]
337+
][, ':='(refreshStatus = refreshCounter, ds = as.IDate(ds))]
338338
)
339339
mediaVecReport <- mediaVecReport[order(type, ds, refreshStatus)]
340340
xDecompVecReport <- rbind(
341341
listOutputPrev$xDecompVecCollect[bestModRF == TRUE],
342342
OutputCollectRF$xDecompVecCollect[
343343
bestModRF == TRUE & ds >= InputCollectRF$refreshAddedStart &
344344
ds <= refreshEnd
345-
][, refreshStatus := refreshCounter]
345+
][, ':='(refreshStatus = refreshCounter, ds = as.IDate(ds))]
346346
)
347347
} else {
348-
resultHypParamReport <- rbind(listReportPrev$resultHypParamReport, OutputCollectRF$resultHypParam[
349-
bestModRF == TRUE
350-
][, refreshStatus := refreshCounter])
351-
xDecompAggReport <- rbind(listReportPrev$xDecompAggReport, OutputCollectRF$xDecompAgg[
352-
bestModRF == TRUE
353-
][, refreshStatus := refreshCounter])
348+
resultHypParamReport <- rbind(
349+
listReportPrev$resultHypParamReport,
350+
OutputCollectRF$resultHypParam[bestModRF == TRUE][
351+
, refreshStatus := refreshCounter])
352+
xDecompAggReport <- rbind(
353+
listReportPrev$xDecompAggReport,
354+
OutputCollectRF$xDecompAgg[bestModRF == TRUE][
355+
, refreshStatus := refreshCounter])
354356
mediaVecReport <- rbind(
355357
listReportPrev$mediaVecReport,
356358
OutputCollectRF$mediaVecCollect[
357359
bestModRF == TRUE & ds >= InputCollectRF$refreshAddedStart &
358360
ds <= refreshEnd
359-
][, refreshStatus := refreshCounter]
361+
][, ':='(refreshStatus = refreshCounter, ds = as.IDate(ds))]
360362
)
361363
mediaVecReport <- mediaVecReport[order(type, ds, refreshStatus)]
362364
xDecompVecReport <- rbind(
363365
listReportPrev$xDecompVecReport,
364366
OutputCollectRF$xDecompVecCollect[
365367
bestModRF == TRUE & ds >= InputCollectRF$refreshAddedStart &
366368
ds <= refreshEnd
367-
][, refreshStatus := refreshCounter]
369+
][, ':='(refreshStatus = refreshCounter, ds = as.IDate(ds))]
368370
)
369371
}
370372

demo/debug.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ lambda.n = 100
88
lambda_control = 1
99
lambda_fixed = NULL
1010
refresh = FALSE
11+
seed = 123L
1112
# go into robyn_mmm() line by line
1213

1314
## debug robyn_run
@@ -24,6 +25,16 @@ csv_out = "pareto"
2425
seed = 123
2526
# go into robyn_run() line by line
2627

28+
## debug robyn_refresh
29+
# robyn_object
30+
dt_input = dt_input
31+
dt_holidays = dt_holidays
32+
refresh_steps = 14
33+
refresh_mode = "auto" # "auto", "manual"
34+
refresh_iters = 100
35+
refresh_trials = 2
36+
plot_pareto = TRUE
37+
2738
## debug robyn_allocator
2839
# prep input para
2940

0 commit comments

Comments
 (0)