Skip to content

Commit 2d99077

Browse files
authored
feat: allow EI to be adjusted by epsilon to strengthen exploration (#154)
1 parent b45f868 commit 2d99077

File tree

8 files changed

+48
-8
lines changed

8 files changed

+48
-8
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ Config/testthat/edition: 3
7272
Config/testthat/parallel: false
7373
NeedsCompilation: yes
7474
Roxygen: list(markdown = TRUE, r6 = TRUE)
75-
RoxygenNote: 7.3.1
75+
RoxygenNote: 7.3.2
7676
Collate:
7777
'mlr_acqfunctions.R'
7878
'AcqFunction.R'

R/AcqFunctionCB.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ AcqFunctionCB = R6Class("AcqFunctionCB",
7676
constants = list(...)
7777
lambda = constants$lambda
7878
p = self$surrogate$predict(xdt)
79-
res = p$mean - self$surrogate_max_to_min * lambda * p$se
80-
data.table(acq_cb = res)
79+
cb = p$mean - self$surrogate_max_to_min * lambda * p$se
80+
data.table(acq_cb = cb)
8181
}
8282
)
8383
)

R/AcqFunctionEI.R

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99
#' @description
1010
#' Expected Improvement.
1111
#'
12+
#' @section Parameters:
13+
#' * `"epsilon"` (`numeric(1)`)\cr
14+
#' \eqn{\epsilon} value used to determine the amount of exploration.
15+
#' Higher values result in the importance of improvements predicted by the posterior mean
16+
#' decreasing relative to the importance of potential improvements in regions of high predictive uncertainty.
17+
#' Defaults to `0` (standard Expected Improvement).
18+
#'
1219
#' @references
1320
#' * `r format_bib("jones_1998")`
1421
#'
@@ -60,9 +67,15 @@ AcqFunctionEI = R6Class("AcqFunctionEI",
6067
#' Creates a new instance of this [R6][R6::R6Class] class.
6168
#'
6269
#' @param surrogate (`NULL` | [SurrogateLearner]).
63-
initialize = function(surrogate = NULL) {
70+
#' @param epsilon (`numeric(1)`).
71+
initialize = function(surrogate = NULL, epsilon = 0) {
6472
assert_r6(surrogate, "SurrogateLearner", null.ok = TRUE)
65-
super$initialize("acq_ei", surrogate = surrogate, requires_predict_type_se = TRUE, direction = "maximize", label = "Expected Improvement", man = "mlr3mbo::mlr_acqfunctions_ei")
73+
assert_number(epsilon, lower = 0, finite = TRUE)
74+
75+
constants = ps(epsilon = p_dbl(lower = 0, default = 0))
76+
constants$values$epsilon = epsilon
77+
78+
super$initialize("acq_ei", constants = constants, surrogate = surrogate, requires_predict_type_se = TRUE, direction = "maximize", label = "Expected Improvement", man = "mlr3mbo::mlr_acqfunctions_ei")
6679
},
6780

6881
#' @description
@@ -73,14 +86,16 @@ AcqFunctionEI = R6Class("AcqFunctionEI",
7386
),
7487

7588
private = list(
76-
.fun = function(xdt) {
89+
.fun = function(xdt, ...) {
7790
if (is.null(self$y_best)) {
7891
stop("$y_best is not set. Missed to call $update()?")
7992
}
93+
constants = list(...)
94+
epsilon = constants$epsilon
8095
p = self$surrogate$predict(xdt)
8196
mu = p$mean
8297
se = p$se
83-
d = self$y_best - self$surrogate_max_to_min * mu
98+
d = (self$y_best - self$surrogate_max_to_min * mu) - epsilon
8499
d_norm = d / se
85100
ei = d * pnorm(d_norm) + se * dnorm(d_norm)
86101
ei = ifelse(se < 1e-20, 0, ei)

man/mlr_acqfunctions_ei.Rd

Lines changed: 14 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_AcqFunctionCB.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ test_that("AcqFunctionCB works", {
1212
expect_learner(acqf$surrogate$learner)
1313
expect_true(acqf$requires_predict_type_se)
1414

15+
expect_r6(acqf$constants, "ParamSet")
16+
expect_equal(acqf$constants$ids(), "lambda")
17+
1518
design = MAKE_DESIGN(inst)
1619
inst$eval_batch(design)
1720

tests/testthat/test_AcqFunctionEHVIGH.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ test_that("AcqFunctionEHVIGH works", {
1515
expect_true(acqf$requires_predict_type_se)
1616
expect_setequal(acqf$packages, c("emoa", "fastGHQuad"))
1717

18+
expect_r6(acqf$constants, "ParamSet")
19+
expect_equal(acqf$constants$ids(), c("k", "r"))
20+
1821
design = MAKE_DESIGN(inst)
1922
inst$eval_batch(design)
2023

tests/testthat/test_AcqFunctionEI.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ test_that("AcqFunctionEI works", {
1313
expect_learner(acqf$surrogate$learner)
1414
expect_true(acqf$requires_predict_type_se)
1515

16+
expect_r6(acqf$constants, "ParamSet")
17+
expect_equal(acqf$constants$ids(), "epsilon")
18+
1619
design = MAKE_DESIGN(inst)
1720
inst$eval_batch(design)
1821

tests/testthat/test_AcqFunctionSmsEgo.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ test_that("AcqFunctionSmsEgo works", {
1212
expect_list(acqf$surrogate$learner, types = "Learner")
1313
expect_true(acqf$requires_predict_type_se)
1414

15+
expect_r6(acqf$constants, "ParamSet")
16+
expect_equal(acqf$constants$ids(), c("lambda", "epsilon"))
17+
1518
design = MAKE_DESIGN(inst)
1619
inst$eval_batch(design)
1720

0 commit comments

Comments
 (0)