Skip to content
Open
57 changes: 39 additions & 18 deletions R/PipeOpRenameColumns.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,21 @@
#'
#' @section Parameters:
#' The parameters are the parameters inherited from [`PipeOpTaskPreproc`], as well as:
#' * `renaming` :: named `character`\cr
#' Named `character` vector. The names of the vector specify the old column names that should be
#' changed to the new column names as given by the elements of the vector. Initialized to the empty
#' character vector.
#' * `renaming` :: named `character` | `function`\cr
#' Takes the form of either a named `character` or a `function`.
#' For a named `character` vector the names of the vector elements specify the
#' old column names and the corresponding element values give the new column names.
#' Initialized to an empty character vector.
#' A `function` specifies how the old column names should be changed to the new column names.
#' The function must return a `character` vector with one entry per input column name so that each selected column receives a new name.
#' To choose columns use the `affect_columns` parameter. No function is initialized.
#' * `ignore_missing` :: `logical(1)`\cr
#' Ignore if columns named in `renaming` are not found in the input [`Task`][mlr3::Task]. If this is
#' `FALSE`, then names found in `renaming` not found in the [`Task`][mlr3::Task] cause an error.
#' Initialized to `FALSE`.
#'
#' @section Internals:
#' Uses the `$rename()` mutator of the [`Task`][mlr3::Task] to set the new column names.
#' Uses the `$rename()` mutator of the [`Task`][mlr3::Task] to set new column names.
#'
#' @section Fields:
#' Only fields inherited from [`PipeOp`].
Expand All @@ -56,36 +60,53 @@
#' task = tsk("iris")
#' pop = po("renamecolumns", param_vals = list(renaming = c("Petal.Length" = "PL")))
#' pop$train(list(task))
#'
#' pof = po("renamecolumns",
#' param_vals = list(renaming = function(colnames) {sub("Petal", "P", colnames)}))
#' pof$train(list(task))
#'

PipeOpRenameColumns = R6Class("PipeOpRenameColumns",
inherit = PipeOpTaskPreprocSimple,
public = list(
initialize = function(id = "renamecolumns", param_vals = list()) {
ps = ps(
renaming = p_uty(
custom_check = crate(function(x) check_character(x, any.missing = FALSE, names = "strict") %check&&% check_names(x, type = "strict"),
.parent = topenv()),
custom_check = crate(function(x) (check_character(x, any.missing = FALSE, names = "strict") %check&&% check_names(x, type = "strict")) %check||% check_function(x)),
tags = c("train", "predict", "required")
),
ignore_missing = p_lgl(tags = c("train", "predict", "required"))
)
ps$values = list(renaming = character(0), ignore_missing = FALSE)
super$initialize(id, ps, param_vals = param_vals, can_subset_cols = FALSE)
super$initialize(id, ps, param_vals = param_vals, can_subset_cols = TRUE)
}
),
private = list(
.get_state = function(task) {
if (is.function(self$param_set$values$renaming)) {
new_names = self$param_set$values$renaming(task$feature_names)
assert_character(new_names, any.missing = FALSE, len = length(task$feature_names), .var.name = "the value returned by `renaming` function")
names(new_names) = task$feature_names
list(old_names = task$feature_names, new_names = new_names)
} else {
pv = self$param_set$get_values(tags = "train")
new_names = pv$renaming
innames = names(new_names)
nontargets = task$col_roles
nontargets$target = NULL
takenames = innames %in% unlist(nontargets)
if (!pv$ignore_missing && !all(takenames)) {
# we can't rely on task$rename because it could also change the target name, which we don't want.
stopf("The names %s from `renaming` parameter were not found in the Task.", str_collapse(innames[!takenames]))
}
list(old_names = innames[takenames], new_names = new_names[takenames])
}
},
.transform = function(task) {
if (!length(self$param_set$values$renaming)) {
if (!length(self$state$new_names)) {
return(task) # early exit
}
innames = names(self$param_set$values$renaming)
nontargets = task$col_roles
nontargets$target = NULL
takenames = innames %in% unlist(nontargets)
if (!self$param_set$values$ignore_missing && !all(takenames)) {
# we can't rely on task$rename because it could also change the target name, which we don't want.
stopf("The names %s from `renaming` parameter were not found in the Task.", str_collapse(innames[!takenames]))
}
task$rename(old = innames[takenames], new = self$param_set$values$renaming[takenames])
task$rename(old = self$state$old_names, new = self$state$new_names)
}
)
)
Expand Down
19 changes: 14 additions & 5 deletions man/mlr_pipeops_renamecolumns.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions tests/testthat/test_pipeop_renamecolumns.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,15 @@ test_that("error handling", {
op$param_set$values$ignore_missing = TRUE
expect_equal(task$data(), op$train(list(task))[[1]]$data())
})

test_that("assert on function works", {
task = mlr_tasks$get("iris")
expect_error(po("renamecolumns", param_vals = list(renaming = 1 + 1)))
})

test_that("assert on function works", {
task = mlr_tasks$get("iris")
po = po("renamecolumns", param_vals = list(renaming = function(colnames) sub("Petal", "P", colnames)))
result = po$train(list(task))
expect_equal(result[[1]]$feature_names, c("P.Length", "P.Width", "Sepal.Length", "Sepal.Width"))
})