Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,9 @@ import(palmerpenguins)
import(paradox)
importFrom(R6,R6Class)
importFrom(R6,is.R6)
importFrom(data.table,as.data.table)
importFrom(data.table,data.table)
importFrom(future,nbrOfWorkers)
importFrom(future,plan)
importFrom(graphics,plot)
importFrom(mlr3misc,clbk)
importFrom(mlr3misc,clbks)
importFrom(mlr3misc,mlr_callbacks)
importFrom(parallelly,availableCores)
importFrom(stats,contr.treatment)
importFrom(stats,model.frame)
Expand Down
14 changes: 14 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
# mlr3 (development version)

## New Features:

* `Task` got method `$materialize_view()` which can save memory after subsetting a task.
* Better input validation for:
* `Learner` fields.
* Various improvements to the documentation and logging output, including
examples for methods.
* Measure "oob_error" now works even without storing models during resampling.

## Deprecations:

* Assigning to some fields of `Task`, `Learner`, and `Resampling` now throws a deprecation warning.
This will become an error in the future.

# mlr3 1.1.0

* feat: Add new measure `MeasureRegrRQR` for quantile regression.
Expand Down
110 changes: 73 additions & 37 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,46 +177,16 @@ Learner = R6Class("Learner",
#' This is an internal data structure which may change in the future.
state = NULL,

#' @template field_task_type
task_type = NULL,

#' @field feature_types (`character()`)\cr
#' Stores the feature types the learner can handle, e.g. `"logical"`, `"numeric"`, or `"factor"`.
#' A complete list of candidate feature types, grouped by task type, is stored in [`mlr_reflections$task_feature_types`][mlr_reflections].
feature_types = NULL,

#' @field properties (`character()`)\cr
#' Stores a set of properties/capabilities the learner has.
#' A complete list of candidate properties, grouped by task type, is stored in [`mlr_reflections$learner_properties`][mlr_reflections].
properties = NULL,


#' @template field_packages
packages = NULL,

#' @template field_predict_sets
predict_sets = "test",

#' @field parallel_predict (`logical(1)`)\cr
#' If set to `TRUE`, use \CRANpkg{future} to calculate predictions in parallel (default: `FALSE`).
#' The row ids of the `task` will be split into [future::nbrOfWorkers()] chunks,
#' and predictions are evaluated according to the active [future::plan()].
#' This currently only works for methods `Learner$predict()` and `Learner$predict_newdata()`,
#' and has no effect during [resample()] or [benchmark()] where you have other means
#' to parallelize.
#'
#' Note that the recorded time required for prediction reports the time required to predict
#' is not properly defined and depends on the parallelization backend.
parallel_predict = FALSE,

#' @field timeout (named `numeric(2)`)\cr
#' Timeout for the learner's train and predict steps, in seconds.
#' This works differently for different encapsulation methods, see
#' [mlr3misc::encapsulate()].
#' Default is `c(train = Inf, predict = Inf)`.
#' Also see the section on error handling the mlr3book:
#' \url{https://mlr3book.mlr-org.com/chapters/chapter10/advanced_technical_aspects_of_mlr3.html#sec-error-handling}
timeout = c(train = Inf, predict = Inf),

#' @template field_man
man = NULL,

Expand All @@ -229,12 +199,12 @@ Learner = R6Class("Learner",

self$id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$task_type = assert_choice(task_type, mlr_reflections$task_types$type)
private$.task_type = assert_choice(task_type, mlr_reflections$task_types$type)
self$feature_types = assert_ordered_set(feature_types, mlr_reflections$task_feature_types, .var.name = "feature_types")
private$.predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]),
empty.ok = FALSE, .var.name = "predict_types")
private$.predict_type = predict_types[1L]
self$properties = sort(assert_subset(properties, mlr_reflections$learner_properties[[task_type]]))
private$.properties = sort(assert_subset(properties, mlr_reflections$learner_properties[[task_type]]))
if (!missing(data_formats)) warn_deprecated("Learner$initialize argument 'data_formats'")
self$packages = union("mlr3", assert_character(packages, any.missing = FALSE, min.chars = 1L))
self$man = assert_string(man, na.ok = TRUE)
Expand Down Expand Up @@ -492,10 +462,10 @@ Learner = R6Class("Learner",
}

prevci = task$col_info
task$backend = newdata
task$col_info = col_info(task$backend)
task$col_info[, c("label", "fix_factor_levels")] = prevci[list(task$col_info$id), on = "id", c("label", "fix_factor_levels")]
task$col_info$fix_factor_levels[is.na(task$col_info$fix_factor_levels)] = FALSE
task$.__enclos_env__$private$.backend = newdata
task$.__enclos_env__$private$.col_info = col_info(task$backend)
task$.__enclos_env__$private$.col_info[, c("label", "fix_factor_levels")] = prevci[list(task$col_info$id), on = "id", c("label", "fix_factor_levels")]
task$.__enclos_env__$private$.col_info$fix_factor_levels[is.na(task$.__enclos_env__$private$.col_info$fix_factor_levels)] = FALSE
task$row_roles$use = task$backend$rownames
task_col_roles = task$col_roles
update_col_roles = FALSE
Expand Down Expand Up @@ -676,6 +646,67 @@ Learner = R6Class("Learner",
),

active = list(
#' @template field_task_type
task_type = function(rhs) {
if (!missing(rhs)) {
warn_deprecated("task_type will soon be read-only.")
private$.properties = rhs
}
private$.task_type
},

#' @field properties (`character()`)\cr
#' Stores a set of properties/capabilities the learner has.
#' A complete list of candidate properties, grouped by task type, is stored in [`mlr_reflections$learner_properties`][mlr_reflections].
properties = function(rhs) {
if (!missing(rhs)) {
warn_deprecated("properties will soon be read-only.")
private$.properties = rhs
}
private$.properties
},

#' @template field_predict_sets
predict_sets = function(rhs) {
if (missing(rhs)) {
return(private$.predict_sets)
}
assert_subset(rhs, mlr_reflections$predict_sets)
private$.predict_sets = rhs
},

#' @field parallel_predict (`logical(1)`)\cr
#' If set to `TRUE`, use \CRANpkg{future} to calculate predictions in parallel (default: `FALSE`).
#' The row ids of the `task` will be split into [future::nbrOfWorkers()] chunks,
#' and predictions are evaluated according to the active [future::plan()].
#' This currently only works for methods `Learner$predict()` and `Learner$predict_newdata()`,
#' and has no effect during [resample()] or [benchmark()] where you have other means
#' to parallelize.
#'
#' Note that the recorded time required for prediction reports the time required to predict
#' is not properly defined and depends on the parallelization backend.
parallel_predict = function(rhs) {
if (missing(rhs)) {
return(private$.parallel_predict)
}
private$.parallel_predict = assert_flag(rhs)
},

#' @field timeout (named `numeric(2)`)\cr
#' Timeout for the learner's train and predict steps, in seconds.
#' This works differently for different encapsulation methods, see
#' [mlr3misc::encapsulate()].
#' Default is `c(train = Inf, predict = Inf)`.
#' Also see the section on error handling the mlr3book:
#' \url{https://mlr3book.mlr-org.com/chapters/chapter10/advanced_technical_aspects_of_mlr3.html#sec-error-handling}
timeout = function(rhs) {
if (missing(rhs)) {
return(private$.timeout)
}
assert_permutation(names(rhs), c("train", "predict"))
private$.timeout = assert_numeric(rhs, lower = 0, any.missing = FALSE, len = 2L)
},

#' @field use_weights (`character(1)`)\cr
#' How weights should be handled.
#' Settings are `"use"` `"ignore"`, and `"error"`.
Expand Down Expand Up @@ -841,6 +872,11 @@ Learner = R6Class("Learner",
),

private = list(
.predict_sets = "test",
.task_type = NULL,
.properties = NULL,
.parallel_predict = FALSE,
.timeout = c(train = Inf, predict = Inf),
.use_weights = NULL,
.encapsulation = c(train = "none", predict = "none"),
.fallback = NULL,
Expand Down
71 changes: 50 additions & 21 deletions R/Resampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,24 +106,7 @@ Resampling = R6Class("Resampling",
#' `$train_set()` and `$test_set()`.
instance = NULL,

#' @field task_hash (`character(1)`)\cr
#' The hash of the [Task] which was passed to `r$instantiate()`.
task_hash = NA_character_,

#' @field task_row_hash (`character(1)`)\cr
#' The hash of the row ids of the [Task] which was passed to `r$instantiate()`.
task_row_hash = NA_character_,

#' @field task_nrow (`integer(1)`)\cr
#' The number of observations of the [Task] which was passed to `r$instantiate()`.
#'
task_nrow = NA_integer_,

#' @field duplicated_ids (`logical(1)`)\cr
#' If `TRUE`, duplicated rows can occur within a single training set or within a single test set.
#' E.g., this is `TRUE` for Bootstrap, and `FALSE` for cross-validation.
#' Only used internally.
duplicated_ids = NULL,

#' @template field_man
man = NULL,
Expand All @@ -139,7 +122,7 @@ Resampling = R6Class("Resampling",
private$.id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$param_set = assert_param_set(param_set)
self$duplicated_ids = assert_flag(duplicated_ids)
private$.duplicated_ids = assert_flag(duplicated_ids)
self$man = assert_string(man, na.ok = TRUE)
},

Expand Down Expand Up @@ -188,9 +171,9 @@ Resampling = R6Class("Resampling",
task = assert_task(as_task(task))
private$.hash = NULL
self$instance = private$.get_instance(task)
self$task_hash = task$hash
self$task_row_hash = task$row_hash
self$task_nrow = task$nrow
private$.task_hash = task$hash
private$.task_row_hash = task$row_hash
private$.task_nrow = task$nrow
invisible(self)
},

Expand Down Expand Up @@ -256,6 +239,48 @@ Resampling = R6Class("Resampling",
}

private$.hash
},

#' @field task_hash (`character(1)`)\cr
#' The hash of the [Task] which was passed to `r$instantiate()`.
task_hash = function(rhs) {
if (!missing(rhs)) {
warn_deprecated("task_hash will soon be read-only.")
private$.task_hash = rhs
}
private$.task_hash
},

#' @field task_row_hash (`character(1)`)\cr
#' The hash of the row ids of the [Task] which was passed to `r$instantiate()`.
task_row_hash = function(rhs) {
if (!missing(rhs)) {
warn_deprecated("task_row_hash will soon be read-only.")
private$.task_row_hash = rhs
}
private$.task_row_hash
},

#' @field task_nrow (`integer(1)`)\cr
#' The number of observations of the [Task] which was passed to `r$instantiate()`.
task_nrow = function(rhs) {
if (!missing(rhs)) {
warn_deprecated("task_nrow will soon be read-only.")
private$.task_nrow = rhs
}
private$.task_nrow
},

#' @field duplicated_ids (`logical(1)`)\cr
#' If `TRUE`, duplicated rows can occur within a single training set or within a single test set.
#' E.g., this is `TRUE` for Bootstrap, and `FALSE` for cross-validation.
#' Only used internally.
duplicated_ids = function(rhs) {
if (!missing(rhs)) {
warn_deprecated("duplicated_ids will soon be read-only.")
private$.duplicated_ids = rhs
}
private$.duplicated_ids
}
),

Expand All @@ -264,6 +289,10 @@ Resampling = R6Class("Resampling",
.id = NULL,
.hash = NULL,
.groups = NULL,
.task_hash = NA_character_,
.task_row_hash = NA_character_,
.task_nrow = NA_integer_,
.duplicated_ids = NULL,

.get_instance = function(task) {
strata = task$strata
Expand Down
6 changes: 3 additions & 3 deletions R/ResamplingCustom.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ ResamplingCustom = R6Class("ResamplingCustom", inherit = Resampling,
assert_subset(unlist(train_sets, use.names = FALSE), task$row_ids)
assert_subset(unlist(test_sets, use.names = FALSE), task$row_ids)
self$instance = list(train = train_sets, test = test_sets)
self$task_hash = task$hash
self$task_nrow = task$nrow
self$task_row_hash = task$row_hash
private$.task_hash = task$hash
private$.task_nrow = task$nrow
private$.task_row_hash = task$row_hash
invisible(self)
}
),
Expand Down
6 changes: 3 additions & 3 deletions R/ResamplingCustomCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ ResamplingCustomCV = R6Class("ResamplingCustomCV", inherit = Resampling,
}

self$instance = split(task$row_ids, f, drop = TRUE)
self$task_hash = task$hash
self$task_nrow = task$nrow
self$task_row_hash = task$row_hash
private$.task_hash = task$hash
private$.task_nrow = task$nrow
private$.task_row_hash = task$row_hash
invisible(self)
}
),
Expand Down
Loading
Loading