Skip to content

Commit fdcdda4

Browse files
authored
Merge pull request #796 from mlr-org/filterensemble
filterensemble
2 parents 8dd971f + cbdb747 commit fdcdda4

File tree

10 files changed

+824
-6
lines changed

10 files changed

+824
-6
lines changed

AGENTS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ Straightforwardness: Avoid ideological adherence to other programming principles
9090

9191
- R unit tests in this repo assume helper `expect_man_exists()` is available. If you need to call it in a new test and you are working without mlr3pipelines installed, define a local fallback at the top of that test file before `expect_learner()` is used.
9292
- Revdep helper scripts live in `attic/revdeps/`. `download_revdeps.R` downloads reverse dependency source tarballs; `install_revdep_suggests.R` installs Suggests for those revdeps without pulling the revdeps themselves.
93+
- When writing `paradox::ParamSet` custom checks (e.g. `p_uty(custom_check = ...)`), you do not need to special-case `TuneToken`s. `paradox` skips custom validators for `TuneToken` inputs before evaluating them, so the check only sees concrete values.
9394

9495
</agent_notes>
9596
<your_task>

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ Collate:
120120
'CnfFormula_simplify.R'
121121
'CnfSymbol.R'
122122
'CnfUniverse.R'
123+
'FilterEnsemble.R'
123124
'Graph.R'
124125
'GraphLearner.R'
125126
'mlr_pipeops.R'

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ export(CnfClause)
9191
export(CnfFormula)
9292
export(CnfSymbol)
9393
export(CnfUniverse)
94+
export(FilterEnsemble)
9495
export(Graph)
9596
export(GraphLearner)
9697
export(LearnerClassifAvg)

NEWS.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
# mlr3pipelines 0.9.0-9000
22

33
* Pretty-printing some info using the `cli` package now.
4-
* Fix: Added internal workaround for `PipeOpNMF` attaching `Biobase`, `BiocGenerics`, and `generics` to the search path during training, prediction or when printing its `$state`.
4+
* New PipeOp `PipeOpInfo` prints or logs info about objects passing through.
5+
* New Pipeop `PipeOpIsomap` implements isomap embedding from `dimRed::embed`
56
* feat: allow dates in datefeatures pipe op and use data.table for date feature generation.
6-
* Added support for internal validation tasks to `PipeOpFeatureUnion`.
77
* feat: `PipeOpLearnerCV` can reuse the cross-validation models during prediction by averaging their outputs (`resampling.predict_method = "cv_ensemble"`).
8-
* feat: `PipeOpRegrAvg` gets new `se_aggr` and `se_aggr_rho` hyperparameters and now allows various forms of SE aggregation.
9-
* Fix: `PipeOpInfo` now prints a bounded task preview (respecting target/feature ordering and row ids) and collapses logger output to single messages.
10-
* Fix: `PipeOpIsomap` only operates on numeric or integer features and its parameter documentation was corrected.
8+
* feat: `PipeOpRegrAvg` gets new `se_aggr`, `se_aggr_rho`, `prob_aggr`, and `prob_aggr_eps` hyperparameters and now allows different forms of prob / SE aggregation.
9+
* feat: `FilterEnsemble` implements Binder et al. (2020) *Multi-Objective Hyperparameter Tuning and Feature Selection using Filter Ensembles*
1110
* Fix: `PipeOpRemoveConstants` now avoids integer overflow when evaluating relative tolerances for near-`integer.max` data.
11+
* Fix: Added support for internal validation tasks to `PipeOpFeatureUnion`.
12+
* Fix: Added internal workaround for `PipeOpNMF` attaching `Biobase`, `BiocGenerics`, and `generics` to the search path during training, prediction or when printing its `$state`.
1213
* Compatibility with new testthat version 3.3.0
1314

15+
1416
# mlr3pipelines 0.9.0
1517

1618
* Breaking change: Removed initialization of `PipeOpImputeConstant`'s `constant` hyperparameter since it was incompatible with other defaults and would lead to not recommended usage (creating an empty level).

R/FilterEnsemble.R

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
2+
3+
#' @title Filter Ensemble
4+
#'
5+
#' @usage NULL
6+
#' @name mlr_filters_ensemble
7+
#' @format [`R6Class`][R6::R6Class] object inheriting from [`Filter`][mlr3filters::Filter].
8+
#'
9+
#' @description
10+
#' `FilterEnsemble` aggregates several [`Filter`][mlr3filters::Filter]s by averaging their scores
11+
#' (or ranks) with user-defined weights. Each wrapped filter is evaluated on the supplied task,
12+
#' and the resulting feature scores are combined feature-wise by a convex combination determined
13+
#' through the `weights` parameter. This allows leveraging complementary inductive biases of
14+
#' multiple filters without committing to a single criterion. The concept was introduced by
15+
#' Binder et al. (2020). This implementation follows the idea but leaves the exact choice of
16+
#' weights to the user.
17+
#'
18+
#' @section Construction:
19+
#' ```
20+
#' FilterEnsemble$new(filters)
21+
#' ```
22+
#'
23+
#' * `filters` :: `list` of [`Filter`][mlr3filters::Filter]\cr
24+
#' Filters that are evaluated and aggregated. Each filter must be cloneable and support the
25+
#' task type and feature types of the ensemble. The ensemble identifier defaults to the wrapped
26+
#' filter ids concatenated by `"."`.
27+
#'
28+
#' @section Parameters:
29+
#' * `weights` :: `numeric()`\cr
30+
#' Required non-negative weights, one for each wrapped filter, with at least one strictly positive value.
31+
#' Values are used as given when calculating the weighted mean. If named, names must match the wrapped filter ids.
32+
#' * `rank_transform` :: `logical(1)`\cr
33+
#' If `TRUE`, ranks of individual filter scores are used instead of the raw scores before
34+
#' averaging. Initialized to `FALSE`.
35+
#'
36+
#' Parameters of wrapped filters are available via `$param_set` and can be referenced using
37+
#' the wrapped filter id followed by `"."`, e.g. `"variance.na.rm"`.
38+
#'
39+
#' @section Fields:
40+
#' * `$wrapped` :: named `list` of [`Filter`][mlr3filters::Filter]\cr
41+
#' Read-only access to the wrapped filters.
42+
#'
43+
#' @section Methods:
44+
#' * `get_weights_search_space(weights_param_name = "weights", normalize_weights = "uniform", prefix = "w")`\cr
45+
#' (`character(1)`, `character(1)`, `character(1)`) -> [`ParamSet`][paradox::ParamSet]\cr
46+
#' Construct a [`ParamSet`][paradox::ParamSet] describing a weight search space.
47+
#' * `get_weights_tunetoken(normalize_weights = "uniform")`\cr
48+
#' (`character(1)`) -> [`TuneToken`][paradox::TuneToken]\cr
49+
#' Shortcut returning a [`TuneToken`][paradox::TuneToken] for tuning the weights.
50+
#' * `set_weights_to_tune(normalize_weights = "uniform")`\cr
51+
#' (`character(1)`) -> `self`\cr
52+
#' Convenience wrapper that stores the `TuneToken` returned by
53+
#' `get_weights_tunetoken()` in `$param_set$values$weights`.
54+
#'
55+
#' @section Internals:
56+
#' All wrapped filters are called with `nfeat` equal to the number of features to ensure that
57+
#' complete score vectors are available for aggregation. Scores are combined per feature by
58+
#' computing the weighted (optionally rank-based) mean.
59+
#'
60+
#' @section References:
61+
#' `r format_bib("binder_2020")`
62+
#'
63+
#' @examplesIf mlr3misc::require_namespaces("mlr3filters", quietly = TRUE)
64+
#' library("mlr3")
65+
#' library("mlr3filters")
66+
#'
67+
#' task = tsk("sonar")
68+
#'
69+
#' flt = mlr_filters$get("ensemble",
70+
#' filters = list(FilterVariance$new(), FilterAUC$new()))
71+
#' flt$param_set$values$weights = c(variance = 0.5, auc = 0.5)
72+
#' flt$calculate(task)
73+
#' head(as.data.table(flt))
74+
#' @export
75+
FilterEnsemble = R6Class("FilterEnsemble", inherit = mlr3filters::Filter,
76+
public = list(
77+
initialize = function(filters) {
78+
private$.wrapped = lapply(assert_list(filters, types = "Filter", min.len = 1), function(x) x$clone(deep = TRUE))
79+
fnames = map_chr(private$.wrapped, "id")
80+
names(private$.wrapped) = fnames
81+
types_list = map(discard(private$.wrapped, function(x) test_scalar_na(x$task_types)), "task_types")
82+
if (length(types_list)) {
83+
task_types = Reduce(intersect, types_list)
84+
} else {
85+
task_types = NA_character_
86+
}
87+
.own_param_set = ps(
88+
weights = p_uty(custom_check = crate(function(x) {
89+
if (inherits(x, "TuneToken")) {
90+
return(TRUE)
91+
}
92+
check_numeric(x, len = length(fnames), lower = 0) %check&&%
93+
(check_names(names(x), type = "unnamed") %check||%
94+
check_names(names(x), type = "unique", permutation.of = fnames)) %check&&%
95+
(if (any(x > 0)) TRUE else "At least one weight must be > 0.")
96+
}, fnames),
97+
tags = "required"
98+
),
99+
rank_transform = p_lgl(init = FALSE, tags = "required")
100+
)
101+
102+
super$initialize(
103+
id = paste(fnames, collapse = "."),
104+
task_types = task_types,
105+
task_properties = unique(unlist(map(private$.wrapped, "task_properties"))),
106+
param_set = .own_param_set,
107+
feature_types = Reduce(intersect, map(private$.wrapped, "feature_types")),
108+
packages = unique(unlist(map(private$.wrapped, "packages"))),
109+
label = "meta",
110+
man = "mlr3pipelines::mlr_filters_ensemble"
111+
)
112+
private$.own_param_set = .own_param_set
113+
private$.param_set = NULL
114+
},
115+
get_weights_tunetoken = function(normalize_weights = "uniform") {
116+
assert_choice(normalize_weights, c("uniform", "naive", "no"))
117+
paradox::to_tune(self$get_weights_search_space(normalize_weights = normalize_weights))
118+
},
119+
set_weights_to_tune = function(normalize_weights = "uniform") {
120+
assert_choice(normalize_weights, c("uniform", "naive", "no"))
121+
self$param_set$set_values(.values = list(weights = self$get_weights_tunetoken(normalize_weights = normalize_weights)))
122+
invisible(self)
123+
},
124+
get_weights_search_space = function(weights_param_name = "weights", normalize_weights = "uniform", prefix = "w") {
125+
assert_string(prefix)
126+
assert_string(weights_param_name)
127+
assert_choice(normalize_weights, c("uniform", "naive", "no"))
128+
fnames = names(private$.wrapped)
129+
innames = if (prefix == "") fnames else paste0(prefix, ".", fnames)
130+
domains = rep(list(p_dbl(0, 1)), length(fnames))
131+
names(domains) = innames
132+
133+
domains$.extra_trafo = crate(function(x) {
134+
w = unlist(x[innames], use.names = FALSE)
135+
names(w) = fnames
136+
x[innames] = NULL
137+
138+
if (normalize_weights == "uniform") {
139+
w[w > 1 - .Machine$double.eps] = 1 - .Machine$double.eps
140+
w = -log1p(-w)
141+
w = w / max(sum(w), .Machine$double.eps)
142+
} else if (normalize_weights == "naive") {
143+
w = w / max(sum(w), .Machine$double.eps)
144+
}
145+
if (!any(w > 0)) {
146+
w[] = 1 / length(w)
147+
}
148+
x[[weights_param_name]] = w
149+
x
150+
}, innames, fnames, normalize_weights, weights_param_name)
151+
152+
do.call(paradox::ps, domains)
153+
}
154+
),
155+
private = list(
156+
.wrapped = NULL,
157+
.own_param_set = NULL,
158+
.param_set = NULL,
159+
.calculate = function(task, nfeat) {
160+
pv = private$.own_param_set$get_values()
161+
fn = task$feature_names
162+
nfeat = length(fn) # need to rank all features in an ensemble
163+
weights = pv$weights
164+
wnames = names(private$.wrapped)
165+
if (!is.null(names(weights))) {
166+
weights = weights[wnames]
167+
}
168+
if (!any(weights > 0)) {
169+
stop("At least one weight must be > 0.")
170+
}
171+
scores = pmap(list(private$.wrapped, weights), function(x, w) {
172+
x$calculate(task, nfeat)
173+
s = x$scores[fn]
174+
if (pv$rank_transform) s = rank(s, na.last = "keep", ties.method = "average")
175+
s * w
176+
})
177+
scores_df = as.data.frame(scores)
178+
combined = rowSums(scores_df, na.rm = TRUE)
179+
all_missing = rowSums(!is.na(scores_df)) == 0L
180+
combined[all_missing] = NA_real_
181+
structure(combined, names = fn)
182+
},
183+
deep_clone = function(name, value) {
184+
if (name == ".wrapped") {
185+
private$.param_set = NULL
186+
return(map(value, function(x) x$clone(deep = TRUE)))
187+
}
188+
if (name == ".own_param_set") {
189+
private$.param_set = NULL
190+
return(value$clone(deep = TRUE))
191+
}
192+
if (name == ".param_set") {
193+
return(NULL)
194+
}
195+
value
196+
}
197+
),
198+
active = list(
199+
wrapped = function(val) {
200+
if (!missing(val)) {
201+
stop("$wrapped is read-only.")
202+
}
203+
private$.wrapped
204+
},
205+
param_set = function(val) {
206+
if (is.null(private$.param_set)) {
207+
private$.param_set = ParamSetCollection$new(c(list(private$.own_param_set), map(private$.wrapped, "param_set")))
208+
}
209+
if (!missing(val) && !identical(val, private$.param_set)) {
210+
stop("param_set is read-only.")
211+
}
212+
private$.param_set
213+
}
214+
)
215+
216+
)

R/bibentries.R

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,16 @@ bibentries = c(
6565
journal = "Journal of the American Statistical Association"
6666
),
6767

68+
binder_2020 = bibentry("inproceedings",
69+
doi = "10.1145/3377930.3389815",
70+
year = "2020",
71+
publisher = "Association for Computing Machinery",
72+
pages = "471--479",
73+
author = "Martin Binder and Julia Moosbauer and Janek Thomas and Bernd Bischl",
74+
title = "Multi-objective hyperparameter tuning and feature selection using filter ensembles",
75+
booktitle = "Proceedings of the 2020 Genetic and Evolutionary Computation Conference"
76+
),
77+
6878
zhang2003 = bibentry("inproceedings",
6979
year = "2003",
7080
author = "Zhang, J. and Mani, I.",

R/zzz.R

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,27 @@ register_mlr3 = function() {
1919
x$pipeops$properties = c("validation", "internal_tuning")
2020
}
2121

22+
register_mlr3filters = function() {
23+
if ("mlr3filters" %in% loadedNamespaces()) {
24+
x = utils::getFromNamespace("mlr_filters", ns = "mlr3filters")
25+
x$add("ensemble", FilterEnsemble)
26+
}
27+
}
28+
29+
30+
31+
paradox_info <- list2env(list(is_old = FALSE), parent = emptyenv())
32+
2233
.onLoad = function(libname, pkgname) { # nocov start
2334
register_mlr3()
24-
setHook(packageEvent("mlr3", "onLoad"), function(...) register_mlr3(), action = "append")
35+
register_mlr3filters()
36+
setHook(packageEvent("mlr3", "onLoad"), function(...) {
37+
register_mlr3()
38+
register_mlr3filters()
39+
}, action = "append")
40+
setHook(packageEvent("mlr3filters", "onLoad"), function(...) {
41+
register_mlr3filters()
42+
}, action = "append")
2543
backports::import(pkgname)
2644

2745
assign("lg", lgr::get_logger("mlr3/mlr3pipelines"), envir = parent.env(environment()))

0 commit comments

Comments
 (0)