Skip to content

Commit e3383c7

Browse files
authored
fix: AcqFunction domain construction respects Surrogate cols_x now (#129)
* docs: infill --> acquisition function * fix: AcqFunction domain construction respects Surrogate cols_x now
1 parent 9054a2d commit e3383c7

File tree

9 files changed

+87
-34
lines changed

9 files changed

+87
-34
lines changed

R/AcqFunction.R

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,9 @@ AcqFunction = R6Class("AcqFunction",
6060
}
6161
private$.surrogate = surrogate
6262
private$.archive = assert_r6(surrogate$archive, classes = "Archive")
63-
codomain = generate_acq_codomain(surrogate$archive$codomain, id = id, direction = direction)
64-
self$surrogate_max_to_min = surrogate_mult_max_to_min(surrogate$archive$codomain, cols_y = surrogate$cols_y)
65-
domain = surrogate$archive$search_space$clone(deep = TRUE)
66-
domain$trafo = NULL
63+
codomain = generate_acq_codomain(surrogate, id = id, direction = direction)
64+
self$surrogate_max_to_min = surrogate_mult_max_to_min(surrogate)
65+
domain = generate_acq_domain(surrogate)
6766
}
6867
super$initialize(id = id, domain = domain, codomain = codomain, constants = constants)
6968
},
@@ -160,7 +159,7 @@ AcqFunction = R6Class("AcqFunction",
160159
},
161160

162161
#' @field fun (`function`)\cr
163-
#' Pointing to the private acquisition function to be implemented by subclasses.
162+
#' Points to the private acquisition function to be implemented by subclasses.
164163
fun = function(lhs) {
165164
if (!missing(lhs) && !identical(lhs, private$.fun)) stop("$fun is read-only.")
166165
private$.fun
@@ -178,10 +177,9 @@ AcqFunction = R6Class("AcqFunction",
178177
}
179178
private$.surrogate = rhs
180179
private$.archive = assert_r6(rhs$archive, classes = "Archive")
181-
codomain = generate_acq_codomain(rhs$archive$codomain, id = self$id, direction = self$direction)
182-
self$surrogate_max_to_min = surrogate_mult_max_to_min(rhs$archive$codomain, cols_y = rhs$cols_y)
183-
domain = rhs$archive$search_space$clone(deep = TRUE)
184-
domain$trafo = NULL
180+
codomain = generate_acq_codomain(rhs, id = self$id, direction = self$direction)
181+
self$surrogate_max_to_min = surrogate_mult_max_to_min(rhs)
182+
domain = generate_acq_domain(rhs)
185183
self$codomain = Codomain$new(codomain$params) # lazy initialization requires this
186184
self$domain = domain
187185
}

R/SurrogateLearner.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#' }
2424
#' \item{`catch_errors`}{`logical(1)`\cr
2525
#' Should errors during updating the surrogate be caught and propagated to the `loop_function` which can then handle
26-
#' the failed infill optimization (as a result of the failed surrogate) appropriately by, e.g., proposing a randomly sampled point for evaluation?
26+
#' the failed acquisition function optimization (as a result of the failed surrogate) appropriately by, e.g., proposing a randomly sampled point for evaluation?
2727
#' Default is `TRUE`.
2828
#' }
2929
#' }

R/SurrogateLearnerCollection.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
#' }
2626
#' \item{`catch_errors`}{`logical(1)`\cr
2727
#' Should errors during updating the surrogate be caught and propagated to the `loop_function` which can then handle
28-
#' the failed infill optimization (as a result of the failed surrogate) appropriately by, e.g., proposing a randomly sampled point for evaluation?
28+
#' the failed acquisition function optimization (as a result of the failed surrogate) appropriately by, e.g., proposing a randomly sampled point for evaluation?
2929
#' Default is `TRUE`.
3030
#' }
3131
#' }

R/helper.R

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
1-
generate_acq_codomain = function(codomain, id, direction = "same") {
2-
assert_choice(direction, c("same", "minimize", "maximize"))
1+
generate_acq_codomain = function(surrogate, id, direction = "same") {
2+
assert_r6(surrogate$archive, classes = "Archive")
3+
assert_string(id)
4+
assert_choice(direction, choices = c("same", "minimize", "maximize"))
35
if (direction == "same") {
4-
if (codomain$length > 1L) {
6+
if (surrogate$archive$codomain$length > 1L) {
57
stop("Not supported yet.") # FIXME: But should be?
68
}
7-
tags = codomain$params[[1L]]$tags
9+
tags = surrogate$archive$codomain$params[[1L]]$tags
810
tags = tags[tags %in% c("minimize", "maximize")] # only filter out the relevant one
911
} else {
1012
tags = direction
1113
}
1214
codomain = ParamSet$new(list(
1315
ParamDbl$new(id, tags = tags)
1416
))
15-
return(codomain)
17+
codomain
18+
}
19+
20+
generate_acq_domain = function(surrogate) {
21+
assert_r6(surrogate$archive, classes = "Archive")
22+
domain = surrogate$archive$search_space$clone(deep = TRUE)$subset(surrogate$cols_x)
23+
domain$trafo = NULL
24+
domain
1625
}
1726

1827
archive_xy = function(archive) {
@@ -46,10 +55,16 @@ calculate_parego_weights = function(s, k) {
4655
matrix(unlist(fun(s, k)), ncol = k, byrow = TRUE) / s
4756
}
4857

49-
surrogate_mult_max_to_min = function(codomain, cols_y) {
58+
surrogate_mult_max_to_min = function(surrogate) {
59+
codomain = surrogate$archive$codomain
60+
cols_y = surrogate$cols_y
5061
mult = map_int(cols_y, function(col_y) {
51-
mult = if (col_y %in% codomain$ids()) {
52-
if(has_element(codomain$tags[[col_y]], "maximize")) -1L else 1L
62+
mult = if (col_y %in% surrogate$archive$codomain$ids()) {
63+
if (has_element(surrogate$archive$codomain$tags[[col_y]], "maximize")) {
64+
-1L
65+
} else {
66+
1L
67+
}
5368
} else {
5469
1L
5570
}
@@ -58,7 +73,7 @@ surrogate_mult_max_to_min = function(codomain, cols_y) {
5873
}
5974

6075
mult_max_to_min = function(codomain) {
61-
ifelse(map_lgl(codomain$tags, has_element, "minimize"), 1, -1)
76+
ifelse(map_lgl(codomain$tags, has_element, "minimize"), yes = 1L, no = -1L)
6277
}
6378

6479
# used in AcqOptimizer
@@ -86,7 +101,9 @@ catn = function(..., file = "") {
86101
}
87102

88103
set_collapse = function(x) {
89-
if (length(x) == 0L) return("{}")
104+
if (length(x) == 0L) {
105+
return("{}")
106+
}
90107
sprintf("{'%s'}", paste0(unique(x), collapse = "','"))
91108
}
92109

@@ -95,14 +112,14 @@ check_attributes = function(x, attribute_names) {
95112
if (any(attribute_names %nin% names(attributes(x)))) {
96113
return(sprintf("Attributes must include '%s' but is '%s'", set_collapse(attribute_names), set_collapse(names(attributes(x)))))
97114
}
98-
return(TRUE)
115+
TRUE
99116
}
100117

101118
check_instance_attribute = function(x) {
102119
if (length(intersect(c("single-crit", "multi-crit"), attr(x, "instance"))) == 0L) {
103120
return(sprintf("'instance' attribute must be a subset of '%s' but is '%s'", set_collapse(c("single-crit", "multi-crit")), set_collapse(attr(x, "instance"))))
104121
}
105-
return(TRUE)
122+
TRUE
106123
}
107124

108125
check_learner_surrogate = function(learner) {
@@ -118,7 +135,9 @@ check_learner_surrogate = function(learner) {
118135
}
119136

120137
assert_loop_function = function(x, .var.name = vname(x)) {
121-
if (is.null(x)) return(x)
138+
if (is.null(x)) {
139+
return(x)
140+
}
122141
# NOTE: this is buggy in checkmate; assert should always return x invisible not TRUE as is the case here
123142
assert(check_class(x, classes = "loop_function"),
124143
check_function(x, args = c("instance", "surrogate", "acq_function", "acq_optimizer")),

man/AcqFunction.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/SurrogateLearner.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/SurrogateLearnerCollection.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/helper.R

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,29 +49,32 @@ FUN_1D_2_MIXED = function(xs) {
4949
}
5050
OBJ_1D_2_MIXED = ObjectiveRFun$new(fun = FUN_1D_2_MIXED, domain = PS_1D_MIXED, codomain = FUN_1D_2_CODOMAIN, properties = "multi-crit")
5151

52-
# Simple 2D Function
53-
PS_2D_domain = ParamSet$new(list(
52+
# Simple 2D Functions
53+
PS_2D = ParamSet$new(list(
5454
ParamDbl$new("x1", lower = -1, upper = 1),
55-
ParamDbl$new("x2", lower = -1, upper = 1),
56-
ParamUty$new("foo") # the domain of the function should not matter
55+
ParamDbl$new("x2", lower = -1, upper = 1)
5756
))
58-
PS_2D = ParamSet$new(list(
57+
PS_2D_trafo = ParamSet$new(list(
5958
ParamDbl$new("x1", lower = -1, upper = 1),
6059
ParamDbl$new("x2", lower = -1, upper = 1)
6160
))
61+
PS_2D_trafo$trafo = function(x, param_set) {
62+
x$x2 = x$x2 ^ 2
63+
x
64+
}
6265
FUN_2D = function(xs) {
6366
y = sum(as.numeric(xs)^2)
6467
list(y = y)
6568
}
6669
FUN_2D_CODOMAIN = ParamSet$new(list(ParamDbl$new("y", tags = c("minimize", "random_tag"))))
67-
OBJ_2D = ObjectiveRFun$new(fun = FUN_2D, domain = PS_2D_domain, properties = "single-crit")
70+
OBJ_2D = ObjectiveRFun$new(fun = FUN_2D, domain = PS_2D, properties = "single-crit")
6871

6972
# Simple 2D Function with noise
7073
FUN_2D_NOISY = function(xs) {
7174
y = sum(as.numeric(xs)^2) + rnorm(1, sd = 0.5)
7275
list(y = y)
7376
}
74-
OBJ_2D_NOISY = ObjectiveRFun$new(fun = FUN_2D_NOISY, domain = PS_2D_domain, properties = c("single-crit", "noisy"))
77+
OBJ_2D_NOISY = ObjectiveRFun$new(fun = FUN_2D_NOISY, domain = PS_2D, properties = c("single-crit", "noisy"))
7578

7679
# Instance helper
7780
MAKE_INST = function(objective = OBJ_2D, search_space = PS_2D, terminator = trm("evals", n_evals = 10L)) {

tests/testthat/test_AcqFunction.R

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,36 @@ test_that("AcqFunction packages works", {
3737
expect_equal(acqf$packages, "mlr3mbo")
3838
})
3939

40+
test_that("AcqFunction generate_acq_codomain works", {
41+
inst = MAKE_INST(OBJ_2D, search_space = PS_2D)
42+
surrogate = SurrogateLearner$new(REGR_FEATURELESS, archive = inst$archive)
43+
codomain = generate_acq_codomain(surrogate, "acqf")
44+
expect_r6(codomain, "ParamSet")
45+
expect_equal(codomain$tags[["acqf"]], "minimize")
46+
})
47+
48+
test_that("AcqFunction generate_acq_domain works", {
49+
inst = MAKE_INST(OBJ_2D, search_space = PS_2D)
50+
surrogate = SurrogateLearner$new(REGR_FEATURELESS, archive = inst$archive)
51+
domain = generate_acq_domain(surrogate)
52+
expect_equal(domain, OBJ_2D$domain)
53+
expect_equal(domain, inst$search_space)
54+
55+
inst = MAKE_INST(OBJ_2D, search_space = PS_2D_trafo)
56+
expect_true(inst$search_space$has_trafo)
57+
surrogate = SurrogateLearner$new(REGR_FEATURELESS, archive = inst$archive)
58+
domain = generate_acq_domain(surrogate)
59+
expect_equal(domain, OBJ_2D$domain)
60+
expect_false(domain$has_trafo)
61+
62+
surrogate$cols_x = "x2"
63+
domain = generate_acq_domain(surrogate)
64+
surrogate$cols_x = "x2"
65+
domain = generate_acq_domain(surrogate)
66+
expect_equal(domain, OBJ_2D$domain$clone(deep = TRUE)$subset("x2"))
67+
expect_false(domain$has_trafo)
68+
69+
surrogate = SurrogateLearner$new(REGR_FEATURELESS)
70+
expect_error(generate_acq_domain(surrogate), "Must be an R6 class, not 'NULL'")
71+
})
72+

0 commit comments

Comments
 (0)