|
| 1 | + |
| 2 | + |
| 3 | +#' @title Filter Ensemble |
| 4 | +#' |
| 5 | +#' @usage NULL |
| 6 | +#' @name mlr_filters_ensemble |
| 7 | +#' @format [`R6Class`][R6::R6Class] object inheriting from [`Filter`][mlr3filters::Filter]. |
| 8 | +#' |
| 9 | +#' @description |
| 10 | +#' `FilterEnsemble` aggregates several [`Filter`][mlr3filters::Filter]s by averaging their scores |
| 11 | +#' (or ranks) with user-defined weights. Each wrapped filter is evaluated on the supplied task, |
| 12 | +#' and the resulting feature scores are combined feature-wise by a convex combination determined |
| 13 | +#' through the `weights` parameter. This allows leveraging complementary inductive biases of |
| 14 | +#' multiple filters without committing to a single criterion. The concept was introduced by |
| 15 | +#' Binder et al. (2020). This implementation follows the idea but leaves the exact choice of |
| 16 | +#' weights to the user. |
| 17 | +#' |
| 18 | +#' @section Construction: |
| 19 | +#' ``` |
| 20 | +#' FilterEnsemble$new(filters) |
| 21 | +#' ``` |
| 22 | +#' |
| 23 | +#' * `filters` :: `list` of [`Filter`][mlr3filters::Filter]\cr |
| 24 | +#' Filters that are evaluated and aggregated. Each filter must be cloneable and support the |
| 25 | +#' task type and feature types of the ensemble. The ensemble identifier defaults to the wrapped |
| 26 | +#' filter ids concatenated by `"."`. |
| 27 | +#' |
| 28 | +#' @section Parameters: |
| 29 | +#' * `weights` :: `numeric()`\cr |
| 30 | +#' Required non-negative weights, one for each wrapped filter, with at least one strictly positive value. |
| 31 | +#' Values are used as given when calculating the weighted mean. If named, names must match the wrapped filter ids. |
| 32 | +#' * `rank_transform` :: `logical(1)`\cr |
| 33 | +#' If `TRUE`, ranks of individual filter scores are used instead of the raw scores before |
| 34 | +#' averaging. Initialized to `FALSE`. |
| 35 | +#' |
| 36 | +#' Parameters of wrapped filters are available via `$param_set` and can be referenced using |
| 37 | +#' the wrapped filter id followed by `"."`, e.g. `"variance.na.rm"`. |
| 38 | +#' |
| 39 | +#' @section Fields: |
| 40 | +#' * `$wrapped` :: named `list` of [`Filter`][mlr3filters::Filter]\cr |
| 41 | +#' Read-only access to the wrapped filters. |
| 42 | +#' |
| 43 | +#' @section Methods: |
| 44 | +#' * `get_weights_search_space(weights_param_name = "weights", normalize_weights = "uniform", prefix = "w")`\cr |
| 45 | +#' (`character(1)`, `character(1)`, `character(1)`) -> [`ParamSet`][paradox::ParamSet]\cr |
| 46 | +#' Construct a [`ParamSet`][paradox::ParamSet] describing a weight search space. |
| 47 | +#' * `get_weights_tunetoken(normalize_weights = "uniform")`\cr |
| 48 | +#' (`character(1)`) -> [`TuneToken`][paradox::TuneToken]\cr |
| 49 | +#' Shortcut returning a [`TuneToken`][paradox::TuneToken] for tuning the weights. |
| 50 | +#' * `set_weights_to_tune(normalize_weights = "uniform")`\cr |
| 51 | +#' (`character(1)`) -> `self`\cr |
| 52 | +#' Convenience wrapper that stores the `TuneToken` returned by |
| 53 | +#' `get_weights_tunetoken()` in `$param_set$values$weights`. |
| 54 | +#' |
| 55 | +#' @section Internals: |
| 56 | +#' All wrapped filters are called with `nfeat` equal to the number of features to ensure that |
| 57 | +#' complete score vectors are available for aggregation. Scores are combined per feature by |
| 58 | +#' computing the weighted (optionally rank-based) mean. |
| 59 | +#' |
| 60 | +#' @section References: |
| 61 | +#' `r format_bib("binder_2020")` |
| 62 | +#' |
| 63 | +#' @examplesIf mlr3misc::require_namespaces("mlr3filters", quietly = TRUE) |
| 64 | +#' library("mlr3") |
| 65 | +#' library("mlr3filters") |
| 66 | +#' |
| 67 | +#' task = tsk("sonar") |
| 68 | +#' |
| 69 | +#' flt = mlr_filters$get("ensemble", |
| 70 | +#' filters = list(FilterVariance$new(), FilterAUC$new())) |
| 71 | +#' flt$param_set$values$weights = c(variance = 0.5, auc = 0.5) |
| 72 | +#' flt$calculate(task) |
| 73 | +#' head(as.data.table(flt)) |
| 74 | +#' @export |
| 75 | +FilterEnsemble = R6Class("FilterEnsemble", inherit = mlr3filters::Filter, |
| 76 | + public = list( |
| 77 | + initialize = function(filters) { |
| 78 | + private$.wrapped = lapply(assert_list(filters, types = "Filter", min.len = 1), function(x) x$clone(deep = TRUE)) |
| 79 | + fnames = map_chr(private$.wrapped, "id") |
| 80 | + names(private$.wrapped) = fnames |
| 81 | + types_list = map(discard(private$.wrapped, function(x) test_scalar_na(x$task_types)), "task_types") |
| 82 | + if (length(types_list)) { |
| 83 | + task_types = Reduce(intersect, types_list) |
| 84 | + } else { |
| 85 | + task_types = NA_character_ |
| 86 | + } |
| 87 | + .own_param_set = ps( |
| 88 | + weights = p_uty(custom_check = crate(function(x) { |
| 89 | + if (inherits(x, "TuneToken")) { |
| 90 | + return(TRUE) |
| 91 | + } |
| 92 | + check_numeric(x, len = length(fnames), lower = 0) %check&&% |
| 93 | + (check_names(names(x), type = "unnamed") %check||% |
| 94 | + check_names(names(x), type = "unique", permutation.of = fnames)) %check&&% |
| 95 | + (if (any(x > 0)) TRUE else "At least one weight must be > 0.") |
| 96 | + }, fnames), |
| 97 | + tags = "required" |
| 98 | + ), |
| 99 | + rank_transform = p_lgl(init = FALSE, tags = "required") |
| 100 | + ) |
| 101 | + |
| 102 | + super$initialize( |
| 103 | + id = paste(fnames, collapse = "."), |
| 104 | + task_types = task_types, |
| 105 | + task_properties = unique(unlist(map(private$.wrapped, "task_properties"))), |
| 106 | + param_set = .own_param_set, |
| 107 | + feature_types = Reduce(intersect, map(private$.wrapped, "feature_types")), |
| 108 | + packages = unique(unlist(map(private$.wrapped, "packages"))), |
| 109 | + label = "meta", |
| 110 | + man = "mlr3pipelines::mlr_filters_ensemble" |
| 111 | + ) |
| 112 | + private$.own_param_set = .own_param_set |
| 113 | + private$.param_set = NULL |
| 114 | + }, |
| 115 | + get_weights_tunetoken = function(normalize_weights = "uniform") { |
| 116 | + assert_choice(normalize_weights, c("uniform", "naive", "no")) |
| 117 | + paradox::to_tune(self$get_weights_search_space(normalize_weights = normalize_weights)) |
| 118 | + }, |
| 119 | + set_weights_to_tune = function(normalize_weights = "uniform") { |
| 120 | + assert_choice(normalize_weights, c("uniform", "naive", "no")) |
| 121 | + self$param_set$set_values(.values = list(weights = self$get_weights_tunetoken(normalize_weights = normalize_weights))) |
| 122 | + invisible(self) |
| 123 | + }, |
| 124 | + get_weights_search_space = function(weights_param_name = "weights", normalize_weights = "uniform", prefix = "w") { |
| 125 | + assert_string(prefix) |
| 126 | + assert_string(weights_param_name) |
| 127 | + assert_choice(normalize_weights, c("uniform", "naive", "no")) |
| 128 | + fnames = names(private$.wrapped) |
| 129 | + innames = if (prefix == "") fnames else paste0(prefix, ".", fnames) |
| 130 | + domains = rep(list(p_dbl(0, 1)), length(fnames)) |
| 131 | + names(domains) = innames |
| 132 | + |
| 133 | + domains$.extra_trafo = crate(function(x) { |
| 134 | + w = unlist(x[innames], use.names = FALSE) |
| 135 | + names(w) = fnames |
| 136 | + x[innames] = NULL |
| 137 | + |
| 138 | + if (normalize_weights == "uniform") { |
| 139 | + w[w > 1 - .Machine$double.eps] = 1 - .Machine$double.eps |
| 140 | + w = -log1p(-w) |
| 141 | + w = w / max(sum(w), .Machine$double.eps) |
| 142 | + } else if (normalize_weights == "naive") { |
| 143 | + w = w / max(sum(w), .Machine$double.eps) |
| 144 | + } |
| 145 | + if (!any(w > 0)) { |
| 146 | + w[] = 1 / length(w) |
| 147 | + } |
| 148 | + x[[weights_param_name]] = w |
| 149 | + x |
| 150 | + }, innames, fnames, normalize_weights, weights_param_name) |
| 151 | + |
| 152 | + do.call(paradox::ps, domains) |
| 153 | + } |
| 154 | + ), |
| 155 | + private = list( |
| 156 | + .wrapped = NULL, |
| 157 | + .own_param_set = NULL, |
| 158 | + .param_set = NULL, |
| 159 | + .calculate = function(task, nfeat) { |
| 160 | + pv = private$.own_param_set$get_values() |
| 161 | + fn = task$feature_names |
| 162 | + nfeat = length(fn) # need to rank all features in an ensemble |
| 163 | + weights = pv$weights |
| 164 | + wnames = names(private$.wrapped) |
| 165 | + if (!is.null(names(weights))) { |
| 166 | + weights = weights[wnames] |
| 167 | + } |
| 168 | + if (!any(weights > 0)) { |
| 169 | + stop("At least one weight must be > 0.") |
| 170 | + } |
| 171 | + scores = pmap(list(private$.wrapped, weights), function(x, w) { |
| 172 | + x$calculate(task, nfeat) |
| 173 | + s = x$scores[fn] |
| 174 | + if (pv$rank_transform) s = rank(s, na.last = "keep", ties.method = "average") |
| 175 | + s * w |
| 176 | + }) |
| 177 | + scores_df = as.data.frame(scores) |
| 178 | + combined = rowSums(scores_df, na.rm = TRUE) |
| 179 | + all_missing = rowSums(!is.na(scores_df)) == 0L |
| 180 | + combined[all_missing] = NA_real_ |
| 181 | + structure(combined, names = fn) |
| 182 | + }, |
| 183 | + deep_clone = function(name, value) { |
| 184 | + if (name == ".wrapped") { |
| 185 | + private$.param_set = NULL |
| 186 | + return(map(value, function(x) x$clone(deep = TRUE))) |
| 187 | + } |
| 188 | + if (name == ".own_param_set") { |
| 189 | + private$.param_set = NULL |
| 190 | + return(value$clone(deep = TRUE)) |
| 191 | + } |
| 192 | + if (name == ".param_set") { |
| 193 | + return(NULL) |
| 194 | + } |
| 195 | + value |
| 196 | + } |
| 197 | + ), |
| 198 | + active = list( |
| 199 | + wrapped = function(val) { |
| 200 | + if (!missing(val)) { |
| 201 | + stop("$wrapped is read-only.") |
| 202 | + } |
| 203 | + private$.wrapped |
| 204 | + }, |
| 205 | + param_set = function(val) { |
| 206 | + if (is.null(private$.param_set)) { |
| 207 | + private$.param_set = ParamSetCollection$new(c(list(private$.own_param_set), map(private$.wrapped, "param_set"))) |
| 208 | + } |
| 209 | + if (!missing(val) && !identical(val, private$.param_set)) { |
| 210 | + stop("param_set is read-only.") |
| 211 | + } |
| 212 | + private$.param_set |
| 213 | + } |
| 214 | + ) |
| 215 | + |
| 216 | +) |
0 commit comments