Skip to content

Commit 5e7b85e

Browse files
committed
more progress
1 parent 8ec274a commit 5e7b85e

File tree

11 files changed

+207
-99
lines changed

11 files changed

+207
-99
lines changed

NAMESPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ S3method(predict,Graph)
2424
S3method(print,Multiplicity)
2525
S3method(print,Selector)
2626
S3method(set_inner_tuning,GraphLearner)
27+
S3method(set_validate,GraphLearner)
2728
export("%>>!%")
2829
export("%>>%")
2930
export(Graph)
@@ -143,7 +144,6 @@ export(selector_none)
143144
export(selector_setdiff)
144145
export(selector_type)
145146
export(selector_union)
146-
export(set_validate.GraphLearner)
147147
import(checkmate)
148148
import(data.table)
149149
import(mlr3)

R/GraphLearner.R

Lines changed: 83 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@
4747
#' contain the model. Use `graph_model` to access the trained [`Graph`] after `$train()`. Read-only.
4848
#' * `graph_model` :: [`Learner`][mlr3::Learner]\cr
4949
#' [`Graph`] that is being wrapped. This [`Graph`] contains a trained state after `$train()`. Read-only.
50+
#' * `inner_tuned_values` :: named `list()` or `NULL`\cr
51+
#' The inner tuned parameter values.
52+
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports inner tuning.
53+
#' * `inner_valid_scores` :: named `list()` or `NULL`\cr
54+
#' The inner tuned parameter values.
55+
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal validation.
5056
#'
5157
#' @section Internals:
5258
#' [`as_graph()`] is called on the `graph` argument, so it can technically also be a `list` of things, which is
@@ -110,7 +116,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
110116
)
111117

112118
properties = setdiff(mlr_reflections$learner_properties[[task_type]],
113-
c("validation", "inner_tuning")[!c(private$.validate, inner_tuning)])
119+
c("validation", "inner_tuning")[!c(private$.can_validate, inner_tuning)])
114120

115121
super$initialize(id = id, task_type = task_type,
116122
feature_types = mlr_reflections$task_feature_types,
@@ -128,9 +134,6 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
128134
if (!is.null(predict_type)) self$predict_type = predict_type
129135
},
130136
base_learner = function(recursive = Inf) {
131-
self$base_pipeop(recursive = recursive)$learner_model
132-
},
133-
base_pipeop = function(recursive = Inf) {
134137
assert(check_numeric(recursive, lower = Inf), check_int(recursive))
135138
if (recursive <= 0) return(self)
136139
gm = self$graph_model
@@ -150,30 +153,18 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
150153
if (length(last_pipeop_id) > 1) stop("Graph has no unique PipeOp containing a Learner")
151154
if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.")
152155
}
153-
last_pipeop$base_pipeop(recursive - 1)
154-
155-
},
156-
157-
#' @description
158-
#' Retrieves the inner validation scores as a named `list()`.
156+
learner_model$base_learner(recursive - 1)
157+
}
158+
),
159+
active = list(
159160
inner_valid_scores = function(rhs) {
160161
assert_ro_binding(rhs)
161-
if (is.null(self$state)) {
162-
stopf("Learner not trained")
163-
}
164162
self$state$inner_valid_scores
165163
},
166-
#' @description
167-
#' Retrieves the inner tuned values as a named `list()`.
168164
inner_tuned_values = function(rhs) {
169165
assert_ro_binding(rhs)
170-
if (is.null(self$state)) {
171-
stopf("Learner not trained")
172-
}
173166
self$state$inner_tuned_values
174-
}
175-
),
176-
active = list(
167+
},
177168
validate = function(rhs) {
178169
if (!missing(rhs)) {
179170
if (!private$.can_validate) {
@@ -185,11 +176,11 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
185176

186177
},
187178
hash = function() {
188-
digest(list(class(self), self$id, self$graph$hash, private$.predict_type,
179+
digest(list(class(self), self$id, self$graph$hash, private$.predict_type, private$.validate,
189180
self$fallback$hash, self$parallel_predict), algo = "xxhash64")
190181
},
191182
phash = function() {
192-
digest(list(class(self), self$id, self$graph$phash, private$.predict_type,
183+
digest(list(class(self), self$id, self$graph$phash, private$.predict_type, private$.validate,
193184
self$fallback$hash, self$parallel_predict), algo = "xxhash64")
194185
},
195186
predict_type = function(rhs) {
@@ -226,21 +217,34 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
226217
.validate = NULL,
227218
.can_validate = NULL,
228219
.extract_inner_tuned_values = function() {
220+
itvs = unlist(map(
221+
learner_wrapping_pipeops(self$graph_model), function(po) {
222+
if (exists("inner_tuned_values", po$learner)) {
223+
po$learner_model$inner_tuned_values
224+
}
225+
}
226+
), recursive = FALSE)
229227

230-
231-
warningf("Implementthis")
232-
list()
228+
if (is.null(itvs) || !length(itvs)) {
229+
return(named_list())
230+
}
231+
itvs
233232

234233
},
235234
.extract_inner_valid_scores = function() {
236-
warningf("Implementthis")
237-
list()
238-
# map(
239-
# keep(self$graph$pipeops, function(po) inherits(po, "PipeOpLearnerCV") || inherits(po, "PipeOpLearner")),
240-
# function(po) {
241-
# po$inner_
242-
# }
243-
# )
235+
ivs = unlist(map(
236+
learner_wrapping_pipeops(self$graph_model), function(po) {
237+
if (exists("inner_valid_scores", po$learner)) {
238+
po$learner_model$inner_valid_scores
239+
}
240+
}
241+
), recursive = FALSE)
242+
243+
if (is.null(ivs) || !length(ivs)) {
244+
return(named_list())
245+
}
246+
ivs
247+
244248
},
245249
deep_clone = function(name, value) {
246250
private$.param_set = NULL
@@ -256,8 +260,8 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
256260

257261
.train = function(task) {
258262
if (!is.null(get0("validate", self))) {
259-
some_pipeops_validate = map(
260-
filter(self$graph$pipeops, function(po) inherits(po, "PipeOpLearner") || inherits(po, "PipeOpLearnerCV")),
263+
some_pipeops_validate = map_lgl(
264+
keep(self$graph$pipeops, function(po) inherits(po, "PipeOpLearner") || inherits(po, "PipeOpLearnerCV")),
261265
function(po) !is.null(get0("validate", po$learner))
262266
)
263267

@@ -319,30 +323,30 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
319323
#' Configure validation for a graph learner.
320324
#'
321325
#' In a [`GraphLearner`], validation can be configured on two levels:
322-
#' 1. On the [`GraphLearner`] level.
323-
#' 2. On the level of the [`Learner`]s that are wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`].
324-
#'
325-
#' Therefore, enabling validation requires to specify not only how to create the validation set (1), but also which
326-
#' pipeops should actually use it.
327-
#' Only the [`GraphLearner`] can specify **how** to create the validation set.
328-
#' All learners wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`] can only set it to `NULL` (disable) or
329-
#' `"inner_valid"` (enable).
326+
#' 1. On the [`GraphLearner`] level, which specifies **how** the validation set is constructed.
327+
#' 2. On the level of the [`Learner`]s that are wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`], which specifies
328+
#' which pipeops actually make use of the validation set.
329+
#' All learners wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`] can only set it to `NULL` (disable) or
330+
#' `"inner_valid"` (enable).
330331
#'
331332
#' @param learner ([`GraphLearner`])\cr
332333
#' The graph learner to configure.
333334
#' @param validate (`numeric(1)`, `"inner_valid"` or `NULL`)\cr
334335
#' How to set the `$validate` field of the learner.
335-
#' If set to `NULL` all validation is disabled.
336+
#' If set to `NULL` all validation is disabled, both on the graph learner level, but also for all pipeops.
336337
#' @param ids (`NULL` or `character()`)\cr
337338
#' For which pipeops to enable validation.
338339
#' This parameter is ignored when `validate` is set to `NULL`.
339340
#' By default, validation is enabled for the base learner.
340341
#' @param args (named `list()`)\cr
341342
#' Rarely needed.
342-
#' A named list of lists, specifying additional argments to be passed to [`set_validate()`] for the pipeops.
343-
#' The names must be a subset of `ids`.
343+
#' A named list of lists, specifying additional argments to be passed to [`set_validate()`] for the respective pipeops.
344+
#' @param ... (any)\cr
345+
#' Currently unused.
346+
#'
344347
#' @export
345348
#' @examples
349+
#' library(mlr3)
346350
#' # simple
347351
#' glrn = as_learner(po("pca") %>>% lrn("classif.debug"))
348352
#' set_validate(glrn, 0.3)
@@ -353,13 +357,14 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
353357
#' glrn$graph$pipeops$classif.debug$learner$validate
354358
#'
355359
#' # complex
356-
#' glrn = as_learner(ppl("stacking", lrns(c("classif.debug", "classif.featureless")), lrn("classif.debug", id = "final")))
360+
#' glrn = as_learner(ppl("stacking", lrns(c("classif.debug", "classif.featureless")),
361+
#' lrn("classif.debug", id = "final")))
357362
#' set_validate(glrn, 0.2, which = c("classif.debug", "final"))
358363
#' glrn$validate
359364
#' glrn$graph$pipeops$classif.debug$learner$validate
360365
#' glrn$graph$pipeops$final$learner$validate
361-
set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list()) {
362-
if (is.null(learner$validate)) {
366+
set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(), ...) {
367+
if (is.null(validate)) {
363368
learner$validate = NULL
364369
walk(learner_wrapping_pipeops(learner), function(po) {
365370
if (exists("validate", po$learner)) {
@@ -370,7 +375,7 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
370375
}
371376

372377
if (is.null(ids)) {
373-
which = learner$base_pipeop()$id
378+
ids = base_pipeop(learner)$id
374379
} else {
375380
assert_subset(ids, ids(keep(learner_wrapping_pipeops(learner), function(po) "validation" %in% po$learner$properties)))
376381
assert_true(length(ids) > 0)
@@ -379,7 +384,7 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
379384
assert_list(args, types = "list")
380385
assert_subset(names(args), ids)
381386

382-
prev_validate_pos = discard(map(learner_wrapping_pipeops(learner), function(po) get0("validate", po$learner), is.null))
387+
prev_validate_pos = discard(map(learner_wrapping_pipeops(learner), function(po) get0("validate", po$learner)), is.null)
383388
prev_validate = learner$validate
384389

385390
on.exit({
@@ -391,18 +396,29 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
391396

392397
walk(ids, function(poid) {
393398
# learner might be another GraphLearner / AutoTuner
394-
invoke(set_validate learner = learner$graph$pipeops[[poid]]$learner, validate = "inner_valid", .args = args[[poid]])
399+
withCallingHandlers({
400+
invoke(set_validate, learner = learner$graph$pipeops[[poid]]$learner, validate = "inner_valid", .args = args[[poid]])
401+
}, error = function(e) {
402+
e$message = sprintf("Failed to set validate for PipeOp '%s':\n%s", poid, e$message)
403+
stop(e)
404+
}, warning = function(w) {
405+
w$message = sprintf("Failed to set validate for PipeOp '%s':\n%s", po$id, w$message)
406+
warning(w)
407+
invokeRestart("muffleWarning")
408+
})
395409
})
396410
on.exit()
397411

398412
invisible(learner)
399413
}
400414

401415

402-
#' @title Set Inner Tuning of a GraphLearner
416+
#' @title Set Inner Tuning for a Graph Learner
403417
#' @description
404418
#' First, all values specified by `...` are
405419
#' All [`PipeOpLearner`] and [`PipeOpLearnerCV`]
420+
#'
421+
#' @inheritParams mlr3::set_inner_tuning
406422
#' @param validate (`numeric(1)`, `"inner_valid"`, or `NULL`)\cr
407423
#' How to set the `$validate` field of the learner.
408424
#' @param args (named `list()`)\cr
@@ -415,7 +431,7 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
415431
#' @export
416432
set_inner_tuning.GraphLearner = function(.learner, .disable = FALSE, validate = NA, args = NULL, ...) {
417433
if (is.null(args)) {
418-
args = set_names(list(list()), .learner$base_pipeops()$id
434+
args = set_names(list(list()), base_pipeop(.learner)$id)
419435
}
420436
all_pipeops = .learner$graph$pipeops
421437
lrn_pipeops = learner_wrapping_pipeops(all_pipeops)
@@ -433,10 +449,18 @@ set_inner_tuning.GraphLearner = function(.learner, .disable = FALSE, validate =
433449
}, add = TRUE)
434450

435451
walk(lrn_pipeops[names(args)], function(po) {
436-
browser()
437-
invoke(set_inner_tuning, .learner = po$learner,
438-
.args = insert_named(list(validate = validate, .disable = .disable), args[[po$id]])
439-
)
452+
withCallingHandlers({
453+
invoke(set_inner_tuning, .learner = po$learner,
454+
.args = insert_named(list(validate = validate, .disable = .disable), args[[po$id]])
455+
)
456+
}, error = function(e) {
457+
e$message = sprintf("Failed to set inner tuning for PipeOp '%s':\n%s", po$id, e$message)
458+
stop(e)
459+
}, warning = function(w) {
460+
w$message = sprintf("Failed to set inner tuning for PipeOp '%s':\n%s", po$id, w$message)
461+
warning(w)
462+
invokeRestart("muffleWarning")
463+
})
440464
})
441465

442466
# Now:

R/PipeOpImpute.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ PipeOpImpute = R6Class("PipeOpImpute",
193193

194194
intask$select(setdiff(intask$feature_names, colnames(imputanda)))$cbind(imputanda)
195195

196-
self$state$outtasklayouto = copy(intask$feature_types)
196+
self$state$outtasklayout = copy(intask$feature_types)
197197

198198
if (!is.null(intask$inner_valid_task)) {
199199
intask$inner_valid_task = private$.predict(list(intask$inner_valid_task))[[1L]]

R/PipeOpLearner.R

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,6 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
9797
}
9898
),
9999
active = list(
100-
validate = function(rhs) {
101-
if (!missing(rhs)) {
102-
if (is.null(rhs) || identical(rhs, "inner_valid")) {
103-
stopf("The validate field of PipeOpLearner can only be set to NULL or inner_valid. You probably meant to configure the validate field of the GraphLearner")
104-
}
105-
private$.learner$validate = rhs
106-
}
107-
private$.learner$validate
108-
},
109100
id = function(val) {
110101
if (!missing(val)) {
111102
private$.id = val

R/utils.R

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,7 @@ dictionary_sugar_inc_mget = function(dict, .keys, ...) {
135135
objs
136136
}
137137

138-
check_validate = function(x) {
139-
if (test_numeric(x, lower = 0, upper = 1, len = 1L)) {
140-
return(TRUE)
141-
}
142-
check_choice(x, c("inner_valid", "test"), null.ok = TRUE)
143-
}
144-
145-
learning_wrapping_pipeops = function(x) {
138+
learner_wrapping_pipeops = function(x) {
146139
if (inherits(x, "Graph")) {
147140
x = x$pipeops
148141
} else if (inherits(x, "GraphLearner")) {
@@ -153,3 +146,27 @@ learning_wrapping_pipeops = function(x) {
153146

154147
keep(x, function(po) inherits(po, "PipeOpLearner") || inherits(po, "PipeOpLearnerCV"))
155148
}
149+
150+
151+
# get the last PipeOpLearner
152+
base_pipeop = function(self) {
153+
gm = self$graph_model
154+
gm_output = gm$output
155+
if (nrow(gm_output) != 1) stop("Graph has no unique output.")
156+
last_pipeop_id = gm_output$op.id
157+
158+
# pacify static checks
159+
src_id = NULL
160+
dst_id = NULL
161+
162+
repeat {
163+
last_pipeop = gm$pipeops[[last_pipeop_id]]
164+
learner_model = if ("learner_model" %in% names(last_pipeop)) last_pipeop$learner_model
165+
if (!is.null(learner_model)) break
166+
last_pipeop_id = gm$edges[dst_id == last_pipeop_id]
167+
if (length(last_pipeop_id) > 1) stop("Graph has no unique PipeOp containing a Learner")
168+
if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.")
169+
}
170+
# New movie idea: "The Last PipeOp"
171+
last_pipeop
172+
}

man/mlr_learners_graph.Rd

Lines changed: 6 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)