Skip to content

Commit 4c06919

Browse files
committed
wip
1 parent 0170eac commit 4c06919

File tree

7 files changed

+133
-43
lines changed

7 files changed

+133
-43
lines changed

R/GraphLearner.R

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
#' * `validate` :: `numeric(1)`, `"predefined"`, `"test"` or `NULL`\cr
5757
#' How to construct the validation data. This also has to be configured in the individual learners wrapped by
5858
#' `PipeOpLearner`, see [`set_validate.GraphLearner`] on how to configure this.
59+
#' For more details on the possible values, see [`mlr3::Learner`].
5960
#' * `marshaled` :: `logical(1)`\cr
6061
#' Whether the learner is marshaled.
6162
#'
@@ -119,8 +120,8 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
119120
}
120121
assert_subset(task_type, mlr_reflections$task_types$type)
121122

122-
private$.can_validate = some(learner_wrapping_pipeops(graph), function(po) "validation" %in% po$learner$properties)
123-
private$.can_internal_tuning = some(learner_wrapping_pipeops(graph), function(po) "internal_tuning" %in% po$learner$properties)
123+
private$.can_validate = some(graph$pipeops, function(po) "validation" %in% po$properties)
124+
private$.can_validate = some(graph$pipeops, function(po) "internal_tuning" %in% po$properties)
124125

125126
properties = setdiff(mlr_reflections$learner_properties[[task_type]],
126127
c("validation", "internal_tuning")[!c(private$.can_validate, private$.can_internal_tuning)])
@@ -139,6 +140,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
139140
if (!is.null(predict_type)) self$predict_type = predict_type
140141
},
141142
base_learner = function(recursive = Inf) {
143+
self$base_pipeop(recursive = recursive)$learner_model
144+
},
145+
base_pipeop = function(recursive = Inf) {
142146
assert(check_numeric(recursive, lower = Inf), check_int(recursive))
143147
if (recursive <= 0) return(self)
144148
gm = self$graph_model
@@ -158,7 +162,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
158162
if (length(last_pipeop_id) > 1) stop("Graph has no unique PipeOp containing a Learner")
159163
if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.")
160164
}
161-
learner_model$base_learner(recursive - 1)
165+
learner_model$base_pipeop(recursive - 1)
162166
},
163167
marshal = function(...) {
164168
learner_marshal(.learner = self, ...)
@@ -179,7 +183,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
179183
validate = function(rhs) {
180184
if (!missing(rhs)) {
181185
if (!private$.can_validate) {
182-
stopf("None of the Learners wrapped by GraphLearner '%s' support validation.", self$id)
186+
stopf("None of the PipeOps in Graph '%s' supports validation.", self$id)
183187
}
184188
private$.validate = assert_validate(rhs)
185189
}
@@ -232,30 +236,17 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
232236
.can_internal_tuning = NULL,
233237
.extract_internal_tuned_values = function() {
234238
if (!private$.can_validate) return(NULL)
235-
itvs = unlist(map(
236-
learner_wrapping_pipeops(self$graph_model), function(po) {
237-
if (exists("internal_tuned_values", po$learner)) {
238-
po$learner_model$internal_tuned_values
239-
}
240-
}
241-
), recursive = FALSE)
242-
if (is.null(itvs) || !length(itvs)) return(named_list())
239+
itvs = unlist(map(pos_with_property(self, "internal_tuning"), "internal_tuned_values"), recursive = FALSE)
240+
if (!length(itvs)) return(named_list())
243241
itvs
244242
},
245243
.extract_internal_valid_scores = function() {
246244
if (!private$.can_internal_tuning) return(NULL)
247-
ivs = unlist(map(
248-
learner_wrapping_pipeops(self$graph_model), function(po) {
249-
if (exists("internal_valid_scores", po$learner)) {
250-
po$learner_model$internal_valid_scores
251-
}
252-
}
253-
), recursive = FALSE)
245+
its = unlist(map(pos_with_property(self, "validation"), "internal_valid_scores"), recursive = FALSE)
254246
if (is.null(ivs) || !length(ivs)) return(named_list())
255247
ivs
256248
},
257249
deep_clone = function(name, value) {
258-
private$.param_set = NULL
259250
# FIXME this repairs the mlr3::Learner deep_clone() method which is broken.
260251
if (is.environment(value) && !is.null(value[[".__enclos_env__"]])) {
261252
return(value$clone(deep = TRUE))
@@ -268,17 +259,10 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
268259

269260
.train = function(task) {
270261
if (!is.null(get0("validate", self))) {
271-
some_pipeops_validate = some(learner_wrapping_pipeops(self), function(po) !is.null(get0("validate", po$learner)))
262+
some_pipeops_validate = some(pos_with_property(self, "validation"), function(po) !is.null(po$validate))
272263
if (!some_pipeops_validate) {
273264
lg$warn("GraphLearner '%s' specifies a validation set, but none of its Learners use it.", self$id)
274265
}
275-
} else {
276-
# otherwise the pipeops will preprocess this unnecessarily
277-
if (!is.null(task$internal_valid_task)) {
278-
prev_itv = task$internal_valid_task
279-
on.exit({task$internal_valid_task = prev_itv}, add = TRUE)
280-
task$internal_valid_task = NULL
281-
}
282266
}
283267

284268
on.exit({self$graph$state = NULL})
@@ -350,6 +334,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
350334
#' For which pipeops to enable validation.
351335
#' This parameter is ignored when `validate` is set to `NULL`.
352336
#' By default, validation is enabled for the base learner.
337+
#' @param args_all (`list()`)\cr
338+
#' Rarely needed. A named list of parameter values that are passed to all subsequet [`set_validate()`] calls on the individual
339+
#' `PipeOp`s.
353340
#' @param args (named `list()`)\cr
354341
#' Rarely needed.
355342
#' A named list of lists, specifying additional argments to be passed to [`set_validate()`] for the respective learners.
@@ -376,31 +363,35 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
376363
#' glrn$validate
377364
#' glrn$graph$pipeops$classif.debug$learner$validate
378365
#' glrn$graph$pipeops$final$learner$validate
379-
set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(), ...) {
366+
set_validate.GraphLearner = function(learner, validate, ids = NULL, args_all = list(), args = list(), ...) {
380367
if (is.null(validate)) {
381368
learner$validate = NULL
382-
walk(learner_wrapping_pipeops(learner), function(po) {
383-
po$learner$validate = NULL
369+
walk(pos_with_property(learner$graph$pipeops, "validation"), function(po) {
370+
# disabling needs no extra arguments
371+
invoke(set_validate, po, validate = NULL, args_all = args_all, args = args[[po$id]] %??% list())
384372
})
385373
return(invisible(learner))
386374
}
387375

388376
if (is.null(ids)) {
389-
ids = base_pipeop(learner)$id
377+
ids = learner$base_pipeop(recursive = 1)$id
390378
} else {
391-
assert_subset(ids, ids(keep(learner_wrapping_pipeops(learner), function(po) "validation" %in% po$learner$properties)))
379+
assert_subset(ids, ids(pos_with_property(learner$graph$pipeops, "validation")))
392380
}
393381

394382
assert_list(args, types = "list")
383+
assert_list(args_all, types = "list")
395384
assert_subset(names(args), ids)
396385

397-
prev_validate_pos = discard(map(learner_wrapping_pipeops(learner), function(po) get0("validate", po$learner, ifnotfound = NA)),
398-
function(x) identical(x, NA))
399-
386+
prev_validate_pos = map(pos_with_property(learner$graph$pipeops, "validation"), "validate")
400387
prev_validate = learner$validate
401-
402388
on.exit({
403-
iwalk(prev_validate_pos, function(val, poid) learner$graph$pipeops[[poid]]$learner$validate = val)
389+
iwalk(prev_validate_pos, function(val, poid) {
390+
# passing the args here is just a heuristic that can in principle fail, but this should be extremely
391+
# rare
392+
args = args[[poid]] %??% list()
393+
set_validate(learner$graph$pipeops[[poid]], validate = val, args = args, args_all = args_all)
394+
})
404395
learner$validate = prev_validate
405396
}, add = TRUE)
406397

@@ -409,7 +400,8 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
409400
walk(ids, function(poid) {
410401
# learner might be another GraphLearner / AutoTuner so we call into set_validate() again
411402
withCallingHandlers({
412-
invoke(set_validate, learner = learner$graph$pipeops[[poid]]$learner, .args = insert_named(list(validate = "predefined"), args[[poid]]))
403+
args = c(args[[poid]], args_all) %??% list()
404+
set_validate(learner$graph$pipeops[[poid]], .args = insert_named(list(validate = "predefined"), args))
413405
}, error = function(e) {
414406
e$message = sprintf("Failed to set validate for PipeOp '%s':\n%s", poid, e$message)
415407
stop(e)

R/PipeOp.R

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,16 @@
135135
#' and done, if requested, by the [`Graph`] backend itself; it should *not* be done explicitly by `private$.train()` or `private$.predict()`.
136136
#' * `man` :: `character(1)`\cr
137137
#' Identifying string of the help page that shows with `help()`.
138+
#' * `properties` :: `character()`\cr
139+
#' The properties of the pipeop.
140+
#' Currently supported values are:
141+
#' * `"validation"`: the `PipeOp` can make use of the `$internal_valid_task` of an [`mlr3::Task`], see [`mlr3::Learner`] for more information.
142+
#' `PipeOp`s that have this property, also have a `$validate` field, which controls whether to use the validation task,
143+
#' as well as a `$internal_valid_scores` field, which allows to access the internal validation scores after training.
144+
#' * `"internal_tuning"`: the `PipeOp` is able to internally optimize hyperparameters, see [`mlr3::Learner`] for an explanation.
145+
#' `PipeOp`s with that property also implement the standardized accessor `$internal_tuned_values`.
146+
#'
147+
#' Programatic access to all available properties is possible via `mlr_reflections$pipeops$properties`.
138148
#'
139149
#' @section Methods:
140150
#' * `train(input)`\cr
@@ -235,8 +245,9 @@ PipeOp = R6Class("PipeOp",
235245
output = NULL,
236246
.result = NULL,
237247
tags = NULL,
248+
properties = NULL,
238249

239-
initialize = function(id, param_set = ps(), param_vals = list(), input, output, packages = character(0), tags = "abstract") {
250+
initialize = function(id, param_set = ps(), param_vals = list(), input, output, packages = character(0), tags = "abstract", properties = character(0)) {
240251
if (inherits(param_set, "ParamSet")) {
241252
private$.param_set = assert_param_set(param_set)
242253
private$.param_set_source = NULL
@@ -246,6 +257,7 @@ PipeOp = R6Class("PipeOp",
246257
}
247258
self$id = assert_string(id)
248259

260+
self$properties = assert_subset(properties, mlr_reflections$pipeops$properties)
249261
self$param_set$values = insert_named(self$param_set$values, param_vals)
250262
self$input = assert_connection_table(input)
251263
self$output = assert_connection_table(output)
@@ -596,4 +608,3 @@ evaluate_multiplicities = function(self, unpacked, evalcall, instate) {
596608
map(transpose_list(map(result, "output")), as.Multiplicity)
597609
}
598610
}
599-

R/PipeOpLearner.R

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,17 @@
6262
#' [`Learner`][mlr3::Learner] that is being wrapped. Read-only.
6363
#' * `learner_model` :: [`Learner`][mlr3::Learner]\cr
6464
#' [`Learner`][mlr3::Learner] that is being wrapped. This learner contains the model if the `PipeOp` is trained. Read-only.
65-
#'
65+
#' * `validate` :: `"predefined"` or `NULL`\cr
66+
#' This field can only be set for `Learner`s that have the `"validation"` property.
67+
#' Setting the field to `"predefined"` means that the wrapped `Learner` will use the internal validation task,
68+
#' otherwise it will be ignored.
69+
#' Note that specifying *how* the validation data is created is possible via the `$validate` field of the [`GraphLearner`].
70+
#' For each `PipeOp` it is then only possible to either use it (`"predefined"`) or not use it (`NULL`).
71+
#' Also see [`set_validate.GraphLearner`] for more information.
72+
#' * `internal_tuned_values` :: named `list()` or `NULL`\cr
73+
#' The internally tuned values if the wrapped `Learner`s supports internal tuning, `NULL` otherwise.
74+
#' * `internal_valid_scores` :: named `list()` or `NULL`\cr
75+
#' The internal validation scores if the wrapped `Learner`s supports internal validation, `NULL` otherwise.
6676
#' @section Methods:
6777
#' Methods inherited from [`PipeOp`].
6878
#'
@@ -91,13 +101,38 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
91101
type = private$.learner$task_type
92102
task_type = mlr_reflections$task_types[type, mult = "first"]$task
93103
out_type = mlr_reflections$task_types[type, mult = "first"]$prediction
104+
properties = c("validation", "internal_tuning")
105+
properties = properties[properties %in% learner$properties]
94106
super$initialize(id, param_set = alist(private$.learner$param_set), param_vals = param_vals,
95107
input = data.table(name = "input", train = task_type, predict = task_type),
96108
output = data.table(name = "output", train = "NULL", predict = out_type),
97-
tags = "learner", packages = learner$packages)
109+
tags = "learner", packages = learner$packages, properties = properties)
98110
}
99111
),
100112
active = list(
113+
internal_tuned_values = function(rhs) {
114+
assert_ro_binding(rhs)
115+
if ("validate" %nin% self$properties) return(NULL)
116+
self$learner$internal_tuned_values
117+
},
118+
internal_valid_scores = function(rhs) {
119+
assert_ro_binding(rhs)
120+
if ("internal_tuning" %nin% self$properties) return(NULL)
121+
self$learner$internal_valid_scores
122+
},
123+
validate = function(rhs) {
124+
if ("validation" %nin% self$properties) {
125+
if (!missing(rhs)) {
126+
stopf("PipeOp '%s' does not support validation, because the wrapped Learner doesn't.", self$id)
127+
}
128+
return(NULL)
129+
}
130+
if (!missing(rhs)) {
131+
private$.validate = assert_po_validate(rhs)
132+
self$learner$validate = rhs
133+
}
134+
private$.learner$validate
135+
},
101136
id = function(val) {
102137
if (!missing(val)) {
103138
private$.id = val
@@ -137,6 +172,7 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
137172
),
138173
private = list(
139174
.learner = NULL,
175+
.validate = NULL,
140176

141177
.train = function(inputs) {
142178
on.exit({private$.learner$state = NULL})
@@ -157,3 +193,15 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
157193
)
158194

159195
mlr_pipeops$add("learner", PipeOpLearner, list(R6Class("Learner", public = list(id = "learner", task_type = "classif", param_set = ps(), packages = "mlr3pipelines"))$new()))
196+
197+
#' @export
198+
set_validate.PipeOpLearner = function(learner, validate, ...) {
199+
assert_po_validate(validate)
200+
on.exit({learner$validate = prev_validate})
201+
prev_validate = learner$validate
202+
learner$validate = validate
203+
set_validate(learner, validate = validate, ...)
204+
on.exit()
205+
learner$validate = validate
206+
invisible(learner)
207+
}

R/utils.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,18 @@ base_pipeop = function(self) {
170170
# New movie idea: "The Last PipeOp"
171171
last_pipeop
172172
}
173+
174+
pos_with_property = function(x, property) {
175+
x = if (test_class(x, "GraphLearner")) {
176+
x$graph$pipeops
177+
} else if(test_class(x, "Graph")) {
178+
x$pipeops
179+
} else {
180+
x
181+
}
182+
keep(x, function(po) property %in% po$properties)
183+
}
184+
185+
assert_po_validate = function(rhs) {
186+
assert_choice(rhs, "predefined", null.ok = TRUE)
187+
}

R/zzz.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ register_mlr3 = function() {
1616
c("abstract", "meta", "missings", "feature selection", "imbalanced data",
1717
"data transform", "target transform", "ensemble", "robustify", "learner", "encode",
1818
"multiplicity")))
19+
x$pipeops$properties = c("validation", "internal_tuning")
1920
}
2021

2122
paradox_info <- list2env(list(is_old = FALSE), parent = emptyenv())

inst/testthat/helper_functions.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ expect_pipeop = function(po, check_ps_default_values = TRUE) {
108108
expect_int(po$innum, lower = 1)
109109
expect_int(po$outnum, lower = 1)
110110
expect_valid_pipeop_param_set(po, check_ps_default_values = check_ps_default_values)
111+
if ("validation" %in% po$properties) {
112+
testthat::expect_true(exists("validate", po))
113+
testthat::expect_true(exists("internal_valid_scores", envir = po))
114+
checkmate::expect_function(mlr3misc::get_private(po)$.extract_internal_valid_scores)
115+
}
116+
if ("internal_tuning" %in% po$properties) {
117+
checkmate::assert_false(exists("internal_tuning", po))
118+
}
111119
}
112120

113121
# autotest for the parmset of a pipeop

tests/testthat/test_PipeOp.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,18 @@ test_that("Informative error and warning messages", {
123123
expect_warning(potest$predict(list(1)), NA)
124124

125125
})
126+
127+
test_that("properties", {
128+
f = function(properties) {
129+
PipeOp$new(
130+
id = "potest",
131+
input = data.table(name = "input", train = "*", predict = "*"),
132+
output = data.table(name = "input", train = "*", predict = "*"),
133+
properties = properties
134+
)
135+
}
136+
137+
expect_error(f("abc"))
138+
po1 = f("validation")
139+
expect_equal(po1$properties, "validation")
140+
})

0 commit comments

Comments
 (0)