|
| 1 | +#' @title Acquisition Function Wrapping Multiple Acquisition Functions |
| 2 | +#' |
| 3 | +#' @include AcqFunction.R |
| 4 | +#' @name mlr_acqfunctions_multi |
| 5 | +#' |
| 6 | +#' @templateVar id multi |
| 7 | +#' @template section_dictionary_acqfunctions |
| 8 | +#' |
| 9 | +#' @description |
| 10 | +#' Wrapping multiple [AcqFunction]s resulting in a multi-objective acquisition function composed of the individual ones. |
| 11 | +#' Note that the optimization direction of each wrapped acquisition function is corrected for maximization. |
| 12 | +#' |
| 13 | +#' For each acquisition function, the same [Surrogate] must be used. |
| 14 | +#' If acquisition functions passed during construction already have been initialized with a surrogate, it is checked whether |
| 15 | +#' the surrogate is the same for all acquisition functions. |
| 16 | +#' If acquisition functions have not been initialized with a surrogate, the surrogate passed during construction or lazy initialization |
| 17 | +#' will be used for all acquisition functions. |
| 18 | +#' |
| 19 | +#' For optimization, [AcqOptimizer] can be used as for any other [AcqFunction], however, the [bbotk::Optimizer] wrapped within the [AcqOptimizer] |
| 20 | +#' must support multi-objective optimization as indicated via the `multi-crit` property. |
| 21 | +#' |
| 22 | +#' @family Acquisition Function |
| 23 | +#' @export |
| 24 | +#' @examples |
| 25 | +#' if (requireNamespace("mlr3learners") & |
| 26 | +#' requireNamespace("DiceKriging") & |
| 27 | +#' requireNamespace("rgenoud")) { |
| 28 | +#' library(bbotk) |
| 29 | +#' library(paradox) |
| 30 | +#' library(mlr3learners) |
| 31 | +#' library(data.table) |
| 32 | +#' |
| 33 | +#' fun = function(xs) { |
| 34 | +#' list(y = xs$x ^ 2) |
| 35 | +#' } |
| 36 | +#' domain = ps(x = p_dbl(lower = -10, upper = 10)) |
| 37 | +#' codomain = ps(y = p_dbl(tags = "minimize")) |
| 38 | +#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain) |
| 39 | +#' |
| 40 | +#' instance = OptimInstanceBatchSingleCrit$new( |
| 41 | +#' objective = objective, |
| 42 | +#' terminator = trm("evals", n_evals = 5)) |
| 43 | +#' |
| 44 | +#' instance$eval_batch(data.table(x = c(-6, -5, 3, 9))) |
| 45 | +#' |
| 46 | +#' learner = default_gp() |
| 47 | +#' |
| 48 | +#' surrogate = srlrn(learner, archive = instance$archive) |
| 49 | +#' |
| 50 | +#' acq_function = acqf("multi", |
| 51 | +#' acq_functions = acqfs(c("ei", "pi", "cb")), |
| 52 | +#' surrogate = surrogate |
| 53 | +#' ) |
| 54 | +#' |
| 55 | +#' acq_function$surrogate$update() |
| 56 | +#' acq_function$update() |
| 57 | +#' acq_function$eval_dt(data.table(x = c(-1, 0, 1))) |
| 58 | +#' } |
| 59 | +AcqFunctionMulti = R6Class("AcqFunctionMulti", |
| 60 | + inherit = AcqFunction, |
| 61 | + |
| 62 | + public = list( |
| 63 | + |
| 64 | + #' @description |
| 65 | + #' Creates a new instance of this [R6][R6::R6Class] class. |
| 66 | + #' |
| 67 | + #' @param acq_functions (list of [AcqFunction]s). |
| 68 | + #' @param surrogate (`NULL` | [Surrogate]). |
| 69 | + initialize = function(acq_functions, surrogate = NULL) { |
| 70 | + assert_list(acq_functions, "AcqFunction", min.len = 2L) |
| 71 | + acq_function_ids = map_chr(acq_functions, function(acq_function) acq_function$id) |
| 72 | + assert_character(acq_function_ids, unique = TRUE) |
| 73 | + acq_functions = setNames(acq_functions, nm = acq_function_ids) |
| 74 | + acq_function_directions = map_chr(acq_functions, function(acq_function) acq_function$direction) |
| 75 | + private$.acq_functions = acq_functions |
| 76 | + private$.acq_function_ids = acq_function_ids |
| 77 | + private$.acq_function_directions = acq_function_directions |
| 78 | + id = paste0(c("acq", map_chr(acq_function_ids, function(id) gsub("acq_", replacement = "", x = id))), collapse = "_") |
| 79 | + label = paste0("Multi Acquisition Function of ", paste0(map_chr(acq_functions, function(acq_function) acq_function$label), collapse = ", ")) |
| 80 | + constants = ps() |
| 81 | + domains = map(acq_functions, function(acq_function) acq_function$domain) |
| 82 | + assert_true(all(map_lgl(domains[-1L], function(domain) all.equal(domains[[1L]]$data, domain$data)))) |
| 83 | + if (is.null(surrogate)) { |
| 84 | + surrogates = map(acq_functions, function(acq_function) acq_function$surrogate) |
| 85 | + assert_list(surrogates, types = c("Surrogate", "NULL")) |
| 86 | + if (length(unique(map_chr(surrogates, function(surrogate) address(surrogate)))) > 1L) { |
| 87 | + stop("Acquisition functions must rely on the same surrogate model.") |
| 88 | + } |
| 89 | + surrogate = surrogates[[1L]] |
| 90 | + } |
| 91 | + requires_predict_type_se = any(map_lgl(acq_functions, function(acq_function) acq_function$requires_predict_type_se)) |
| 92 | + packages = unique(unlist(map(acq_functions, function(acq_function) acq_function$packages))) |
| 93 | + properties = character() |
| 94 | + check_values = FALSE |
| 95 | + man = "mlr3mbo::mlr_acqfunctions_multi" |
| 96 | + |
| 97 | + private$.requires_predict_type_se = requires_predict_type_se |
| 98 | + private$.packages = packages |
| 99 | + self$direction = "maximize" |
| 100 | + if (is.null(surrogate)) { |
| 101 | + domain = ParamSet$new() |
| 102 | + codomain = ParamSet$new() |
| 103 | + } else { |
| 104 | + if (requires_predict_type_se && surrogate$predict_type != "se") { |
| 105 | + stopf("Acquisition function '%s' requires the surrogate to have `\"se\"` as `$predict_type`.", sprintf("<%s:%s>", "AcqFunction", id)) |
| 106 | + } |
| 107 | + private$.surrogate = surrogate |
| 108 | + private$.archive = assert_archive(surrogate$archive) |
| 109 | + for (acq_function in private$.acq_functions) { |
| 110 | + acq_function$surrogate = surrogate |
| 111 | + } |
| 112 | + codomain = generate_acq_multi_codomain(surrogate, acq_functions = acq_functions) |
| 113 | + self$surrogate_max_to_min = surrogate_mult_max_to_min(surrogate) |
| 114 | + domain = generate_acq_domain(surrogate) |
| 115 | + } |
| 116 | + |
| 117 | + self$id = assert_string(id) |
| 118 | + self$domain = assert_param_set(domain) |
| 119 | + assert_param_set(codomain) |
| 120 | + # get "codomain" element if present (new paradox) or default to $params (old paradox) |
| 121 | + params = get0("domains", codomain, ifnotfound = codomain$params) |
| 122 | + self$codomain = Codomain$new(params) |
| 123 | + assert_names(self$domain$ids(), disjunct.from = self$codomain$ids()) |
| 124 | + assert_names(self$domain$ids(), disjunct.from = c("x_domain", "timestamp", "batch_nr")) |
| 125 | + assert_names(self$codomain$ids(), disjunct.from = c("x_domain", "timestamp", "batch_nr")) |
| 126 | + self$properties = assert_subset(properties, bbotk_reflections$objective_properties) |
| 127 | + self$constants = assert_param_set(constants) |
| 128 | + self$check_values = assert_flag(check_values) |
| 129 | + private$.label = assert_string(label, na.ok = TRUE) |
| 130 | + private$.man = assert_string(man, na.ok = TRUE) |
| 131 | + }, |
| 132 | + |
| 133 | + #' @description |
| 134 | + #' Update each of the wrapped acquisition functions. |
| 135 | + update = function() { |
| 136 | + if (length(unique(map_chr(self$acq_functions, function(acq_function) address(acq_function$surrogate)))) > 1L) { |
| 137 | + stop("Acquisition functions must rely on the same surrogate model.") |
| 138 | + } |
| 139 | + for (acq_function in self$acq_functions) { |
| 140 | + acq_function$update() |
| 141 | + } |
| 142 | + } |
| 143 | + ), |
| 144 | + |
| 145 | + active = list( |
| 146 | + #' @field surrogate ([Surrogate])\cr |
| 147 | + #' Surrogate. |
| 148 | + surrogate = function(rhs) { |
| 149 | + if (missing(rhs)) { |
| 150 | + private$.surrogate |
| 151 | + } else { |
| 152 | + assert_r6(rhs, classes = "Surrogate") |
| 153 | + if (self$requires_predict_type_se && rhs$predict_type != "se") { |
| 154 | + stopf("Acquisition function '%s' requires the surrogate to have `\"se\"` as `$predict_type`.", format(self)) |
| 155 | + } |
| 156 | + private$.surrogate = rhs |
| 157 | + private$.archive = assert_archive(rhs$archive) |
| 158 | + for (acq_function in self$acq_functions) { |
| 159 | + acq_function$surrogate = rhs |
| 160 | + } |
| 161 | + codomain = generate_acq_multi_codomain(rhs, acq_functions = self$acq_functions) |
| 162 | + self$surrogate_max_to_min = surrogate_mult_max_to_min(rhs) |
| 163 | + domain = generate_acq_domain(rhs) |
| 164 | + # lazy initialization requires this: |
| 165 | + self$codomain = Codomain$new(get0("domains", codomain, ifnotfound = codomain$params)) # get0 for old paradox |
| 166 | + self$domain = domain |
| 167 | + } |
| 168 | + }, |
| 169 | + |
| 170 | + #' @field acq_functions (list of [AcqFunction])\cr |
| 171 | + #' Points to the list of the individual acquisition functions. |
| 172 | + acq_functions = function(rhs) { |
| 173 | + if (!missing(rhs) && !identical(rhs, private$.acq_functions)) { |
| 174 | + stop("$acq_functions is read-only.") |
| 175 | + } |
| 176 | + private$.acq_functions |
| 177 | + }, |
| 178 | + |
| 179 | + #' @field acq_function_ids (character())\cr |
| 180 | + #' Points to the ids of the individual acquisition functions. |
| 181 | + acq_function_ids = function(rhs) { |
| 182 | + if (!missing(rhs) && !identical(rhs, private$.acq_function_ids)) { |
| 183 | + stop("$acq_function_ids is read-only.") |
| 184 | + } |
| 185 | + private$.acq_function_ids |
| 186 | + } |
| 187 | + ), |
| 188 | + |
| 189 | + private = list( |
| 190 | + .acq_functions = NULL, |
| 191 | + |
| 192 | + .acq_function_ids = NULL, |
| 193 | + |
| 194 | + .acq_function_directions = NULL, |
| 195 | + |
| 196 | + # NOTE: this is currently slower than it could be because when each acquisition functions is evaluated, |
| 197 | + # the mean and se prediction for each point is computed again using the surrogate of that acquisition function, |
| 198 | + # however, as acquisition functions must share the same surrogate, this is redundant. |
| 199 | + # It might be sensible to have a customized eval function for acquisition functions where directly the mean and se |
| 200 | + # predictions are passed (along xdt) so that one can save computing the mean and se predictions over and over again. |
| 201 | + # This also would, however, depend on learners being fully deterministic. |
| 202 | + .fun = function(xdt) { |
| 203 | + values = map_dtc(self$acq_functions, function(acq_function) acq_function$eval_dt(xdt)) |
| 204 | + ids = private$.acq_function_ids |
| 205 | + directions = private$.acq_function_directions |
| 206 | + if (any(directions == "same")) { |
| 207 | + directions[directions == "same"] = self$surrogate$archive$codomain$tags[[1L]] |
| 208 | + } |
| 209 | + change_sign = ids[directions == "minimize"] |
| 210 | + for (j in change_sign) { |
| 211 | + set(values, j = j, value = - values[[j]]) |
| 212 | + } |
| 213 | + values |
| 214 | + }, |
| 215 | + |
| 216 | + deep_clone = function(name, value) { |
| 217 | + switch(name, |
| 218 | + .acq_functions = value$clone(deep = TRUE), |
| 219 | + value |
| 220 | + ) |
| 221 | + } |
| 222 | + ) |
| 223 | +) |
| 224 | + |
| 225 | +mlr_acqfunctions$add("multi", AcqFunctionMulti) |
| 226 | + |
0 commit comments