Skip to content

Commit 3de4d0c

Browse files
authored
refactor: extract internal tuned values in instance (#164)
* refactor: extract internal tuned values in instance * ... * ...
1 parent b3b8c74 commit 3de4d0c

File tree

5 files changed

+67
-10
lines changed

5 files changed

+67
-10
lines changed

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ BugReports: https://github.com/mlr-org/mlr3mbo/issues
4141
Depends:
4242
R (>= 3.1.0)
4343
Imports:
44-
bbotk (>= 1.0.0),
44+
bbotk (>= 1.1.1),
4545
checkmate (>= 2.0.0),
4646
data.table,
4747
lgr (>= 0.3.4),
4848
mlr3 (>= 0.21.0),
4949
mlr3misc (>= 0.11.0),
50-
mlr3tuning (>= 1.0.0),
50+
mlr3tuning (>= 1.0.2),
5151
paradox (>= 1.0.0),
5252
spacefillr,
5353
R6 (>= 2.4.1)

R/ResultAssignerArchive.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ ResultAssignerArchive = R6Class("ResultAssignerArchive",
2929
#' @param instance ([bbotk::OptimInstanceBatchSingleCrit] | [bbotk::OptimInstanceBatchMultiCrit])\cr
3030
#' The [bbotk::OptimInstance] the final result should be assigned to.
3131
assign_result = function(instance) {
32-
res = instance$archive$best()
33-
xdt = res[, instance$search_space$ids(), with = FALSE]
32+
xydt = instance$archive$best()
33+
xdt = xydt[, instance$search_space$ids(), with = FALSE]
3434
if (inherits(instance, "OptimInstanceBatchMultiCrit")) {
35-
ydt = res[, instance$archive$cols_y, with = FALSE]
36-
instance$assign_result(xdt, ydt)
35+
ydt = xydt[, instance$archive$cols_y, with = FALSE]
36+
instance$assign_result(xdt, ydt, xydt = xydt)
3737
}
3838
else {
39-
y = unlist(res[, instance$archive$cols_y, with = FALSE])
40-
instance$assign_result(xdt, y)
39+
y = unlist(xydt[, instance$archive$cols_y, with = FALSE])
40+
instance$assign_result(xdt, y, xydt = xydt)
4141
}
4242
}
4343
),

R/ResultAssignerSurrogate.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,16 @@ ResultAssignerSurrogate = R6Class("ResultAssignerSurrogate",
5959
}
6060
archive_tmp = archive$clone(deep = TRUE)
6161
archive_tmp$data[, self$surrogate$cols_y := means]
62-
best = archive_tmp$best()[, archive_tmp$cols_x, with = FALSE]
62+
xydt = archive_tmp$best()
63+
best = xydt[, archive_tmp$cols_x, with = FALSE]
6364

6465
# ys are still the ones originally evaluated
6566
best_y = if (inherits(instance, "OptimInstanceBatchSingleCrit")) {
6667
unlist(archive$data[best, on = archive$cols_x][, archive$cols_y, with = FALSE])
6768
} else if (inherits(instance, "OptimInstanceBatchMultiCrit")) {
6869
archive$data[best, on = archive$cols_x][, archive$cols_y, with = FALSE]
6970
}
70-
instance$assign_result(xdt = best, best_y)
71+
instance$assign_result(xdt = best, best_y, xydt = xydt)
7172
}
7273
),
7374

tests/testthat/test_ResultAssignerArchive.R

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,31 @@ test_that("ResultAssignerArchive works with OptimizerMbo and bayesopt_smsego", {
5454
expect_data_table(instance$result, min.rows = 1L)
5555
})
5656

57+
test_that("ResultAssignerArchive passes internal tuned values", {
58+
result_assigner = ResultAssignerArchive$new()
59+
60+
learner = lrn("classif.debug",
61+
validate = 0.2,
62+
early_stopping = TRUE,
63+
x = to_tune(0.2, 0.3),
64+
iter = to_tune(upper = 1000, internal = TRUE, aggr = function(x) 99))
65+
66+
instance = ti(
67+
task = tsk("pima"),
68+
learner = learner,
69+
resampling = rsmp("cv", folds = 3),
70+
measures = msr("classif.ce"),
71+
terminator = trm("evals", n_evals = 20),
72+
store_benchmark_result = TRUE
73+
)
74+
surrogate = SurrogateLearner$new(REGR_KM_DETERM)
75+
acq_function = AcqFunctionEI$new()
76+
acq_optimizer = AcqOptimizer$new(opt("random_search", batch_size = 2L), terminator = trm("evals", n_evals = 2L))
77+
78+
tuner = tnr("mbo", result_assigner = result_assigner)
79+
expect_data_table(tuner$optimize(instance), nrows = 1)
80+
expect_list(instance$archive$data$internal_tuned_values, len = 20, types = "list")
81+
expect_equal(instance$archive$data$internal_tuned_values[[1]], list(iter = 99))
82+
expect_false(instance$result_learner_param_vals$early_stopping)
83+
expect_equal(instance$result_learner_param_vals$iter, 99)
84+
})

tests/testthat/test_ResultAssignerSurrogate.R

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,31 @@ test_that("ResultAssignerSurrogate works with OptimizerMbo and bayesopt_smsego",
8585
expect_data_table(instance$result, min.rows = 1L)
8686
})
8787

88+
test_that("ResultAssignerSurrogate passes internal tuned values", {
89+
result_assigner = ResultAssignerSurrogate$new()
90+
91+
learner = lrn("classif.debug",
92+
validate = 0.2,
93+
early_stopping = TRUE,
94+
x = to_tune(0.2, 0.3),
95+
iter = to_tune(upper = 1000, internal = TRUE, aggr = function(x) 99))
96+
97+
instance = ti(
98+
task = tsk("pima"),
99+
learner = learner,
100+
resampling = rsmp("cv", folds = 3),
101+
measures = msr("classif.ce"),
102+
terminator = trm("evals", n_evals = 20),
103+
store_benchmark_result = TRUE
104+
)
105+
surrogate = SurrogateLearner$new(REGR_KM_DETERM)
106+
acq_function = AcqFunctionEI$new()
107+
acq_optimizer = AcqOptimizer$new(opt("random_search", batch_size = 2L), terminator = trm("evals", n_evals = 2L))
108+
109+
tuner = tnr("mbo", result_assigner = result_assigner)
110+
expect_data_table(tuner$optimize(instance), nrows = 1)
111+
expect_list(instance$archive$data$internal_tuned_values, len = 20, types = "list")
112+
expect_equal(instance$archive$data$internal_tuned_values[[1]], list(iter = 99))
113+
expect_false(instance$result_learner_param_vals$early_stopping)
114+
expect_equal(instance$result_learner_param_vals$iter, 99)
115+
})

0 commit comments

Comments
 (0)