Skip to content

Commit 5db34a1

Browse files
committed
Merge branch 'main' of github.com:mlr-org/mlr3mbo
2 parents 4d2dd5d + d49b117 commit 5db34a1

16 files changed

+112
-18
lines changed

DESCRIPTION

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Type: Package
22
Package: mlr3mbo
33
Title: Flexible Bayesian Optimization
4-
Version: 0.2.3.9000
4+
Version: 0.2.4.9000
55
Authors@R: c(
66
person("Lennart", "Schneider", , "[email protected]", role = c("cre", "aut"),
77
comment = c(ORCID = "0000-0003-4152-5308")),
@@ -66,13 +66,14 @@ Suggests:
6666
rpart,
6767
stringi,
6868
testthat (>= 3.0.0)
69+
Remotes: mlr-org/bbotk
6970
ByteCompile: no
7071
Encoding: UTF-8
7172
Config/testthat/edition: 3
7273
Config/testthat/parallel: false
7374
NeedsCompilation: yes
7475
Roxygen: list(markdown = TRUE, r6 = TRUE)
75-
RoxygenNote: 7.3.1
76+
RoxygenNote: 7.3.2
7677
Collate:
7778
'mlr_acqfunctions.R'
7879
'AcqFunction.R'

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# mlr3mbo (development version)
22

3+
# mlr3mbo 0.2.4
4+
5+
* fix: Improve runtime of `AcqOptimizer` by setting `check_values = FALSE`.
6+
37
# mlr3mbo 0.2.3
48

59
* compatibility: Work with new bbotk and mlr3tuning version 1.0.0.

R/AcqFunctionCB.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ AcqFunctionCB = R6Class("AcqFunctionCB",
7676
constants = list(...)
7777
lambda = constants$lambda
7878
p = self$surrogate$predict(xdt)
79-
res = p$mean - self$surrogate_max_to_min * lambda * p$se
80-
data.table(acq_cb = res)
79+
cb = p$mean - self$surrogate_max_to_min * lambda * p$se
80+
data.table(acq_cb = cb)
8181
}
8282
)
8383
)

R/AcqFunctionEI.R

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99
#' @description
1010
#' Expected Improvement.
1111
#'
12+
#' @section Parameters:
13+
#' * `"epsilon"` (`numeric(1)`)\cr
14+
#' \eqn{\epsilon} value used to determine the amount of exploration.
15+
#' Higher values result in the importance of improvements predicted by the posterior mean
16+
#' decreasing relative to the importance of potential improvements in regions of high predictive uncertainty.
17+
#' Defaults to `0` (standard Expected Improvement).
18+
#'
1219
#' @references
1320
#' * `r format_bib("jones_1998")`
1421
#'
@@ -60,9 +67,15 @@ AcqFunctionEI = R6Class("AcqFunctionEI",
6067
#' Creates a new instance of this [R6][R6::R6Class] class.
6168
#'
6269
#' @param surrogate (`NULL` | [SurrogateLearner]).
63-
initialize = function(surrogate = NULL) {
70+
#' @param epsilon (`numeric(1)`).
71+
initialize = function(surrogate = NULL, epsilon = 0) {
6472
assert_r6(surrogate, "SurrogateLearner", null.ok = TRUE)
65-
super$initialize("acq_ei", surrogate = surrogate, requires_predict_type_se = TRUE, direction = "maximize", label = "Expected Improvement", man = "mlr3mbo::mlr_acqfunctions_ei")
73+
assert_number(epsilon, lower = 0, finite = TRUE)
74+
75+
constants = ps(epsilon = p_dbl(lower = 0, default = 0))
76+
constants$values$epsilon = epsilon
77+
78+
super$initialize("acq_ei", constants = constants, surrogate = surrogate, requires_predict_type_se = TRUE, direction = "maximize", label = "Expected Improvement", man = "mlr3mbo::mlr_acqfunctions_ei")
6679
},
6780

6881
#' @description
@@ -73,14 +86,16 @@ AcqFunctionEI = R6Class("AcqFunctionEI",
7386
),
7487

7588
private = list(
76-
.fun = function(xdt) {
89+
.fun = function(xdt, ...) {
7790
if (is.null(self$y_best)) {
7891
stop("$y_best is not set. Missed to call $update()?")
7992
}
93+
constants = list(...)
94+
epsilon = constants$epsilon
8095
p = self$surrogate$predict(xdt)
8196
mu = p$mean
8297
se = p$se
83-
d = self$y_best - self$surrogate_max_to_min * mu
98+
d = (self$y_best - self$surrogate_max_to_min * mu) - epsilon
8499
d_norm = d / se
85100
ei = d * pnorm(d_norm) + se * dnorm(d_norm)
86101
ei = ifelse(se < 1e-20, 0, ei)

R/AcqOptimizer.R

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,21 @@ AcqOptimizer = R6Class("AcqOptimizer",
9393
#' @field acq_function ([AcqFunction]).
9494
acq_function = NULL,
9595

96+
#' @field callbacks (`NULL` | list of [mlr3misc::Callback]).
97+
callbacks = NULL,
98+
9699
#' @description
97100
#' Creates a new instance of this [R6][R6::R6Class] class.
98101
#'
99102
#' @param optimizer ([bbotk::Optimizer]).
100103
#' @param terminator ([bbotk::Terminator]).
101104
#' @param acq_function (`NULL` | [AcqFunction]).
102-
initialize = function(optimizer, terminator, acq_function = NULL) {
105+
#' @param callbacks (`NULL` | list of [mlr3misc::Callback])
106+
initialize = function(optimizer, terminator, acq_function = NULL, callbacks = NULL) {
103107
self$optimizer = assert_r6(optimizer, "Optimizer")
104108
self$terminator = assert_r6(terminator, "Terminator")
105109
self$acq_function = assert_r6(acq_function, "AcqFunction", null.ok = TRUE)
110+
self$callbacks = assert_callbacks(as_callbacks(callbacks))
106111
ps = ps(
107112
n_candidates = p_int(lower = 1, default = 1L),
108113
logging_level = p_fct(levels = c("fatal", "error", "warn", "info", "debug", "trace"), default = "warn"),
@@ -146,7 +151,7 @@ AcqOptimizer = R6Class("AcqOptimizer",
146151
logger$set_threshold(self$param_set$values$logging_level)
147152
on.exit(logger$set_threshold(old_threshold))
148153

149-
instance = OptimInstanceBatchSingleCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE)
154+
instance = OptimInstanceBatchSingleCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE, callbacks = self$callbacks)
150155

151156
# warmstart
152157
if (self$param_set$values$warmstart) {

R/sugar.R

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#' @param cols_y (`NULL` | `character()`)\cr
2121
#' Column id(s) in the [bbotk::Archive] that should be used as a target.
2222
#' If a list of [mlr3::LearnerRegr] is provided as the `learner` argument and `cols_y` is
23-
#' specified as well, as many column names as learners must be provided.
23+
#' specified as well, as many column names as learners must be provided.
2424
#' Can also be `NULL` in which case this is automatically inferred based on the archive.
2525
#' @param ... (named `list()`)\cr
2626
#' Named arguments passed to the constructor, to be set as parameters in the
@@ -90,6 +90,8 @@ acqf = function(.key, ...) {
9090
#' @param acq_function (`NULL` | [AcqFunction])\cr
9191
#' [AcqFunction] that is to be used.
9292
#' Can also be `NULL`.
93+
#' @param callbacks (`NULL` | list of [mlr3misc::Callback])
94+
#' Callbacks used during acquisition function optimization.
9395
#' @param ... (named `list()`)\cr
9496
#' Named arguments passed to the constructor, to be set as parameters in the
9597
#' [paradox::ParamSet].
@@ -101,9 +103,9 @@ acqf = function(.key, ...) {
101103
#' library(bbotk)
102104
#' acqo(opt("random_search"), trm("evals"), catch_errors = FALSE)
103105
#' @export
104-
acqo = function(optimizer, terminator, acq_function = NULL, ...) {
106+
acqo = function(optimizer, terminator, acq_function = NULL, callbacks = NULL, ...) {
105107
dots = list(...)
106-
acqopt = AcqOptimizer$new(optimizer = optimizer, terminator = terminator, acq_function = acq_function)
108+
acqopt = AcqOptimizer$new(optimizer = optimizer, terminator = terminator, acq_function = acq_function, callbacks = callbacks)
107109
acqopt$param_set$values = insert_named(acqopt$param_set$values, dots)
108110
acqopt
109111
}

man/AcqOptimizer.Rd

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/acqo.Rd

Lines changed: 4 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_acqfunctions_ei.Rd

Lines changed: 14 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkgdown/_pkgdown.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ toc:
1616

1717
navbar:
1818
structure:
19-
left: [reference, news, book]
20-
right: [github, mattermost, stackoverflow, rss, lightswitch]
19+
left: [reference, intro, news, book]
20+
right: [search, github, mattermost, stackoverflow, rss, lightswitch]
2121
components:
2222
home: ~
2323
reference:

0 commit comments

Comments
 (0)