Skip to content

Commit d49b117

Browse files
authored
feat: support callbacks in AcqOptimizer (#153)
* feat: support callbacks in AcqOptimizer
1 parent 2d99077 commit d49b117

File tree

6 files changed

+51
-7
lines changed

6 files changed

+51
-7
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ Suggests:
6666
rpart,
6767
stringi,
6868
testthat (>= 3.0.0)
69+
Remotes: mlr-org/bbotk
6970
ByteCompile: no
7071
Encoding: UTF-8
7172
Config/testthat/edition: 3

R/AcqOptimizer.R

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,21 @@ AcqOptimizer = R6Class("AcqOptimizer",
9393
#' @field acq_function ([AcqFunction]).
9494
acq_function = NULL,
9595

96+
#' @field callbacks (`NULL` | list of [mlr3misc::Callback]).
97+
callbacks = NULL,
98+
9699
#' @description
97100
#' Creates a new instance of this [R6][R6::R6Class] class.
98101
#'
99102
#' @param optimizer ([bbotk::Optimizer]).
100103
#' @param terminator ([bbotk::Terminator]).
101104
#' @param acq_function (`NULL` | [AcqFunction]).
102-
initialize = function(optimizer, terminator, acq_function = NULL) {
105+
#' @param callbacks (`NULL` | list of [mlr3misc::Callback])
106+
initialize = function(optimizer, terminator, acq_function = NULL, callbacks = NULL) {
103107
self$optimizer = assert_r6(optimizer, "Optimizer")
104108
self$terminator = assert_r6(terminator, "Terminator")
105109
self$acq_function = assert_r6(acq_function, "AcqFunction", null.ok = TRUE)
110+
self$callbacks = assert_callbacks(as_callbacks(callbacks))
106111
ps = ps(
107112
n_candidates = p_int(lower = 1, default = 1L),
108113
logging_level = p_fct(levels = c("fatal", "error", "warn", "info", "debug", "trace"), default = "warn"),
@@ -146,7 +151,7 @@ AcqOptimizer = R6Class("AcqOptimizer",
146151
logger$set_threshold(self$param_set$values$logging_level)
147152
on.exit(logger$set_threshold(old_threshold))
148153

149-
instance = OptimInstanceBatchSingleCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE)
154+
instance = OptimInstanceBatchSingleCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE, callbacks = self$callbacks)
150155

151156
# warmstart
152157
if (self$param_set$values$warmstart) {

R/sugar.R

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#' @param cols_y (`NULL` | `character()`)\cr
2121
#' Column id(s) in the [bbotk::Archive] that should be used as a target.
2222
#' If a list of [mlr3::LearnerRegr] is provided as the `learner` argument and `cols_y` is
23-
#' specified as well, as many column names as learners must be provided.
23+
#' specified as well, as many column names as learners must be provided.
2424
#' Can also be `NULL` in which case this is automatically inferred based on the archive.
2525
#' @param ... (named `list()`)\cr
2626
#' Named arguments passed to the constructor, to be set as parameters in the
@@ -90,6 +90,8 @@ acqf = function(.key, ...) {
9090
#' @param acq_function (`NULL` | [AcqFunction])\cr
9191
#' [AcqFunction] that is to be used.
9292
#' Can also be `NULL`.
93+
#' @param callbacks (`NULL` | list of [mlr3misc::Callback])
94+
#' Callbacks used during acquisition function optimization.
9395
#' @param ... (named `list()`)\cr
9496
#' Named arguments passed to the constructor, to be set as parameters in the
9597
#' [paradox::ParamSet].
@@ -101,9 +103,9 @@ acqf = function(.key, ...) {
101103
#' library(bbotk)
102104
#' acqo(opt("random_search"), trm("evals"), catch_errors = FALSE)
103105
#' @export
104-
acqo = function(optimizer, terminator, acq_function = NULL, ...) {
106+
acqo = function(optimizer, terminator, acq_function = NULL, callbacks = NULL, ...) {
105107
dots = list(...)
106-
acqopt = AcqOptimizer$new(optimizer = optimizer, terminator = terminator, acq_function = acq_function)
108+
acqopt = AcqOptimizer$new(optimizer = optimizer, terminator = terminator, acq_function = acq_function, callbacks = callbacks)
107109
acqopt$param_set$values = insert_named(acqopt$param_set$values, dots)
108110
acqopt
109111
}

man/AcqOptimizer.Rd

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/acqo.Rd

Lines changed: 4 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_AcqOptimizer.R

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,32 @@ test_that("AcqOptimizer deep clone", {
120120
expect_true(address(acqopt1$terminator) != address(acqopt2$terminator))
121121
})
122122

123+
test_that("AcqOptimizer callbacks", {
124+
domain = ps(x = p_dbl(lower = 10, upper = 20, trafo = function(x) x - 15))
125+
objective = ObjectiveRFunDt$new(
126+
fun = function(xdt) data.table(y = xdt$x ^ 2),
127+
domain = domain,
128+
codomain = ps(y = p_dbl(tags = "minimize")),
129+
check_values = FALSE
130+
)
131+
instance = MAKE_INST(objective = objective, search_space = domain, terminator = trm("evals", n_evals = 5L))
132+
design = MAKE_DESIGN(instance)
133+
instance$eval_batch(design)
134+
callback = callback_batch("mlr3mbo.acqopt_time",
135+
on_optimization_begin = function(callback, context) {
136+
callback$state$begin = Sys.time()
137+
},
138+
on_optimization_end = function(callback, context) {
139+
callback$state$end = Sys.time()
140+
attr(callback$state$outer_instance, "acq_opt_runtime") = as.numeric(callback$state$end - callback$state$begin)
141+
}
142+
)
143+
callback$state$outer_instance = instance
144+
acqfun = AcqFunctionEI$new(SurrogateLearner$new(REGR_FEATURELESS, archive = instance$archive))
145+
acqopt = AcqOptimizer$new(opt("random_search", batch_size = 10L), trm("evals", n_evals = 10L), acq_function = acqfun, callbacks = callback)
146+
acqfun$surrogate$update()
147+
acqfun$update()
148+
res = acqopt$optimize()
149+
expect_number(attr(instance, "acq_opt_runtime"))
150+
})
151+

0 commit comments

Comments
 (0)