|
| 1 | +#' @title Acquisition Function Expected Improvement on Log Scale |
| 2 | +#' |
| 3 | +#' @include AcqFunction.R |
| 4 | +#' @name mlr_acqfunctions_ei_log |
| 5 | +#' |
| 6 | +#' @templateVar id ei_log |
| 7 | +#' @template section_dictionary_acqfunctions |
| 8 | +#' |
| 9 | +#' @description |
| 10 | +#' Expected Improvement assuming that the target variable has been modeled on log scale. |
| 11 | +#' In general only sensible if the [SurrogateLearner] uses an [OutputTrafoLog] without inverting the posterior predictive distribution (`invert_posterior = FALSE`). |
| 12 | +#' See also the example below. |
| 13 | +#' |
| 14 | +#' @section Parameters: |
| 15 | +#' * `"epsilon"` (`numeric(1)`)\cr |
| 16 | +#' \eqn{\epsilon} value used to determine the amount of exploration. |
| 17 | +#' Higher values result in the importance of improvements predicted by the posterior mean |
| 18 | +#' decreasing relative to the importance of potential improvements in regions of high predictive uncertainty. |
| 19 | +#' Defaults to `0` (standard Expected Improvement). |
| 20 | +#' |
| 21 | +#' @family Acquisition Function |
| 22 | +#' @export |
| 23 | +#' @examples |
| 24 | +#' if (requireNamespace("mlr3learners") & |
| 25 | +#' requireNamespace("DiceKriging") & |
| 26 | +#' requireNamespace("rgenoud")) { |
| 27 | +#' library(bbotk) |
| 28 | +#' library(paradox) |
| 29 | +#' library(mlr3learners) |
| 30 | +#' library(data.table) |
| 31 | +#' |
| 32 | +#' fun = function(xs) { |
| 33 | +#' list(y = xs$x ^ 2) |
| 34 | +#' } |
| 35 | +#' domain = ps(x = p_dbl(lower = -10, upper = 10)) |
| 36 | +#' codomain = ps(y = p_dbl(tags = "minimize")) |
| 37 | +#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain) |
| 38 | +#' |
| 39 | +#' instance = OptimInstanceBatchSingleCrit$new( |
| 40 | +#' objective = objective, |
| 41 | +#' terminator = trm("evals", n_evals = 5)) |
| 42 | +#' |
| 43 | +#' instance$eval_batch(data.table(x = c(-6, -5, 3, 9))) |
| 44 | +#' |
| 45 | +#' learner = default_gp() |
| 46 | +#' |
| 47 | +#' output_trafo = ot("log", invert_posterior = FALSE) |
| 48 | +#' |
| 49 | +#' surrogate = srlrn(learner, output_trafo = output_trafo, archive = instance$archive) |
| 50 | +#' |
| 51 | +#' acq_function = acqf("ei_log", surrogate = surrogate) |
| 52 | +#' |
| 53 | +#' acq_function$surrogate$update() |
| 54 | +#' acq_function$update() |
| 55 | +#' acq_function$eval_dt(data.table(x = c(-1, 0, 1))) |
| 56 | +#' } |
| 57 | +AcqFunctionEILog = R6Class("AcqFunctionEILog", |
| 58 | + inherit = AcqFunction, |
| 59 | + |
| 60 | + public = list( |
| 61 | + |
| 62 | + #' @field y_best (`numeric(1)`)\cr |
| 63 | + #' Best objective function value observed so far. |
| 64 | + #' In the case of maximization, this already includes the necessary change of sign. |
| 65 | + y_best = NULL, |
| 66 | + |
| 67 | + #' @description |
| 68 | + #' Creates a new instance of this [R6][R6::R6Class] class. |
| 69 | + #' |
| 70 | + #' @param surrogate (`NULL` | [SurrogateLearner]). |
| 71 | + #' @param epsilon (`numeric(1)`). |
| 72 | + initialize = function(surrogate = NULL, epsilon = 0) { |
| 73 | + assert_r6(surrogate, "SurrogateLearner", null.ok = TRUE) |
| 74 | + assert_number(epsilon, lower = 0, finite = TRUE) |
| 75 | + |
| 76 | + constants = ps(epsilon = p_dbl(lower = 0, default = 0)) |
| 77 | + constants$values$epsilon = epsilon |
| 78 | + |
| 79 | + super$initialize("acq_ei_log", constants = constants, surrogate = surrogate, requires_predict_type_se = TRUE, direction = "maximize", label = "Expected Improvement on Log Scale", man = "mlr3mbo::mlr_acqfunctions_ei_log") |
| 80 | + }, |
| 81 | + |
| 82 | + #' @description |
| 83 | + #' Update the acquisition function and set `y_best`. |
| 84 | + update = function() { |
| 85 | + assert_r6(self$surrogate$output_trafo, "OutputTrafoLog") |
| 86 | + assert_false(self$surrogate$output_trafo$invert_posterior) |
| 87 | + y = self$archive$data[, self$surrogate$cols_y, with = FALSE] |
| 88 | + if (self$surrogate$output_trafo_must_be_considered) { |
| 89 | + y = self$surrogate$output_trafo$transform(y) |
| 90 | + } |
| 91 | + self$y_best = min(self$surrogate_max_to_min * y) |
| 92 | + } |
| 93 | + ), |
| 94 | + |
| 95 | + private = list( |
| 96 | + .fun = function(xdt, ...) { |
| 97 | + if (is.null(self$y_best)) { |
| 98 | + stop("$y_best is not set. Missed to call $update()?") |
| 99 | + } |
| 100 | + assert_r6(self$surrogate$output_trafo, "OutputTrafoLog") |
| 101 | + assert_false(self$surrogate$output_trafo$invert_posterior) |
| 102 | + constants = list(...) |
| 103 | + epsilon = constants$epsilon |
| 104 | + p = self$surrogate$predict(xdt) |
| 105 | + mu = p$mean |
| 106 | + se = p$se |
| 107 | + |
| 108 | + # FIXME: try to unify w.r.t minimization / maximization and the respective transformation |
| 109 | + if (self$surrogate_max_to_min == 1L) { |
| 110 | + # y is to be minimized and the OutputTrafoLog performed the transformation accordingly |
| 111 | + assert_true(self$surrogate$output_trafo$max_to_min == 1L) |
| 112 | + y_best = self$y_best |
| 113 | + d = (y_best - mu) - epsilon |
| 114 | + d_norm = d / se |
| 115 | + multiplicative_factor = (self$surrogate$output_trafo$state[[self$surrogate$output_trafo$cols_y]]$max - self$surrogate$output_trafo$state[[self$surrogate$output_trafo$cols_y]]$min) |
| 116 | + ei_log = multiplicative_factor * ((exp(y_best) * pnorm(d_norm)) - (exp((0.5 * se^2) + mu)) * pnorm(d_norm - se)) |
| 117 | + } else { |
| 118 | + # y is to be maximized and the OutputTrafoLog performed the transformation accordingly |
| 119 | + y_best = - self$y_best |
| 120 | + d = (mu - y_best) - epsilon |
| 121 | + d_norm = d / se |
| 122 | + multiplicative_factor = (self$surrogate$output_trafo$state[[self$surrogate$output_trafo$cols_y]]$max - self$surrogate$output_trafo$state[[self$surrogate$output_trafo$cols_y]]$min) |
| 123 | + ei_log = multiplicative_factor * ((exp(-y_best) * pnorm(d_norm)) - (exp((0.5 * se^2) - mu) * pnorm(d_norm - se))) |
| 124 | + } |
| 125 | + ei_log = ifelse(se < 1e-20 | is.na(ei_log), 0, ei_log) |
| 126 | + data.table(acq_ei_log = ei_log) |
| 127 | + } |
| 128 | + ) |
| 129 | +) |
| 130 | + |
| 131 | +mlr_acqfunctions$add("ei_log", AcqFunctionEILog) |
| 132 | + |
0 commit comments