Skip to content

Commit 1b2f6fd

Browse files
authored
feat: add AcqFunctionMulti that can wrap multiple acquisition functions resulting in a multi-objective acquisition function problem (#157)
* feat: alllow SurrogateLearnerCollection to be used with an OptimInstanceSingleCrit and one y column * feat: AcqFunctionMulti that can wrap multiple acquisition functions resulting in a multi-objective acquisition function problem * feat: adjusted AcqOptimizer to be more robust (get_best) functionality but also handle AcqFunctionMulti * docs: improve documention of AcqFunctionMulti and AcqOptimizer
1 parent 5db34a1 commit 1b2f6fd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+788
-94
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ Collate:
8484
'AcqFunctionEI.R'
8585
'AcqFunctionEIPS.R'
8686
'AcqFunctionMean.R'
87+
'AcqFunctionMulti.R'
8788
'AcqFunctionPI.R'
8889
'AcqFunctionSD.R'
8990
'AcqFunctionSmsEgo.R'

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ export(AcqFunctionEHVIGH)
1212
export(AcqFunctionEI)
1313
export(AcqFunctionEIPS)
1414
export(AcqFunctionMean)
15+
export(AcqFunctionMulti)
1516
export(AcqFunctionPI)
1617
export(AcqFunctionSD)
1718
export(AcqFunctionSmsEgo)
@@ -25,6 +26,7 @@ export(SurrogateLearner)
2526
export(SurrogateLearnerCollection)
2627
export(TunerMbo)
2728
export(acqf)
29+
export(acqfs)
2830
export(acqo)
2931
export(bayesopt_ego)
3032
export(bayesopt_emo)

R/AcqFunctionAEI.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ AcqFunctionAEI = R6Class("AcqFunctionAEI",
9292
},
9393

9494
#' @description
95-
#' Updates acquisition function and sets `y_effective_best` and `noise_var`.
95+
#' Update the acquisition function and set `y_effective_best` and `noise_var`.
9696
update = function() {
9797
xdt = self$archive$data[, self$archive$cols_x, with = FALSE]
9898
p = self$surrogate$predict(xdt)

R/AcqFunctionEHVI.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ AcqFunctionEHVI = R6Class("AcqFunctionEHVI",
7676
},
7777

7878
#' @description
79-
#' Updates acquisition function and sets `ys_front`, `ref_point`.
79+
#' Update the acquisition function and set `ys_front` and `ref_point`.
8080
update = function() {
8181
n_obj = length(self$archive$cols_y)
8282
if (n_obj > 2L) {

R/AcqFunctionEHVIGH.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ AcqFunctionEHVIGH = R6Class("AcqFunctionEHVIGH",
104104
},
105105

106106
#' @description
107-
#' Updates acquisition function and sets `ys_front`, `ref_point`, `hypervolume`, `gh_data`.
107+
#' Update the acquisition function and set `ys_front`, `ref_point`, `hypervolume` and `gh_data`.
108108
update = function() {
109109
n_obj = length(self$archive$cols_y)
110110
ys = self$archive$data[, self$archive$cols_y, with = FALSE]

R/AcqFunctionEI.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ AcqFunctionEI = R6Class("AcqFunctionEI",
7979
},
8080

8181
#' @description
82-
#' Updates acquisition function and sets `y_best`.
82+
#' Update the acquisition function and set `y_best`.
8383
update = function() {
8484
self$y_best = min(self$surrogate_max_to_min * self$archive$data[[self$surrogate$cols_y]])
8585
}

R/AcqFunctionEIPS.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ AcqFunctionEIPS = R6Class("AcqFunctionEIPS",
7676
},
7777

7878
#' @description
79-
#' Updates acquisition function and sets `y_best`.
79+
#' Update the acquisition function and set `y_best`.
8080
update = function() {
8181
self$y_best = min(self$surrogate_max_to_min[[self$col_y]] * self$archive$data[[self$col_y]])
8282
}

R/AcqFunctionMulti.R

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
#' @title Acquisition Function Wrapping Multiple Acquisition Functions
2+
#'
3+
#' @include AcqFunction.R
4+
#' @name mlr_acqfunctions_multi
5+
#'
6+
#' @templateVar id multi
7+
#' @template section_dictionary_acqfunctions
8+
#'
9+
#' @description
10+
#' Wrapping multiple [AcqFunction]s resulting in a multi-objective acquisition function composed of the individual ones.
11+
#' Note that the optimization direction of each wrapped acquisition function is corrected for maximization.
12+
#'
13+
#' For each acquisition function, the same [Surrogate] must be used.
14+
#' If acquisition functions passed during construction already have been initialized with a surrogate, it is checked whether
15+
#' the surrogate is the same for all acquisition functions.
16+
#' If acquisition functions have not been initialized with a surrogate, the surrogate passed during construction or lazy initialization
17+
#' will be used for all acquisition functions.
18+
#'
19+
#' For optimization, [AcqOptimizer] can be used as for any other [AcqFunction], however, the [bbotk::Optimizer] wrapped within the [AcqOptimizer]
20+
#' must support multi-objective optimization as indicated via the `multi-crit` property.
21+
#'
22+
#' @family Acquisition Function
23+
#' @export
24+
#' @examples
25+
#' if (requireNamespace("mlr3learners") &
26+
#' requireNamespace("DiceKriging") &
27+
#' requireNamespace("rgenoud")) {
28+
#' library(bbotk)
29+
#' library(paradox)
30+
#' library(mlr3learners)
31+
#' library(data.table)
32+
#'
33+
#' fun = function(xs) {
34+
#' list(y = xs$x ^ 2)
35+
#' }
36+
#' domain = ps(x = p_dbl(lower = -10, upper = 10))
37+
#' codomain = ps(y = p_dbl(tags = "minimize"))
38+
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
39+
#'
40+
#' instance = OptimInstanceBatchSingleCrit$new(
41+
#' objective = objective,
42+
#' terminator = trm("evals", n_evals = 5))
43+
#'
44+
#' instance$eval_batch(data.table(x = c(-6, -5, 3, 9)))
45+
#'
46+
#' learner = default_gp()
47+
#'
48+
#' surrogate = srlrn(learner, archive = instance$archive)
49+
#'
50+
#' acq_function = acqf("multi",
51+
#' acq_functions = acqfs(c("ei", "pi", "cb")),
52+
#' surrogate = surrogate
53+
#' )
54+
#'
55+
#' acq_function$surrogate$update()
56+
#' acq_function$update()
57+
#' acq_function$eval_dt(data.table(x = c(-1, 0, 1)))
58+
#' }
59+
AcqFunctionMulti = R6Class("AcqFunctionMulti",
60+
inherit = AcqFunction,
61+
62+
public = list(
63+
64+
#' @description
65+
#' Creates a new instance of this [R6][R6::R6Class] class.
66+
#'
67+
#' @param acq_functions (list of [AcqFunction]s).
68+
#' @param surrogate (`NULL` | [Surrogate]).
69+
initialize = function(acq_functions, surrogate = NULL) {
70+
assert_list(acq_functions, "AcqFunction", min.len = 2L)
71+
acq_function_ids = map_chr(acq_functions, function(acq_function) acq_function$id)
72+
assert_character(acq_function_ids, unique = TRUE)
73+
acq_functions = setNames(acq_functions, nm = acq_function_ids)
74+
acq_function_directions = map_chr(acq_functions, function(acq_function) acq_function$direction)
75+
private$.acq_functions = acq_functions
76+
private$.acq_function_ids = acq_function_ids
77+
private$.acq_function_directions = acq_function_directions
78+
id = paste0(c("acq", map_chr(acq_function_ids, function(id) gsub("acq_", replacement = "", x = id))), collapse = "_")
79+
label = paste0("Multi Acquisition Function of ", paste0(map_chr(acq_functions, function(acq_function) acq_function$label), collapse = ", "))
80+
constants = ps()
81+
domains = map(acq_functions, function(acq_function) acq_function$domain)
82+
assert_true(all(map_lgl(domains[-1L], function(domain) all.equal(domains[[1L]]$data, domain$data))))
83+
if (is.null(surrogate)) {
84+
surrogates = map(acq_functions, function(acq_function) acq_function$surrogate)
85+
assert_list(surrogates, types = c("Surrogate", "NULL"))
86+
if (length(unique(map_chr(surrogates, function(surrogate) address(surrogate)))) > 1L) {
87+
stop("Acquisition functions must rely on the same surrogate model.")
88+
}
89+
surrogate = surrogates[[1L]]
90+
}
91+
requires_predict_type_se = any(map_lgl(acq_functions, function(acq_function) acq_function$requires_predict_type_se))
92+
packages = unique(unlist(map(acq_functions, function(acq_function) acq_function$packages)))
93+
properties = character()
94+
check_values = FALSE
95+
man = "mlr3mbo::mlr_acqfunctions_multi"
96+
97+
private$.requires_predict_type_se = requires_predict_type_se
98+
private$.packages = packages
99+
self$direction = "maximize"
100+
if (is.null(surrogate)) {
101+
domain = ParamSet$new()
102+
codomain = ParamSet$new()
103+
} else {
104+
if (requires_predict_type_se && surrogate$predict_type != "se") {
105+
stopf("Acquisition function '%s' requires the surrogate to have `\"se\"` as `$predict_type`.", sprintf("<%s:%s>", "AcqFunction", id))
106+
}
107+
private$.surrogate = surrogate
108+
private$.archive = assert_archive(surrogate$archive)
109+
for (acq_function in private$.acq_functions) {
110+
acq_function$surrogate = surrogate
111+
}
112+
codomain = generate_acq_multi_codomain(surrogate, acq_functions = acq_functions)
113+
self$surrogate_max_to_min = surrogate_mult_max_to_min(surrogate)
114+
domain = generate_acq_domain(surrogate)
115+
}
116+
117+
self$id = assert_string(id)
118+
self$domain = assert_param_set(domain)
119+
assert_param_set(codomain)
120+
# get "codomain" element if present (new paradox) or default to $params (old paradox)
121+
params = get0("domains", codomain, ifnotfound = codomain$params)
122+
self$codomain = Codomain$new(params)
123+
assert_names(self$domain$ids(), disjunct.from = self$codomain$ids())
124+
assert_names(self$domain$ids(), disjunct.from = c("x_domain", "timestamp", "batch_nr"))
125+
assert_names(self$codomain$ids(), disjunct.from = c("x_domain", "timestamp", "batch_nr"))
126+
self$properties = assert_subset(properties, bbotk_reflections$objective_properties)
127+
self$constants = assert_param_set(constants)
128+
self$check_values = assert_flag(check_values)
129+
private$.label = assert_string(label, na.ok = TRUE)
130+
private$.man = assert_string(man, na.ok = TRUE)
131+
},
132+
133+
#' @description
134+
#' Update each of the wrapped acquisition functions.
135+
update = function() {
136+
if (length(unique(map_chr(self$acq_functions, function(acq_function) address(acq_function$surrogate)))) > 1L) {
137+
stop("Acquisition functions must rely on the same surrogate model.")
138+
}
139+
for (acq_function in self$acq_functions) {
140+
acq_function$update()
141+
}
142+
}
143+
),
144+
145+
active = list(
146+
#' @field surrogate ([Surrogate])\cr
147+
#' Surrogate.
148+
surrogate = function(rhs) {
149+
if (missing(rhs)) {
150+
private$.surrogate
151+
} else {
152+
assert_r6(rhs, classes = "Surrogate")
153+
if (self$requires_predict_type_se && rhs$predict_type != "se") {
154+
stopf("Acquisition function '%s' requires the surrogate to have `\"se\"` as `$predict_type`.", format(self))
155+
}
156+
private$.surrogate = rhs
157+
private$.archive = assert_archive(rhs$archive)
158+
for (acq_function in self$acq_functions) {
159+
acq_function$surrogate = rhs
160+
}
161+
codomain = generate_acq_multi_codomain(rhs, acq_functions = self$acq_functions)
162+
self$surrogate_max_to_min = surrogate_mult_max_to_min(rhs)
163+
domain = generate_acq_domain(rhs)
164+
# lazy initialization requires this:
165+
self$codomain = Codomain$new(get0("domains", codomain, ifnotfound = codomain$params)) # get0 for old paradox
166+
self$domain = domain
167+
}
168+
},
169+
170+
#' @field acq_functions (list of [AcqFunction])\cr
171+
#' Points to the list of the individual acquisition functions.
172+
acq_functions = function(rhs) {
173+
if (!missing(rhs) && !identical(rhs, private$.acq_functions)) {
174+
stop("$acq_functions is read-only.")
175+
}
176+
private$.acq_functions
177+
},
178+
179+
#' @field acq_function_ids (character())\cr
180+
#' Points to the ids of the individual acquisition functions.
181+
acq_function_ids = function(rhs) {
182+
if (!missing(rhs) && !identical(rhs, private$.acq_function_ids)) {
183+
stop("$acq_function_ids is read-only.")
184+
}
185+
private$.acq_function_ids
186+
}
187+
),
188+
189+
private = list(
190+
.acq_functions = NULL,
191+
192+
.acq_function_ids = NULL,
193+
194+
.acq_function_directions = NULL,
195+
196+
# NOTE: this is currently slower than it could be because when each acquisition functions is evaluated,
197+
# the mean and se prediction for each point is computed again using the surrogate of that acquisition function,
198+
# however, as acquisition functions must share the same surrogate, this is redundant.
199+
# It might be sensible to have a customized eval function for acquisition functions where directly the mean and se
200+
# predictions are passed (along xdt) so that one can save computing the mean and se predictions over and over again.
201+
# This also would, however, depend on learners being fully deterministic.
202+
.fun = function(xdt) {
203+
values = map_dtc(self$acq_functions, function(acq_function) acq_function$eval_dt(xdt))
204+
ids = private$.acq_function_ids
205+
directions = private$.acq_function_directions
206+
if (any(directions == "same")) {
207+
directions[directions == "same"] = self$surrogate$archive$codomain$tags[[1L]]
208+
}
209+
change_sign = ids[directions == "minimize"]
210+
for (j in change_sign) {
211+
set(values, j = j, value = - values[[j]])
212+
}
213+
values
214+
},
215+
216+
deep_clone = function(name, value) {
217+
switch(name,
218+
.acq_functions = value$clone(deep = TRUE),
219+
value
220+
)
221+
}
222+
)
223+
)
224+
225+
mlr_acqfunctions$add("multi", AcqFunctionMulti)
226+

R/AcqFunctionPI.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ AcqFunctionPI = R6Class("AcqFunctionPI",
6666
},
6767

6868
#' @description
69-
#' Updates acquisition function and sets `y_best`.
69+
#' Update the acquisition function and set `y_best`.
7070
update = function() {
7171
self$y_best = min(self$surrogate_max_to_min * self$archive$data[[self$surrogate$cols_y]])
7272
}

R/AcqFunctionSmsEgo.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ AcqFunctionSmsEgo = R6Class("AcqFunctionSmsEgo",
103103
},
104104

105105
#' @description
106-
#' Updates acquisition function and sets `ys_front`, `ref_point`, `epsilon`.
106+
#' Update the acquisition function and set `ys_front`, `ref_point` and `epsilon`.
107107
update = function() {
108108
if (is.null(self$progress)) {
109109
stop("$progress is not set.") # needs self$progress here! Originally self$instance$terminator$param_set$values$n_evals - archive$n_evals

0 commit comments

Comments
 (0)