Skip to content

Commit fdfff59

Browse files
authored
feat: Update Surrogate a final time after optimization process finished (#132)
1 parent e3383c7 commit fdfff59

File tree

3 files changed

+46
-1
lines changed

3 files changed

+46
-1
lines changed

R/OptimizerMbo.R

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@
1515
#'
1616
#' Termination is handled via a [bbotk::Terminator] part of the [bbotk::OptimInstance] to be optimized.
1717
#'
18+
#' Note that in general the [Surrogate] is updated one final time on all available data after the optimization process has terminated.
19+
#' However, in certain scenarios this is not always possible or meaningful, e.g., when using [bayesopt_parego()] for multi-objective optimization
20+
#' which uses a surrogate that relies on a scalarization of the objectives.
21+
#' It is therefore recommended to manually inspect the [Surrogate] after optimization if it is to be used, e.g., for visualization purposes to make
22+
#' sure that it has been properly updated on all available data.
23+
#' If this final update of the [Surrogate] could not be performed successfully, a warning will be logged.
24+
#'
1825
#' @section Archive:
1926
#' The [bbotk::Archive] holds the following additional columns that are specific to MBO algorithms:
2027
#' * `[acq_function$id]` (`numeric(1)`)\cr
@@ -274,6 +281,7 @@ OptimizerMbo = R6Class("OptimizerMbo",
274281

275282
.optimize = function(inst) {
276283
# FIXME: this needs more checks for edge cases like eips or loop_function bayesopt_parego then default_surrogate should use one learner
284+
277285
if (is.null(self$loop_function)) {
278286
self$loop_function = default_loop_function(inst)
279287
}
@@ -303,6 +311,17 @@ OptimizerMbo = R6Class("OptimizerMbo",
303311
check_packages_installed(self$packages, msg = sprintf("Package '%%s' required but not installed for Optimizer '%s'", format(self)))
304312

305313
invoke(self$loop_function, instance = inst, surrogate = self$surrogate, acq_function = self$acq_function, acq_optimizer = self$acq_optimizer, .args = self$args)
314+
315+
on.exit({
316+
tryCatch(
317+
{
318+
self$surrogate$update()
319+
}, surrogate_update_error = function(error_condition) {
320+
logger = lgr::get_logger("bbotk")
321+
logger$warn("Could not update the surrogate a final time after the optimization process has terminated.")
322+
}
323+
)
324+
})
306325
},
307326

308327
.assign_result = function(inst) {

man/mlr_optimizers_mbo.Rd

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_OptimizerMbo.R

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ test_that("OptimizerMbo works", {
33
skip_if_not_installed("mlr3learners")
44
skip_if_not_installed("DiceKriging")
55
skip_if_not_installed("rgenoud")
6-
6+
77
optimizer = OptimizerMbo$new()
88
expect_r6(optimizer, classes = "OptimizerMbo")
99

@@ -199,3 +199,22 @@ test_that("OptimizerMbo reset", {
199199
expect_r6(optimizer$acq_optimizer, "AcqOptimizer")
200200
})
201201

202+
test_that("OptimizerMbo up to date surrogate after optimization", {
203+
skip_if_not_installed("mlr3learners")
204+
skip_if_not_installed("DiceKriging")
205+
206+
surrogate = srlrn(lrn("regr.km", covtype = "matern5_2", optim.method = "gen", control = list(trace = FALSE), nugget.stability = 10^-8))
207+
acq_optimizer = acqo(opt("random_search", batch_size = 2L), terminator = trm("evals", n_evals = 2L))
208+
optimizer = opt("mbo", surrogate = surrogate, acq_optimizer = acq_optimizer)
209+
instance = MAKE_INST_1D(terminator = trm("evals", n_evals = 5L))
210+
optimizer$optimize(instance)
211+
212+
expect_equal(surrogate, optimizer$surrogate)
213+
214+
expect_true(surrogate$learner$state$train_task$nrow == 5L)
215+
216+
predictions = surrogate$predict(instance$archive$data[, instance$archive$cols_x, with = FALSE])
217+
expect_true(all(sqrt((predictions$mean - instance$archive$data$y) ^ 2) < 1e-4))
218+
expect_true(all(predictions$se < 1e-4))
219+
})
220+

0 commit comments

Comments
 (0)