Skip to content

Commit 71003d7

Browse files
committed
tests hopefully pass
1 parent 4c06919 commit 71003d7

File tree

12 files changed

+184
-68
lines changed

12 files changed

+184
-68
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ S3method(predict,Graph)
2828
S3method(print,Multiplicity)
2929
S3method(print,Selector)
3030
S3method(set_validate,GraphLearner)
31+
S3method(set_validate,PipeOpLearner)
3132
S3method(unmarshal_model,Multiplicity_marshaled)
3233
S3method(unmarshal_model,graph_learner_model_marshaled)
3334
S3method(unmarshal_model,pipeop_impute_learner_state_marshaled)

R/GraphLearner.R

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,11 @@
5151
#' The internal tuned parameter values.
5252
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal tuning.
5353
#' * `internal_valid_scores` :: named `list()` or `NULL`\cr
54-
#' The internal tuned parameter values.
54+
#' The internal validation scores as retrieved from the `PipeOps`.
55+
#' The names are prefixed with the respective IDs of the `PipeOp`s.
5556
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal validation.
5657
#' * `validate` :: `numeric(1)`, `"predefined"`, `"test"` or `NULL`\cr
57-
#' How to construct the validation data. This also has to be configured in the individual learners wrapped by
58+
#' How to construct the validation data. This also has to be configured in the individual `PipeOp`s such as
5859
#' `PipeOpLearner`, see [`set_validate.GraphLearner`] on how to configure this.
5960
#' For more details on the possible values, see [`mlr3::Learner`].
6061
#' * `marshaled` :: `logical(1)`\cr
@@ -121,7 +122,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
121122
assert_subset(task_type, mlr_reflections$task_types$type)
122123

123124
private$.can_validate = some(graph$pipeops, function(po) "validation" %in% po$properties)
124-
private$.can_validate = some(graph$pipeops, function(po) "internal_tuning" %in% po$properties)
125+
private$.can_internal_tuning = some(graph$pipeops, function(po) "internal_tuning" %in% po$properties)
125126

126127
properties = setdiff(mlr_reflections$learner_properties[[task_type]],
127128
c("validation", "internal_tuning")[!c(private$.can_validate, private$.can_internal_tuning)])
@@ -139,11 +140,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
139140
}
140141
if (!is.null(predict_type)) self$predict_type = predict_type
141142
},
142-
base_learner = function(recursive = Inf) {
143-
self$base_pipeop(recursive = recursive)$learner_model
144-
},
145-
base_pipeop = function(recursive = Inf) {
143+
base_learner = function(recursive = Inf, return_po = FALSE) {
146144
assert(check_numeric(recursive, lower = Inf), check_int(recursive))
145+
assert_flag(return_po)
147146
if (recursive <= 0) return(self)
148147
gm = self$graph_model
149148
gm_output = gm$output
@@ -162,7 +161,11 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
162161
if (length(last_pipeop_id) > 1) stop("Graph has no unique PipeOp containing a Learner")
163162
if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.")
164163
}
165-
learner_model$base_pipeop(recursive - 1)
164+
if (return_po) {
165+
last_pipeop
166+
} else {
167+
learner_model$base_learner(recursive - 1)
168+
}
166169
},
167170
marshal = function(...) {
168171
learner_marshal(.learner = self, ...)
@@ -236,13 +239,13 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
236239
.can_internal_tuning = NULL,
237240
.extract_internal_tuned_values = function() {
238241
if (!private$.can_validate) return(NULL)
239-
itvs = unlist(map(pos_with_property(self, "internal_tuning"), "internal_tuned_values"), recursive = FALSE)
242+
itvs = unlist(map(pos_with_property(self$graph_model, "internal_tuning"), "internal_tuned_values"), recursive = FALSE)
240243
if (!length(itvs)) return(named_list())
241244
itvs
242245
},
243246
.extract_internal_valid_scores = function() {
244247
if (!private$.can_internal_tuning) return(NULL)
245-
its = unlist(map(pos_with_property(self, "validation"), "internal_valid_scores"), recursive = FALSE)
248+
ivs = unlist(map(pos_with_property(self$graph_model, "validation"), "internal_valid_scores"), recursive = FALSE)
246249
if (is.null(ivs) || !length(ivs)) return(named_list())
247250
ivs
248251
},
@@ -367,30 +370,28 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args_all = l
367370
if (is.null(validate)) {
368371
learner$validate = NULL
369372
walk(pos_with_property(learner$graph$pipeops, "validation"), function(po) {
370-
# disabling needs no extra arguments
371373
invoke(set_validate, po, validate = NULL, args_all = args_all, args = args[[po$id]] %??% list())
372374
})
373375
return(invisible(learner))
374376
}
375377

376378
if (is.null(ids)) {
377-
ids = learner$base_pipeop(recursive = 1)$id
379+
ids = learner$base_learner(recursive = 1, return_po = TRUE)$id
378380
} else {
379381
assert_subset(ids, ids(pos_with_property(learner$graph$pipeops, "validation")))
380382
}
381383

382384
assert_list(args, types = "list")
383-
assert_list(args_all, types = "list")
385+
assert_list(args_all)
384386
assert_subset(names(args), ids)
385387

386388
prev_validate_pos = map(pos_with_property(learner$graph$pipeops, "validation"), "validate")
387389
prev_validate = learner$validate
388390
on.exit({
389-
iwalk(prev_validate_pos, function(val, poid) {
390-
# passing the args here is just a heuristic that can in principle fail, but this should be extremely
391-
# rare
392-
args = args[[poid]] %??% list()
393-
set_validate(learner$graph$pipeops[[poid]], validate = val, args = args, args_all = args_all)
391+
iwalk(prev_validate_pos, function(prev_val, poid) {
392+
# Here we don't call into set_validate() as this also does not ensure that we are able to correctly
393+
# reset the configuration to the previous state (e.g. for AutoTuner) and is less transparent
394+
learner$graph$pipeops[[poid]]$validate = prev_val
394395
})
395396
learner$validate = prev_validate
396397
}, add = TRUE)
@@ -400,13 +401,17 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args_all = l
400401
walk(ids, function(poid) {
401402
# learner might be another GraphLearner / AutoTuner so we call into set_validate() again
402403
withCallingHandlers({
403-
args = c(args[[poid]], args_all) %??% list()
404-
set_validate(learner$graph$pipeops[[poid]], .args = insert_named(list(validate = "predefined"), args))
404+
args = insert_named(c(list(validate = "predefined"), args_all), args[[poid]])
405+
invoke(set_validate, learner$graph$pipeops[[poid]], .args = args)
405406
}, error = function(e) {
406-
e$message = sprintf("Failed to set validate for PipeOp '%s':\n%s", poid, e$message)
407+
e$message = sprintf(paste0(
408+
"Failed to set validate for PipeOp '%s':\n%s\n",
409+
"Trying to heuristically reset validation to its previous state, please check the results"), poid, e$message)
407410
stop(e)
408411
}, warning = function(w) {
409-
w$message = sprintf("Failed to set validate for PipeOp '%s':\n%s", po$id, w$message)
412+
w$message = sprintf(paste0(
413+
"Failed to set validate for PipeOp '%s':\n%s\n",
414+
"Trying to heuristically reset validation to its previous state, please check the results"), poid, w$message)
410415
warning(w)
411416
invokeRestart("muffleWarning")
412417
})
@@ -487,4 +492,4 @@ infer_task_type = function(graph) {
487492
task_type = get_po_task_type(graph$pipeops[[graph$rhs]])
488493
}
489494
c(task_type, "classif")[[1]] # "classif" as final fallback
490-
}
495+
}

R/PipeOp.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@
143143
#' as well as a `$internal_valid_scores` field, which allows to access the internal validation scores after training.
144144
#' * `"internal_tuning"`: the `PipeOp` is able to internally optimize hyperparameters, see [`mlr3::Learner`] for an explanation.
145145
#' `PipeOp`s with that property also implement the standardized accessor `$internal_tuned_values`.
146-
#'
146+
#' An example for such a `PipeOp` is a `PipeOpLearner` that wraps a `Learner` with the `"internal_tuning"` property.
147+
#'
147148
#' Programatic access to all available properties is possible via `mlr_reflections$pipeops$properties`.
148149
#'
149150
#' @section Methods:

R/PipeOpLearner.R

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,18 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
9797
initialize = function(learner, id = NULL, param_vals = list()) {
9898
private$.learner = as_learner(learner, clone = TRUE)
9999
id = id %??% private$.learner$id
100+
if (!test_po_validate(get0("validate", private$.learner))) {
101+
stopf(paste0(
102+
"Validate field of PipeOp '%s' must either be NULL or 'predefined'.\nTo configure how ",
103+
"the validation data is created, set the $validate field of the GraphLearner, e.g. using set_validate()."
104+
), id) # nolint
105+
}
100106
# FIXME: can be changed when mlr-org/mlr3#470 has an answer
101107
type = private$.learner$task_type
102108
task_type = mlr_reflections$task_types[type, mult = "first"]$task
103109
out_type = mlr_reflections$task_types[type, mult = "first"]$prediction
104110
properties = c("validation", "internal_tuning")
105-
properties = properties[properties %in% learner$properties]
111+
properties = properties[properties %in% private$.learner$properties]
106112
super$initialize(id, param_set = alist(private$.learner$param_set), param_vals = param_vals,
107113
input = data.table(name = "input", train = task_type, predict = task_type),
108114
output = data.table(name = "output", train = "NULL", predict = out_type),
@@ -112,13 +118,13 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
112118
active = list(
113119
internal_tuned_values = function(rhs) {
114120
assert_ro_binding(rhs)
115-
if ("validate" %nin% self$properties) return(NULL)
116-
self$learner$internal_tuned_values
121+
if ("internal_tuning" %nin% self$properties) return(NULL)
122+
self$learner_model$internal_tuned_values
117123
},
118124
internal_valid_scores = function(rhs) {
119125
assert_ro_binding(rhs)
120-
if ("internal_tuning" %nin% self$properties) return(NULL)
121-
self$learner$internal_valid_scores
126+
if ("validation" %nin% self$properties) return(NULL)
127+
self$learner_model$internal_valid_scores
122128
},
123129
validate = function(rhs) {
124130
if ("validation" %nin% self$properties) {
@@ -128,8 +134,7 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
128134
return(NULL)
129135
}
130136
if (!missing(rhs)) {
131-
private$.validate = assert_po_validate(rhs)
132-
self$learner$validate = rhs
137+
private$.learner$validate = assert_po_validate(rhs)
133138
}
134139
private$.learner$validate
135140
},
@@ -147,6 +152,14 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
147152
if (!identical(val, private$.learner)) {
148153
stop("$learner is read-only.")
149154
}
155+
validate = get0("validate", private$.learner)
156+
if (!test_po_validate(validate)) {
157+
warningf(paste(sep = "\n",
158+
"PipeOpLearner '%s' has its validate field set to a value that is neither NULL nor 'predefined'.",
159+
"This will likely lead to unexpected behaviour.",
160+
"Configure the $validate field of the GraphLearner to define how the validation data is created."
161+
), self$id)
162+
}
150163
}
151164
private$.learner
152165
},
@@ -172,7 +185,6 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
172185
),
173186
private = list(
174187
.learner = NULL,
175-
.validate = NULL,
176188

177189
.train = function(inputs) {
178190
on.exit({private$.learner$state = NULL})
@@ -192,16 +204,30 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
192204
)
193205
)
194206

195-
mlr_pipeops$add("learner", PipeOpLearner, list(R6Class("Learner", public = list(id = "learner", task_type = "classif", param_set = ps(), packages = "mlr3pipelines"))$new()))
207+
mlr_pipeops$add("learner", PipeOpLearner, list(R6Class("Learner", public = list(properties = character(0), id = "learner", task_type = "classif", param_set = ps(), packages = "mlr3pipelines"))$new())) # nolint
196208

197209
#' @export
198210
set_validate.PipeOpLearner = function(learner, validate, ...) {
199211
assert_po_validate(validate)
200-
on.exit({learner$validate = prev_validate})
212+
on.exit({
213+
# also does not work in general (e.g. for AutoTuner) and is even less transparent
214+
learner$validate = prev_validate
215+
})
201216
prev_validate = learner$validate
202-
learner$validate = validate
203-
set_validate(learner, validate = validate, ...)
217+
withCallingHandlers({
218+
set_validate(learner$learner, validate = validate, ...)
219+
}, error = function(e) {
220+
e$message = sprintf(paste0(
221+
"Failed to set validate for Learner '%s':\n%s\n",
222+
"Trying to heuristically reset validation to its previous state, please check the results"), learner$id, e$message)
223+
stop(e)
224+
}, warning = function(w) {
225+
w$message = sprintf(paste0(
226+
"Failed to set validate for PipeOp '%s':\n%s\n",
227+
"Trying to heuristically reset validation to its previous state, please check the results"), learner$id, w$message)
228+
warning(w)
229+
invokeRestart("muffleWarning")
230+
})
204231
on.exit()
205-
learner$validate = validate
206232
invisible(learner)
207233
}

R/utils.R

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -147,30 +147,6 @@ learner_wrapping_pipeops = function(x) {
147147
keep(x, function(po) inherits(po, "PipeOpLearner") || inherits(po, "PipeOpLearnerCV"))
148148
}
149149

150-
151-
# get the last PipeOpLearner
152-
base_pipeop = function(self) {
153-
gm = self$graph_model
154-
gm_output = gm$output
155-
if (nrow(gm_output) != 1) stop("Graph has no unique output.")
156-
last_pipeop_id = gm_output$op.id
157-
158-
# pacify static checks
159-
src_id = NULL
160-
dst_id = NULL
161-
162-
repeat {
163-
last_pipeop = gm$pipeops[[last_pipeop_id]]
164-
learner_model = if ("learner_model" %in% names(last_pipeop)) last_pipeop$learner_model
165-
if (!is.null(learner_model)) break
166-
last_pipeop_id = gm$edges[dst_id == last_pipeop_id]
167-
if (length(last_pipeop_id) > 1) stop("Graph has no unique PipeOp containing a Learner")
168-
if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.")
169-
}
170-
# New movie idea: "The Last PipeOp"
171-
last_pipeop
172-
}
173-
174150
pos_with_property = function(x, property) {
175151
x = if (test_class(x, "GraphLearner")) {
176152
x$graph$pipeops
@@ -185,3 +161,7 @@ pos_with_property = function(x, property) {
185161
assert_po_validate = function(rhs) {
186162
assert_choice(rhs, "predefined", null.ok = TRUE)
187163
}
164+
165+
test_po_validate = function(x) {
166+
test_choice(x, "predefined", null.ok = TRUE)
167+
}

man/PipeOp.Rd

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_learners_graph.Rd

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_pipeops_learner.Rd

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/set_validate.GraphLearner.Rd

Lines changed: 12 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)