Skip to content

Commit 1af5dc8

Browse files
committed
more tests
1 parent b0dc285 commit 1af5dc8

File tree

2 files changed

+106
-1
lines changed

2 files changed

+106
-1
lines changed

R/zzz.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ register_mlr3 = function() {
2222
register_mlr3filters = function() {
2323
if ("mlr3filters" %in% loadedNamespaces()) {
2424
x = utils::getFromNamespace("mlr_filters", ns = "mlr3filters")
25-
mlr_filters$add("ensemble", FilterEnsemble)
25+
x$add("ensemble", FilterEnsemble)
2626
}
2727
}
2828

@@ -37,6 +37,9 @@ paradox_info <- list2env(list(is_old = FALSE), parent = emptyenv())
3737
register_mlr3()
3838
register_mlr3filters()
3939
}, action = "append")
40+
setHook(packageEvent("mlr3filters", "onLoad"), function(...) {
41+
register_mlr3filters()
42+
}, action = "append")
4043
backports::import(pkgname)
4144

4245
assign("lg", lgr::get_logger("mlr3/mlr3pipelines"), envir = parent.env(environment()))

tests/testthat/test_filter_ensemble.R

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,57 @@ test_that("FilterEnsemble combines wrapped filter scores", {
3333
expect_equal(combined_rank_scores[task$feature_names], expected_rank)
3434
})
3535

36+
test_that("FilterEnsemble with a single filter is a passthrough", {
37+
skip_if_not_installed("mlr3filters")
38+
39+
task = mlr_tasks$get("sonar")
40+
variance_ref = mlr3filters::FilterVariance$new()
41+
variance_ref$calculate(task)
42+
43+
ensemble = FilterEnsemble$new(list(mlr3filters::FilterVariance$new()))
44+
ensemble$param_set$values$weights = c(variance = 1)
45+
ensemble$calculate(task)
46+
47+
expect_equal(ensemble$scores[task$feature_names], variance_ref$scores[task$feature_names])
48+
})
49+
50+
test_that("FilterEnsemble with one-hot weights selects the active filter", {
51+
skip_if_not_installed("mlr3filters")
52+
53+
task = mlr_tasks$get("sonar")
54+
variance_filter = mlr3filters::FilterVariance$new()
55+
auc_filter = mlr3filters::FilterAUC$new()
56+
57+
variance_filter$calculate(task)
58+
auc_filter$calculate(task)
59+
60+
weights = c(variance = 1, auc = 0)
61+
ensemble = FilterEnsemble$new(list(
62+
variance_filter$clone(deep = TRUE),
63+
auc_filter$clone(deep = TRUE)
64+
))
65+
ensemble$param_set$values$weights = weights
66+
ensemble$calculate(task)
67+
68+
expect_equal(ensemble$scores[task$feature_names], variance_filter$scores[task$feature_names])
69+
})
70+
71+
test_that("FilterEnsemble is registered in mlr_filters", {
72+
skip_if_not_installed("mlr3filters")
73+
expect_true("ensemble" %in% mlr3filters::mlr_filters$keys())
74+
})
75+
76+
test_that("FilterEnsemble identifier concatenates wrapped filter ids", {
77+
skip_if_not_installed("mlr3filters")
78+
79+
ensemble = FilterEnsemble$new(list(
80+
mlr3filters::FilterVariance$new(),
81+
mlr3filters::FilterAUC$new()
82+
))
83+
84+
expect_identical(ensemble$id, "variance.auc")
85+
})
86+
3687
test_that("FilterEnsemble reorders named weights correctly", {
3788
skip_if_not_installed("mlr3filters")
3889

@@ -110,6 +161,57 @@ test_that("FilterEnsemble requires at least one wrapped filter", {
110161
expect_error(FilterEnsemble$new(list()), "length >= 1")
111162
})
112163

164+
test_that("FilterEnsemble task types are intersected", {
165+
skip_if_not_installed("mlr3filters")
166+
167+
ensemble = FilterEnsemble$new(list(
168+
mlr3filters::FilterAnova$new(), # classif
169+
mlr3filters::FilterCorrelation$new(), # regr
170+
mlr3filters::FilterVariance$new() # any
171+
))
172+
173+
expect_identical(ensemble$task_types, character())
174+
})
175+
176+
test_that("FilterEnsemble feature types intersect across filters", {
177+
skip_if_not_installed("mlr3filters")
178+
skip_if_not_installed("mlr3learners")
179+
skip_if_not_installed("rpart")
180+
181+
importance_filter = mlr3filters::FilterImportance$new(mlr3::lrn("classif.rpart"))
182+
variance_filter = mlr3filters::FilterVariance$new()
183+
184+
ensemble = FilterEnsemble$new(list(
185+
variance_filter,
186+
importance_filter
187+
))
188+
189+
expect_setequal(ensemble$feature_types, c("integer", "numeric"))
190+
})
191+
192+
test_that("FilterEnsemble aggregates packages and task properties", {
193+
skip_if_not_installed("mlr3filters")
194+
195+
ensemble = FilterEnsemble$new(list(
196+
mlr3filters::FilterVariance$new(), # stats
197+
mlr3filters::FilterAUC$new() # mlr3measures, twoclass property
198+
))
199+
200+
expect_true(all(c("stats", "mlr3measures") %in% ensemble$packages))
201+
expect_identical(ensemble$task_properties, "twoclass")
202+
})
203+
204+
test_that("FilterEnsemble clones wrapped filters on construction", {
205+
skip_if_not_installed("mlr3filters")
206+
207+
original_variance = mlr3filters::FilterVariance$new()
208+
ensemble = FilterEnsemble$new(list(original_variance))
209+
210+
variance_in_ensemble = ensemble$wrapped$variance
211+
variance_in_ensemble$param_set$values$na.rm = FALSE
212+
expect_true(original_variance$param_set$values$na.rm)
213+
})
214+
113215
test_that("FilterEnsemble works inside PipeOpFilter", {
114216
skip_if_not_installed("mlr3filters")
115217

0 commit comments

Comments
 (0)