Skip to content

Commit 8ec274a

Browse files
committed
...
1 parent 1652570 commit 8ec274a

File tree

7 files changed

+283
-77
lines changed

7 files changed

+283
-77
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ export(selector_none)
143143
export(selector_setdiff)
144144
export(selector_type)
145145
export(selector_union)
146+
export(set_validate.GraphLearner)
146147
import(checkmate)
147148
import(data.table)
148149
import(mlr3)

R/GraphLearner.R

Lines changed: 181 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
9999
assert_subset(task_type, mlr_reflections$task_types$type)
100100

101101

102-
private$.validate = some(
102+
private$.can_validate = some(
103103
keep(graph$pipeops, function(x) inherits(x, "PipeOpLearner") || inherits(x, "PipeOpLearnerCV")),
104104
function(po) "validation" %in% po$learner$properties
105105
)
@@ -110,7 +110,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
110110
)
111111

112112
properties = setdiff(mlr_reflections$learner_properties[[task_type]],
113-
c("validation", "inner_tuning")[c(!private$.validate, !inner_tuning)])
113+
c("validation", "inner_tuning")[!c(private$.validate, inner_tuning)])
114114

115115
super$initialize(id = id, task_type = task_type,
116116
feature_types = mlr_reflections$task_feature_types,
@@ -128,6 +128,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
128128
if (!is.null(predict_type)) self$predict_type = predict_type
129129
},
130130
base_learner = function(recursive = Inf) {
131+
self$base_pipeop(recursive = recursive)$learner_model
132+
},
133+
base_pipeop = function(recursive = Inf) {
131134
assert(check_numeric(recursive, lower = Inf), check_int(recursive))
132135
if (recursive <= 0) return(self)
133136
gm = self$graph_model
@@ -147,7 +150,8 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
147150
if (length(last_pipeop_id) > 1) stop("Graph has no unique PipeOp containing a Learner")
148151
if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.")
149152
}
150-
learner_model$base_learner(recursive - 1)
153+
last_pipeop$base_pipeop(recursive - 1)
154+
151155
},
152156

153157
#' @description
@@ -170,6 +174,16 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
170174
}
171175
),
172176
active = list(
177+
validate = function(rhs) {
178+
if (!missing(rhs)) {
179+
if (!private$.can_validate) {
180+
stopf("None of the Learners wrapped by GraphLearner '%s' support validation.", self$id)
181+
}
182+
private$.validate = assert_validate(rhs)
183+
}
184+
private$.validate
185+
186+
},
173187
hash = function() {
174188
digest(list(class(self), self$id, self$graph$hash, private$.predict_type,
175189
self$fallback$hash, self$parallel_predict), algo = "xxhash64")
@@ -188,12 +202,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
188202
if (!missing(rhs) && !identical(rhs, self$graph$param_set)) {
189203
stop("param_set is read-only.")
190204
}
191-
if (is.null(private$.param_set)) {
192-
private$.param_set = ParamSetCollection$new(sets = c(list(self$graph$param_set),
193-
if (private$.validate) ps(validate = p_uty(default = NULL, tags = "train", custom_check = check_validate)
194-
)))
195-
}
196-
private$.param_set
205+
self$graph$param_set
197206
},
198207
graph = function(rhs) {
199208
if (!missing(rhs) && !identical(rhs, private$.graph)) stop("graph is read-only")
@@ -215,12 +224,17 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
215224
private = list(
216225
.graph = NULL,
217226
.validate = NULL,
218-
.param_set = NULL,
227+
.can_validate = NULL,
219228
.extract_inner_tuned_values = function() {
220229

230+
231+
warningf("Implementthis")
232+
list()
233+
221234
},
222235
.extract_inner_valid_scores = function() {
223-
.NotYetImplemented()
236+
warningf("Implementthis")
237+
list()
224238
# map(
225239
# keep(self$graph$pipeops, function(po) inherits(po, "PipeOpLearnerCV") || inherits(po, "PipeOpLearner")),
226240
# function(po) {
@@ -241,6 +255,17 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
241255
},
242256

243257
.train = function(task) {
258+
if (!is.null(get0("validate", self))) {
259+
some_pipeops_validate = map(
260+
filter(self$graph$pipeops, function(po) inherits(po, "PipeOpLearner") || inherits(po, "PipeOpLearnerCV")),
261+
function(po) !is.null(get0("validate", po$learner))
262+
)
263+
264+
if (!some_pipeops_validate) {
265+
lg$warn("GraphLearner '%s' specifies a validation set, but none of its Learners use it.", self$id)
266+
}
267+
}
268+
244269
on.exit({self$graph$state = NULL})
245270
self$graph$train(task)
246271
state = self$graph$state
@@ -288,90 +313,173 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
288313
)
289314
)
290315

316+
#' @title Configure Validation for a GraphLearner
317+
#'
318+
#' @description
319+
#' Configure validation for a graph learner.
320+
#'
321+
#' In a [`GraphLearner`], validation can be configured on two levels:
322+
#' 1. On the [`GraphLearner`] level.
323+
#' 2. On the level of the [`Learner`]s that are wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`].
324+
#'
325+
#' Therefore, enabling validation requires to specify not only how to create the validation set (1), but also which
326+
#' pipeops should actually use it.
327+
#' Only the [`GraphLearner`] can specify **how** to create the validation set.
328+
#' All learners wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`] can only set it to `NULL` (disable) or
329+
#' `"inner_valid"` (enable).
330+
#'
331+
#' @param learner ([`GraphLearner`])\cr
332+
#' The graph learner to configure.
333+
#' @param validate (`numeric(1)`, `"inner_valid"` or `NULL`)\cr
334+
#' How to set the `$validate` field of the learner.
335+
#' If set to `NULL` all validation is disabled.
336+
#' @param ids (`NULL` or `character()`)\cr
337+
#' For which pipeops to enable validation.
338+
#' This parameter is ignored when `validate` is set to `NULL`.
339+
#' By default, validation is enabled for the base learner.
340+
#' @param args (named `list()`)\cr
341+
#' Rarely needed.
342+
#' A named list of lists, specifying additional argments to be passed to [`set_validate()`] for the pipeops.
343+
#' The names must be a subset of `ids`.
344+
#' @export
345+
#' @examples
346+
#' # simple
347+
#' glrn = as_learner(po("pca") %>>% lrn("classif.debug"))
348+
#' set_validate(glrn, 0.3)
349+
#' glrn$validate
350+
#' glrn$graph$pipeops$classif.debug$learner$validate
351+
#' set_validate(glrn, NULL)
352+
#' glrn$validate
353+
#' glrn$graph$pipeops$classif.debug$learner$validate
354+
#'
355+
#' # complex
356+
#' glrn = as_learner(ppl("stacking", lrns(c("classif.debug", "classif.featureless")), lrn("classif.debug", id = "final")))
357+
#' set_validate(glrn, 0.2, which = c("classif.debug", "final"))
358+
#' glrn$validate
359+
#' glrn$graph$pipeops$classif.debug$learner$validate
360+
#' glrn$graph$pipeops$final$learner$validate
361+
set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list()) {
362+
if (is.null(learner$validate)) {
363+
learner$validate = NULL
364+
walk(learner_wrapping_pipeops(learner), function(po) {
365+
if (exists("validate", po$learner)) {
366+
po$learner$validate = NULL
367+
}
368+
})
369+
return(invisible(learner))
370+
}
291371

292-
#' @param ids (`character(1)`)\cr
293-
#' The ids of the parameters to disable.
372+
if (is.null(ids)) {
373+
which = learner$base_pipeop()$id
374+
} else {
375+
assert_subset(ids, ids(keep(learner_wrapping_pipeops(learner), function(po) "validation" %in% po$learner$properties)))
376+
assert_true(length(ids) > 0)
377+
}
378+
379+
assert_list(args, types = "list")
380+
assert_subset(names(args), ids)
381+
382+
prev_validate_pos = discard(map(learner_wrapping_pipeops(learner), function(po) get0("validate", po$learner), is.null))
383+
prev_validate = learner$validate
384+
385+
on.exit({
386+
iwalk(prev_validate_pos, function(val, poid) learner$graph$pipeops[[poid]] = val)
387+
learner$valiate = prev_validate
388+
}, add = TRUE)
389+
390+
learner$validate = validate
391+
392+
walk(ids, function(poid) {
393+
# learner might be another GraphLearner / AutoTuner
394+
invoke(set_validate learner = learner$graph$pipeops[[poid]]$learner, validate = "inner_valid", .args = args[[poid]])
395+
})
396+
on.exit()
397+
398+
invisible(learner)
399+
}
400+
401+
402+
#' @title Set Inner Tuning of a GraphLearner
403+
#' @description
404+
#' First, all values specified by `...` are
405+
#' All [`PipeOpLearner`] and [`PipeOpLearnerCV`]
406+
#' @param validate (`numeric(1)`, `"inner_valid"`, or `NULL`)\cr
407+
#' How to set the `$validate` field of the learner.
408+
#' @param args (named `list()`)\cr
409+
#' Names are ids of the [`GraphLearner`]'s `PipeOps` and values are lists containing arguments passed to the
410+
#' respective wrapped [`Learner`].
411+
#' By default, the values `.disable` and `validate` are used, but can be overwritten on a per-pipeop basis.
412+
#'
294413
#' When enabling, the inner tuning of the `$base_learner()` is enabled by default.
295414
#' When disabling, all inner tuning is disable by default.
296415
#' @export
297-
set_inner_tuning.GraphLearner = function(learner, disable = FALSE, ids = NULL, param_vals = list(), ...) {
298-
all_pipeops = learner$graph$pipeops
299-
lrn_pipeops = all_pipeops[inherits(all_pipeops, "PipeOpLearner") | inherits(all_pipeops, "PipeOpLearnerCV")]
300-
301-
if (is.null(ids) && disable) {
302-
ids = as.character(unlist(imap(lrn_pipeops, function(po, prefix) {
303-
sprintf("%s.%s", prefix, names(po$param_set$tags[map_lgl(po$param_set$tags, function(t) "inner_tuning" %in% t)]))
304-
})))
305-
} else if (is.null(ids) && !disable) {
306-
lrn_base = learner$base_learner()
307-
308-
# need to find the pipeop that is the base learner. Cannot directly use id, because id of pipeop might
309-
# differ from learner id
310-
po_baselrn = NULL
311-
for (po in lrn_pipeops[inherits(po, "PipeOpLearner")]) {
312-
if (identical(po$learner, lrn_base)) {
313-
po_baselrn = po
314-
break
315-
}
316-
}
317-
ids = paste0(
318-
po_baselrn$id, ".",
319-
names(po_baselrn$param_set$tags[map_lgl(po_baselrn$param_set$tags, function(tags) "inner_tuning" %in% tags)])
320-
)
416+
set_inner_tuning.GraphLearner = function(.learner, .disable = FALSE, validate = NA, args = NULL, ...) {
417+
if (is.null(args)) {
418+
args = set_names(list(list()), .learner$base_pipeops()$id
321419
}
322-
assert_subset(ids, learner$param_set$ids())
323-
pv_prev = learner$param_set$values
420+
all_pipeops = .learner$graph$pipeops
421+
lrn_pipeops = learner_wrapping_pipeops(all_pipeops)
324422

325-
# reset to previous pvs if anything goes wrong
326-
on.exit({learner$param_set$set_values(.values = pv_prev)}, add = TRUE)
423+
assert_list(args, names = "unique")
424+
assert_subset(names(args), ids(lrn_pipeops))
327425

328-
learner$param_set$set_values(.values = param_vals)
329426

427+
# clean up when something goes wrong
428+
prev_pvs = .learner$param_set$values
429+
prev_validate = discard(map(lrn_pipeops, function(po) if (exists("validate", po$learner)) po$learner$validate), is.null)
430+
on.exit({
431+
.learner$param_set$set_values(.values = prev_pvs)
432+
iwalk(prev_validate, function(val, poid) .learner$graph$pipeops[[poid]]$learner$validate = val)
433+
}, add = TRUE)
330434

331-
# pipeop_ids are those learners that wrap a learner and have a parameter that is containes in ids
332-
po_ids = as.character(unlist(discard(map(lrn_pipeops, function(po) {
333-
if (some(names(param_vals) %in% sprintf("%s.%s", po$id, po$param_set$ids()))) po$id
334-
}), is.null)))
335-
336-
# now we walk through the learners and call set_inner_tuning() WITHOUT passing the parameters, as we have already
337-
# set them above
338-
walk(lrn_pipeops[po_ids], set_inner_tuning, disable = disable)
435+
walk(lrn_pipeops[names(args)], function(po) {
436+
browser()
437+
invoke(set_inner_tuning, .learner = po$learner,
438+
.args = insert_named(list(validate = validate, .disable = .disable), args[[po$id]])
439+
)
440+
})
339441

340-
# now put up some extra guardrails because it is not intuitive how to configure validation in the GraphLearner
442+
# Now:
443+
# Set validate for GraphLearner and verify that the configuration is reasonable
341444

342-
some_pipeops_validate = FALSE
343-
if (disable) {
344-
for (po in lrn_pipeops) {
345-
if (!is.null(po$param_set$values$validate)) {
346-
some_pipeops_validate = TRUE
347-
break
445+
if (.disable) {
446+
.learner$validate = if (identical(validate, NA)) NULL else validate
447+
some_pipeops_validate = some(lrn_pipeops, function(po) {
448+
if (!exists("validate", po$learner)) {
449+
return(FALSE)
348450
}
349-
}
451+
!is.null(po$learner$validate)
452+
})
350453
# if none of the pipeops does validation, we also disable it in the GraphLearner
351-
# (unless a value was explicitly passed via param_vals)
352-
if (!some_pipeops_validate && is.null(param_vals$validate)) {
353-
learner$param_set$set_values(validate = NULL)
454+
# (unless a value was explicitly specified)
455+
if (!some_pipeops_validate && identical(validate, NA)) {
456+
.learner$validate = NULL
354457
}
355458
} else {
356-
for (po in lrn_pipeops) {
357-
if (!is.null(po$param_set$values$validate) && is.null(learner$param_set$values$validate)) {
459+
if (!identical(validate, NA)) {
460+
.learner$validate = validate
461+
}
462+
463+
some_pipeops_validate = some(lrn_pipeops, function(po) {
464+
if (is.null(get0("validate", po$learner))) return(FALSE)
465+
if (is.null(.learner$validate)) {
358466
warningf("PipeOp '%s' from GraphLearner '%s' wants a validation set but GraphLearner does not specify one. This likely not what you want.",
359-
po$id, learner$id)
467+
po$id, .learner$id)
360468
}
361-
if (!is.null(po$param_set$values$validate)) {
362-
if (!identical(po$param_set$values$validate, "inner_valid")) {
363-
warningf("PipeOp '%s' from GraphLearner '%s' specifies validation set other than 'inner_valid'. This is likely not what you want.")
364-
}
365-
some_pipeops_validate = TRUE
469+
if (!identical(po$learner$validate, "inner_valid")) {
470+
warningf("PipeOp '%s' from GraphLearner '%s' specifies validation set other than 'inner_valid'. This is likely not what you want.",
471+
po$id, .learner$id)
366472
}
367-
}
368-
if (!is.null(learner$param_set$values$validate) && !some_pipeops_validate) {
369-
warningf("GraphLearner '%s' specifies a validation set, but none of its Learners use it.", learner$id)
473+
TRUE
474+
})
475+
476+
if (!is.null(.learner$param_set$values$validate) && !some_pipeops_validate) {
477+
warningf("GraphLearner '%s' specifies a validation set, but none of its Learners use it. This is likely not what you want.", .learner$id)
370478
}
371479
}
372480

373481
on.exit()
374-
invisible(learner)
482+
invisible(.learner)
375483
}
376484

377485
#' @export

R/PipeOpLearner.R

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,15 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
9797
}
9898
),
9999
active = list(
100+
validate = function(rhs) {
101+
if (!missing(rhs)) {
102+
if (is.null(rhs) || identical(rhs, "inner_valid")) {
103+
stopf("The validate field of PipeOpLearner can only be set to NULL or inner_valid. You probably meant to configure the validate field of the GraphLearner")
104+
}
105+
private$.learner$validate = rhs
106+
}
107+
private$.learner$validate
108+
},
100109
id = function(val) {
101110
if (!missing(val)) {
102111
private$.id = val

R/utils.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,15 @@ check_validate = function(x) {
141141
}
142142
check_choice(x, c("inner_valid", "test"), null.ok = TRUE)
143143
}
144+
145+
learning_wrapping_pipeops = function(x) {
146+
if (inherits(x, "Graph")) {
147+
x = x$pipeops
148+
} else if (inherits(x, "GraphLearner")) {
149+
x = x$graph$pipeops
150+
} else {
151+
assert_list(x, types = "PipeOp")
152+
}
153+
154+
keep(x, function(po) inherits(po, "PipeOpLearner") || inherits(po, "PipeOpLearnerCV"))
155+
}

0 commit comments

Comments
 (0)