diff --git a/R/PipeOpRenameColumns.R b/R/PipeOpRenameColumns.R index b7d175e56..a96cf02c1 100644 --- a/R/PipeOpRenameColumns.R +++ b/R/PipeOpRenameColumns.R @@ -28,17 +28,21 @@ #' #' @section Parameters: #' The parameters are the parameters inherited from [`PipeOpTaskPreproc`], as well as: -#' * `renaming` :: named `character`\cr -#' Named `character` vector. The names of the vector specify the old column names that should be -#' changed to the new column names as given by the elements of the vector. Initialized to the empty -#' character vector. +#' * `renaming` :: named `character` | `function`\cr +#' Takes the form of either a named `character` or a `function`. +#' For a named `character` vector the names of the vector elements specify the +#' old column names and the corresponding element values give the new column names. +#' Initialized to an empty character vector. +#' A `function` specifies how the old column names should be changed to the new column names. +#' The function must return a `character` vector with one entry per input column name so that each selected column receives a new name. +#' To choose columns use the `affect_columns` parameter. No function is initialized. #' * `ignore_missing` :: `logical(1)`\cr #' Ignore if columns named in `renaming` are not found in the input [`Task`][mlr3::Task]. If this is #' `FALSE`, then names found in `renaming` not found in the [`Task`][mlr3::Task] cause an error. #' Initialized to `FALSE`. #' #' @section Internals: -#' Uses the `$rename()` mutator of the [`Task`][mlr3::Task] to set the new column names. +#' Uses the `$rename()` mutator of the [`Task`][mlr3::Task] to set new column names. #' #' @section Fields: #' Only fields inherited from [`PipeOp`]. @@ -56,36 +60,53 @@ #' task = tsk("iris") #' pop = po("renamecolumns", param_vals = list(renaming = c("Petal.Length" = "PL"))) #' pop$train(list(task)) +#' +#' pof = po("renamecolumns", +#' param_vals = list(renaming = function(colnames) {sub("Petal", "P", colnames)})) +#' pof$train(list(task)) +#' + PipeOpRenameColumns = R6Class("PipeOpRenameColumns", inherit = PipeOpTaskPreprocSimple, public = list( initialize = function(id = "renamecolumns", param_vals = list()) { ps = ps( renaming = p_uty( - custom_check = crate(function(x) check_character(x, any.missing = FALSE, names = "strict") %check&&% check_names(x, type = "strict"), - .parent = topenv()), + custom_check = crate(function(x) (check_character(x, any.missing = FALSE, names = "strict") %check&&% check_names(x, type = "strict")) %check||% check_function(x)), tags = c("train", "predict", "required") ), ignore_missing = p_lgl(tags = c("train", "predict", "required")) ) ps$values = list(renaming = character(0), ignore_missing = FALSE) - super$initialize(id, ps, param_vals = param_vals, can_subset_cols = FALSE) + super$initialize(id, ps, param_vals = param_vals, can_subset_cols = TRUE) } ), private = list( + .get_state = function(task) { + if (is.function(self$param_set$values$renaming)) { + new_names = self$param_set$values$renaming(task$feature_names) + assert_character(new_names, any.missing = FALSE, len = length(task$feature_names), .var.name = "the value returned by `renaming` function") + names(new_names) = task$feature_names + list(old_names = task$feature_names, new_names = new_names) + } else { + pv = self$param_set$get_values(tags = "train") + new_names = pv$renaming + innames = names(new_names) + nontargets = task$col_roles + nontargets$target = NULL + takenames = innames %in% unlist(nontargets) + if (!pv$ignore_missing && !all(takenames)) { + # we can't rely on task$rename because it could also change the target name, which we don't want. + stopf("The names %s from `renaming` parameter were not found in the Task.", str_collapse(innames[!takenames])) + } + list(old_names = innames[takenames], new_names = new_names[takenames]) + } + }, .transform = function(task) { - if (!length(self$param_set$values$renaming)) { + if (!length(self$state$new_names)) { return(task) # early exit } - innames = names(self$param_set$values$renaming) - nontargets = task$col_roles - nontargets$target = NULL - takenames = innames %in% unlist(nontargets) - if (!self$param_set$values$ignore_missing && !all(takenames)) { - # we can't rely on task$rename because it could also change the target name, which we don't want. - stopf("The names %s from `renaming` parameter were not found in the Task.", str_collapse(innames[!takenames])) - } - task$rename(old = innames[takenames], new = self$param_set$values$renaming[takenames]) + task$rename(old = self$state$old_names, new = self$state$new_names) } ) ) diff --git a/man/mlr_pipeops_renamecolumns.Rd b/man/mlr_pipeops_renamecolumns.Rd index 131433937..1eebf6f2f 100644 --- a/man/mlr_pipeops_renamecolumns.Rd +++ b/man/mlr_pipeops_renamecolumns.Rd @@ -40,10 +40,14 @@ The \verb{$state} is a named \code{list} with the \verb{$state} elements inherit The parameters are the parameters inherited from \code{\link{PipeOpTaskPreproc}}, as well as: \itemize{ -\item \code{renaming} :: named \code{character}\cr -Named \code{character} vector. The names of the vector specify the old column names that should be -changed to the new column names as given by the elements of the vector. Initialized to the empty -character vector. +\item \code{renaming} :: named \code{character} | \code{function}\cr +Takes the form of either a named \code{character} or a \code{function}. +For a named \code{character} vector the names of the vector elements specify the +old column names and the corresponding element values give the new column names. +Initialized to an empty character vector. +A \code{function} specifies how the old column names should be changed to the new column names. +The function must return a \code{character} vector with one entry per input column name so that each selected column receives a new name. +To choose columns use the \code{affect_columns} parameter. No function is initialized. \item \code{ignore_missing} :: \code{logical(1)}\cr Ignore if columns named in \code{renaming} are not found in the input \code{\link[mlr3:Task]{Task}}. If this is \code{FALSE}, then names found in \code{renaming} not found in the \code{\link[mlr3:Task]{Task}} cause an error. @@ -53,7 +57,7 @@ Initialized to \code{FALSE}. \section{Internals}{ -Uses the \verb{$rename()} mutator of the \code{\link[mlr3:Task]{Task}} to set the new column names. +Uses the \verb{$rename()} mutator of the \code{\link[mlr3:Task]{Task}} to set new column names. } \section{Fields}{ @@ -72,6 +76,11 @@ library("mlr3") task = tsk("iris") pop = po("renamecolumns", param_vals = list(renaming = c("Petal.Length" = "PL"))) pop$train(list(task)) + +pof = po("renamecolumns", + param_vals = list(renaming = function(colnames) {sub("Petal", "P", colnames)})) +pof$train(list(task)) + } \seealso{ https://mlr-org.com/pipeops.html diff --git a/tests/testthat/test_pipeop_renamecolumns.R b/tests/testthat/test_pipeop_renamecolumns.R index 2439a1d11..536f44687 100644 --- a/tests/testthat/test_pipeop_renamecolumns.R +++ b/tests/testthat/test_pipeop_renamecolumns.R @@ -39,3 +39,15 @@ test_that("error handling", { op$param_set$values$ignore_missing = TRUE expect_equal(task$data(), op$train(list(task))[[1]]$data()) }) + +test_that("assert on function works", { + task = mlr_tasks$get("iris") + expect_error(po("renamecolumns", param_vals = list(renaming = 1 + 1))) +}) + +test_that("assert on function works", { + task = mlr_tasks$get("iris") + po = po("renamecolumns", param_vals = list(renaming = function(colnames) sub("Petal", "P", colnames))) + result = po$train(list(task)) + expect_equal(result[[1]]$feature_names, c("P.Length", "P.Width", "Sepal.Length", "Sepal.Width")) +})