Skip to content

Commit e9392a4

Browse files
committed
test: refactor and speed up, skip some on cran
1 parent b1f932c commit e9392a4

19 files changed

+65
-80
lines changed

tests/testthat/helper.R

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,6 @@
11
lapply(list.files(system.file("testthat", package = "mlr3"),
22
pattern = "^helper.*\\.[rR]", full.names = TRUE), source)
33

4-
with_seed = function(seed, expr) {
5-
old_seed = get0(".Random.seed", globalenv(), mode = "integer", inherits = FALSE)
6-
if (is.null(old_seed)) {
7-
runif(1L)
8-
old_seed = get0(".Random.seed", globalenv(), mode = "integer", inherits = FALSE)
9-
}
10-
11-
on.exit(assign(".Random.seed", old_seed, globalenv()), add = TRUE)
12-
set.seed(seed)
13-
force(expr)
14-
}
15-
164
# Simple 1D Functions
175
PS_1D = ParamSet$new(list(
186
ParamDbl$new("x", lower = -1, upper = 1)
@@ -108,9 +96,9 @@ MAKE_DESIGN = function(instance, n = 4L) {
10896

10997
if (requireNamespace("mlr3learners") && requireNamespace("DiceKriging") && requireNamespace("rgenoud")) {
11098
library(mlr3learners)
111-
REGR_KM_NOISY = lrn("regr.km", covtype = "matern3_2", optim.method = "gen", control = list(trace = FALSE), nugget.estim = TRUE, jitter = 1e-12)
99+
REGR_KM_NOISY = lrn("regr.km", covtype = "matern3_2", optim.method = "gen", control = list(trace = FALSE, max.generations = 2), nugget.estim = TRUE, jitter = 1e-12)
112100
REGR_KM_NOISY$encapsulate = c(train = "callr", predict = "callr")
113-
REGR_KM_DETERM = lrn("regr.km", covtype = "matern3_2", optim.method = "gen", control = list(trace = FALSE), nugget.stability = 10^-8)
101+
REGR_KM_DETERM = lrn("regr.km", covtype = "matern3_2", optim.method = "gen", control = list(trace = FALSE, max.generations = 2), nugget.stability = 10^-8)
114102
REGR_KM_DETERM$encapsulate = c(train = "callr", predict = "callr")
115103
}
116104
REGR_FEATURELESS = lrn("regr.featureless")
@@ -207,3 +195,4 @@ expect_acqfunction = function(acqf) {
207195
expect_string(acqf$label)
208196
expect_man_exists(acqf$man)
209197
}
198+

tests/testthat/teardown.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
lg$set_threshold(old_threshold_bbotk)
22
lg_mlr3$set_threshold(old_threshold_mlr3)
3+

tests/testthat/test_AcqFunction.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ test_that("AcqFunction API works", {
1010
expect_equal(acqf$direction, "same")
1111
expect_equal(acqf$domain, inst$search_space)
1212
})
13+

tests/testthat/test_AcqFunctionCB.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ test_that("AcqFunctionCB works", {
55
expect_acqfunction(acqf)
66

77
expect_r6(acqf$codomain, "ParamSet")
8-
expect_equal(acqf$codomain$ids(), "acq_cb")
8+
expect_equal(acqf$codomain$ids(), acqf$id)
99
expect_equal(acqf$surrogate_max_to_min, c(y = 1))
1010
expect_equal(acqf$direction, "same")
1111
expect_equal(acqf$domain, inst$search_space)
@@ -18,6 +18,6 @@ test_that("AcqFunctionCB works", {
1818
xdt = data.table(x = seq(-1, 1, length.out = 5L))
1919
res = acqf$eval_dt(xdt)
2020
expect_data_table(res, ncols = 1L, nrows = 5L, any.missing = FALSE)
21-
expect_named(res, "acq_cb")
21+
expect_named(res, acqf$id)
2222
})
2323

tests/testthat/test_AcqFunctionEI.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ test_that("AcqFunctionEI works", {
66
expect_acqfunction(acqf)
77

88
expect_r6(acqf$codomain, "ParamSet")
9-
expect_equal(acqf$codomain$ids(), "acq_ei")
9+
expect_equal(acqf$codomain$ids(), acqf$id)
1010
expect_equal(acqf$surrogate_max_to_min, c(y = 1))
1111
expect_equal(acqf$direction, "maximize")
1212
expect_equal(acqf$domain, inst$search_space)
@@ -21,7 +21,7 @@ test_that("AcqFunctionEI works", {
2121
acqf$update()
2222
res = acqf$eval_dt(xdt)
2323
expect_data_table(res, ncols = 1L, nrows = 5L, any.missing = FALSE)
24-
expect_named(res, "acq_ei")
24+
expect_named(res, acqf$id)
2525
})
2626

2727
test_that("AcqFunctionEI trafo", {
@@ -47,7 +47,7 @@ test_that("AcqFunctionEI trafo", {
4747
acqf$update()
4848
res = acqf$eval_dt(xdt)
4949
expect_data_table(res, ncols = 1L, nrows = 11L, any.missing = FALSE)
50-
expect_named(res, "acq_ei")
50+
expect_named(res, acqf$id)
5151
expect_true(max(res$acq_ei) == res$acq_ei[6]) # at x = 15
5252
})
5353

tests/testthat/test_AcqFunctionEIPS.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ test_that("AcqFunctionEIPS works", {
1010
expect_acqfunction(acqf)
1111

1212
expect_r6(acqf$codomain, "ParamSet")
13-
expect_equal(acqf$codomain$ids(), "acq_eips")
13+
expect_equal(acqf$codomain$ids(), acqf$id)
1414
expect_equal(acqf$surrogate_max_to_min, c(y = 1, time = 1)) # FIXME: check this
1515
expect_equal(acqf$direction, "maximize")
1616
expect_equal(acqf$domain, inst$search_space)
@@ -25,6 +25,6 @@ test_that("AcqFunctionEIPS works", {
2525
acqf$update()
2626
res = acqf$eval_dt(xdt)
2727
expect_data_table(res, ncols = 1L, nrows = 5L, any.missing = FALSE)
28-
expect_named(res, "acq_eips")
28+
expect_named(res, acqf$id)
2929
})
3030

tests/testthat/test_AcqFunctionMean.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ test_that("AcqFunctionMean works", {
55
expect_acqfunction(acqf)
66

77
expect_r6(acqf$codomain, "ParamSet")
8-
expect_equal(acqf$codomain$ids(), "acq_mean")
8+
expect_equal(acqf$codomain$ids(), acqf$id)
99
expect_equal(acqf$surrogate_max_to_min, c(y = 1))
1010
expect_equal(acqf$direction, "same")
1111
expect_equal(acqf$domain, inst$search_space)
@@ -18,6 +18,6 @@ test_that("AcqFunctionMean works", {
1818
xdt = data.table(x = seq(-1, 1, length.out = 5L))
1919
res = acqf$eval_dt(xdt)
2020
expect_data_table(res, ncols = 1L, nrows = 5L, any.missing = FALSE)
21-
expect_named(res, "acq_mean")
21+
expect_named(res, acqf$id)
2222
})
2323

tests/testthat/test_AcqFunctionPI.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ test_that("AcqFunctionPI works", {
66

77

88
expect_r6(acqf$codomain, "ParamSet")
9-
expect_equal(acqf$codomain$ids(), "acq_pi")
9+
expect_equal(acqf$codomain$ids(), acqf$id)
1010
expect_equal(acqf$surrogate_max_to_min, c(y = 1))
1111
expect_equal(acqf$direction, "maximize")
1212
expect_equal(acqf$domain, inst$search_space)
@@ -21,6 +21,6 @@ test_that("AcqFunctionPI works", {
2121
acqf$update()
2222
res = acqf$eval_dt(xdt)
2323
expect_data_table(res, ncols = 1L, nrows = 5L, any.missing = FALSE)
24-
expect_named(res, "acq_pi")
24+
expect_named(res, acqf$id)
2525
})
2626

tests/testthat/test_AcqFunctionSmsEgo.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ test_that("AcqFunctionSmsEgo works", {
55
expect_acqfunction(acqf)
66

77
expect_r6(acqf$codomain, "ParamSet")
8-
expect_equal(acqf$codomain$ids(), "acq_smsego")
8+
expect_equal(acqf$codomain$ids(), acqf$id)
99
expect_equal(acqf$surrogate_max_to_min, c(y1 = 1, y2 = 1))
1010
expect_equal(acqf$direction, "minimize")
1111
expect_equal(acqf$domain, inst$search_space)
@@ -22,6 +22,6 @@ test_that("AcqFunctionSmsEgo works", {
2222
acqf$update()
2323
res = acqf$eval_dt(xdt)
2424
expect_data_table(res, ncols = 1L, nrows = 5L, any.missing = FALSE)
25-
expect_named(res, "acq_smsego")
25+
expect_named(res, acqf$id)
2626
})
2727

tests/testthat/test_AcqOptimizer.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ test_that("AcqOptimizer API works", {
99
instance$eval_batch(design)
1010
acqfun = AcqFunctionEI$new(SurrogateLearner$new(REGR_KM_DETERM, archive = instance$archive))
1111
acqopt = AcqOptimizer$new(opt("random_search", batch_size = 2L), trm("evals", n_evals = 2L), acq_function = acqfun)
12-
with_seed(24, {acqfun$surrogate$update()})
12+
acqfun$surrogate$update()
1313
acqfun$update()
1414
expect_data_table(acqopt$optimize(), nrows = 1L)
1515

0 commit comments

Comments
 (0)