Skip to content

Commit 865659f

Browse files
mb706sumny
andauthored
check with new paradox (#136)
refactor: compatibility with upcoming paradox upgrade --------- Co-authored-by: Lennart Schneider <[email protected]>
1 parent 42d8b11 commit 865659f

18 files changed

+107
-87
lines changed

.github/workflows/dev-cmd-check.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ jobs:
2727
- {os: ubuntu-latest, r: 'release', dev-package: 'mlr-org/bbotk'}
2828
- {os: ubuntu-latest, r: 'release', dev-package: 'mlr-org/mlr3'}
2929
- {os: ubuntu-latest, r: 'release', dev-package: 'mlr-org/mlr3tuning'}
30+
- {os: ubuntu-latest, r: 'release', dev-package: "mlr-org/mlr3tuning', 'mlr-org/mlr3learners', 'mlr-org/mlr3pipelines', 'mlr-org/bbotk', 'mlr-org/paradox"}
3031

3132
steps:
3233
- uses: actions/checkout@v3
@@ -43,7 +44,7 @@ jobs:
4344
needs: check
4445

4546
- name: Install dev versions
46-
run: pak::pkg_install('${{ matrix.config.dev-package }}')
47+
run: pak::pkg_install(c('${{ matrix.config.dev-package }}'))
4748
shell: Rscript {0}
4849

4950
- uses: r-lib/actions/check-r-package@v2

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mlr3mbo 0.2.1.9000
22

3+
* refactor: compatibility with upcoming paradox upgrade.
34
* feat: `OptimizerMbo` and `TunerMbo` now update the `Surrogate` a final time after the optimization process finished to
45
ensure that the `Surrogate` correctly reflects the state of being trained on all data seen during optimization.
56
* fix: `AcqFunction` domain construction now respects `Surrogate` cols_x field.

R/AcqFunction.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ AcqFunction = R6Class("AcqFunction",
180180
codomain = generate_acq_codomain(rhs, id = self$id, direction = self$direction)
181181
self$surrogate_max_to_min = surrogate_mult_max_to_min(rhs)
182182
domain = generate_acq_domain(rhs)
183-
self$codomain = Codomain$new(codomain$params) # lazy initialization requires this
183+
# lazy initialization requires this:
184+
self$codomain = Codomain$new(get0("domains", codomain, ifnotfound = codomain$params)) # get0 for old paradox
184185
self$domain = domain
185186
}
186187
},

R/AcqFunctionEHVIGH.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ AcqFunctionEHVIGH = R6Class("AcqFunctionEHVIGH",
9393
assert_r6(surrogate, "SurrogateLearnerCollection", null.ok = TRUE)
9494
assert_int(k, lower = 2L)
9595

96-
constants = ParamSet$new(list(
97-
ParamInt$new("k", lower = 2L, default = 15L),
98-
ParamDbl$new("r", lower = 0, upper = 1, default = 0.2)
99-
))
96+
constants = ps(
97+
k = p_int(lower = 2L, default = 15L),
98+
r = p_dbl(lower = 0, upper = 1, default = 0.2)
99+
)
100100
constants$values$k = k
101101
constants$values$r = r
102102

R/AcqFunctionSmsEgo.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ AcqFunctionSmsEgo = R6Class("AcqFunctionSmsEgo",
9292
assert_number(lambda, lower = 1, finite = TRUE)
9393
assert_number(epsilon, lower = 0, finite = TRUE, null.ok = TRUE)
9494

95-
constants = ParamSet$new(list(
96-
ParamDbl$new("lambda", lower = 0, default = 1),
97-
ParamDbl$new("epsilon", lower = 0, default = NULL, special_vals = list(NULL)) # for NULL, it will be calculated dynamically
98-
))
95+
constants = ps(
96+
lambda = p_dbl(lower = 0, default = 1),
97+
epsilon = p_dbl(lower = 0, default = NULL, special_vals = list(NULL)) # for NULL, it will be calculated dynamically
98+
)
9999
constants$values$lambda = lambda
100100
constants$values$epsilon = epsilon
101101

R/AcqOptimizer.R

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,13 @@ AcqOptimizer = R6Class("AcqOptimizer",
103103
self$optimizer = assert_r6(optimizer, "Optimizer")
104104
self$terminator = assert_r6(terminator, "Terminator")
105105
self$acq_function = assert_r6(acq_function, "AcqFunction", null.ok = TRUE)
106-
ps = ParamSet$new(list(
107-
ParamInt$new("n_candidates", lower = 1, default = 1L),
108-
ParamFct$new("logging_level", levels = c("fatal", "error", "warn", "info", "debug", "trace"), default = "warn"),
109-
ParamLgl$new("warmstart", default = FALSE),
110-
ParamInt$new("warmstart_size", lower = 1L, special_vals = list("all")),
111-
ParamLgl$new("skip_already_evaluated", default = TRUE),
112-
ParamLgl$new("catch_errors", default = TRUE))
106+
ps = ps(
107+
n_candidates = p_int(lower = 1, default = 1L),
108+
logging_level = p_fct(levels = c("fatal", "error", "warn", "info", "debug", "trace"), default = "warn"),
109+
warmstart = p_lgl(default = FALSE),
110+
warmstart_size = p_int(lower = 1L, special_vals = list("all")),
111+
skip_already_evaluated = p_lgl(default = TRUE),
112+
catch_errors = p_lgl(default = TRUE)
113113
)
114114
ps$values = list(n_candidates = 1, logging_level = "warn", warmstart = FALSE, skip_already_evaluated = TRUE, catch_errors = TRUE)
115115
ps$add_dep("warmstart_size", on = "warmstart", cond = CondEqual$new(TRUE))

R/Surrogate.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Surrogate = R6Class("Surrogate",
3030
private$.cols_x = assert_character(cols_x, min.len = 1L, null.ok = TRUE)
3131
private$.cols_y = cols_y = assert_character(cols_y, min.len = 1L, null.ok = TRUE)
3232
assert_r6(param_set, classes = "ParamSet")
33-
assert_r6(param_set$params$catch_errors, classes = "ParamLgl")
33+
stopifnot(param_set$class[["catch_errors"]] == "ParamLgl")
3434
private$.param_set = param_set
3535
},
3636

R/SurrogateLearner.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ SurrogateLearner = R6Class("SurrogateLearner",
8383
assert_character(cols_x, min.len = 1L, null.ok = TRUE)
8484
assert_string(col_y, null.ok = TRUE)
8585

86-
ps = ParamSet$new(list(
87-
ParamLgl$new("assert_insample_perf"),
88-
ParamUty$new("perf_measure", custom_check = function(x) check_r6(x, classes = "MeasureRegr")), # FIXME: actually want check_measure
89-
ParamDbl$new("perf_threshold", lower = -Inf, upper = Inf),
90-
ParamLgl$new("catch_errors"))
86+
ps = ps(
87+
assert_insample_perf = p_lgl(),
88+
perf_measure = p_uty(custom_check = function(x) check_r6(x, classes = "MeasureRegr")), # FIXME: actually want check_measure
89+
perf_threshold = p_dbl(lower = -Inf, upper = Inf),
90+
catch_errors = p_lgl()
9191
)
9292
ps$values = list(assert_insample_perf = FALSE, catch_errors = TRUE)
9393
ps$add_dep("perf_measure", on = "assert_insample_perf", cond = CondEqual$new(TRUE))

R/SurrogateLearnerCollection.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,11 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection",
9696
assert_character(cols_x, min.len = 1L, null.ok = TRUE)
9797
assert_character(cols_y, len = length(learners), null.ok = TRUE)
9898

99-
ps = ParamSet$new(list(
100-
ParamLgl$new("assert_insample_perf"),
101-
ParamUty$new("perf_measures", custom_check = function(x) check_list(x, types = "MeasureRegr", any.missing = FALSE, len = length(learners))), # FIXME: actually want check_measures
102-
ParamUty$new("perf_thresholds", custom_check = function(x) check_double(x, lower = -Inf, upper = Inf, any.missing = FALSE, len = length(learners))),
103-
ParamLgl$new("catch_errors"))
99+
ps = ps(
100+
assert_insample_perf = p_lgl(),
101+
perf_measures = p_uty(custom_check = function(x) check_list(x, types = "MeasureRegr", any.missing = FALSE, len = length(learners))), # FIXME: actually want check_measures
102+
perf_thresholds = p_uty(custom_check = function(x) check_double(x, lower = -Inf, upper = Inf, any.missing = FALSE, len = length(learners))),
103+
catch_errors = p_lgl()
104104
)
105105
ps$values = list(assert_insample_perf = FALSE, catch_errors = TRUE)
106106
ps$add_dep("perf_measures", on = "assert_insample_perf", cond = CondEqual$new(TRUE))

R/helper.R

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,28 @@ generate_acq_codomain = function(surrogate, id, direction = "same") {
66
if (surrogate$archive$codomain$length > 1L) {
77
stop("Not supported yet.") # FIXME: But should be?
88
}
9-
tags = surrogate$archive$codomain$params[[1L]]$tags
9+
tags = surrogate$archive$codomain$tags[[1L]]
1010
tags = tags[tags %in% c("minimize", "maximize")] # only filter out the relevant one
1111
} else {
1212
tags = direction
1313
}
14-
codomain = ParamSet$new(list(
15-
ParamDbl$new(id, tags = tags)
16-
))
17-
codomain
14+
do.call(ps, structure(list(p_dbl(tags = tags)), names = id))
1815
}
1916

2017
generate_acq_domain = function(surrogate) {
2118
assert_r6(surrogate$archive, classes = "Archive")
22-
domain = surrogate$archive$search_space$clone(deep = TRUE)$subset(surrogate$cols_x)
23-
domain$trafo = NULL
19+
if ("set_id" %in% names(ps())) {
20+
# old paradox
21+
domain = surrogate$archive$search_space$clone(deep = TRUE)$subset(surrogate$cols_x)
22+
domain$trafo = NULL
23+
} else {
24+
# get "domain" objects, set their .trafo-entry to NULL individually
25+
dms = lapply(surrogate$archive$search_space$domains[surrogate$cols_x], function(x) {
26+
x$.trafo[1] = list(NULL)
27+
x
28+
})
29+
domain = do.call(ps, dms)
30+
}
2431
domain
2532
}
2633

@@ -130,7 +137,7 @@ check_learner_surrogate = function(learner) {
130137
return(TRUE)
131138
}
132139
}
133-
140+
134141
"Must inherit from class 'Learner' or be a list of elements inheriting from class 'Learner'"
135142
}
136143

0 commit comments

Comments
 (0)