Skip to content

Commit 1652570

Browse files
committed
...
1 parent a4933e1 commit 1652570

File tree

83 files changed

+662
-448
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+662
-448
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ Config/testthat/edition: 3
9393
Config/testthat/parallel: true
9494
NeedsCompilation: no
9595
Roxygen: list(markdown = TRUE, r6 = FALSE)
96-
RoxygenNote: 7.2.3.9000
96+
RoxygenNote: 7.3.1
9797
VignetteBuilder: knitr
9898
Collate:
9999
'Graph.R'

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ S3method(pos,list)
2323
S3method(predict,Graph)
2424
S3method(print,Multiplicity)
2525
S3method(print,Selector)
26+
S3method(set_inner_tuning,GraphLearner)
2627
export("%>>!%")
2728
export("%>>%")
2829
export(Graph)

R/GraphLearner.R

Lines changed: 143 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,30 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
9898
}
9999
assert_subset(task_type, mlr_reflections$task_types$type)
100100

101+
102+
private$.validate = some(
103+
keep(graph$pipeops, function(x) inherits(x, "PipeOpLearner") || inherits(x, "PipeOpLearnerCV")),
104+
function(po) "validation" %in% po$learner$properties
105+
)
106+
107+
inner_tuning = some(
108+
keep(graph$pipeops, function(x) inherits(x, "PipeOpLearner") || inherits(x, "PipeOpLearnerCV")),
109+
function(po) "inner_tuning" %in% po$learner$properties
110+
)
111+
112+
properties = setdiff(mlr_reflections$learner_properties[[task_type]],
113+
c("validation", "inner_tuning")[c(!private$.validate, !inner_tuning)])
114+
101115
super$initialize(id = id, task_type = task_type,
102116
feature_types = mlr_reflections$task_feature_types,
103117
predict_types = names(mlr_reflections$learner_predict_types[[task_type]]),
104118
packages = graph$packages,
105-
properties = mlr_reflections$learner_properties[[task_type]],
119+
properties = properties,
106120
man = "mlr3pipelines::GraphLearner"
107121
)
108122

123+
private$.param_set = NULL
124+
109125
if (length(param_vals)) {
110126
private$.graph$param_set$values = insert_named(private$.graph$param_set$values, param_vals)
111127
}
@@ -132,6 +148,25 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
132148
if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.")
133149
}
134150
learner_model$base_learner(recursive - 1)
151+
},
152+
153+
#' @description
154+
#' Retrieves the inner validation scores as a named `list()`.
155+
inner_valid_scores = function(rhs) {
156+
assert_ro_binding(rhs)
157+
if (is.null(self$state)) {
158+
stopf("Learner not trained")
159+
}
160+
self$state$inner_valid_scores
161+
},
162+
#' @description
163+
#' Retrieves the inner tuned values as a named `list()`.
164+
inner_tuned_values = function(rhs) {
165+
assert_ro_binding(rhs)
166+
if (is.null(self$state)) {
167+
stopf("Learner not trained")
168+
}
169+
self$state$inner_tuned_values
135170
}
136171
),
137172
active = list(
@@ -153,7 +188,12 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
153188
if (!missing(rhs) && !identical(rhs, self$graph$param_set)) {
154189
stop("param_set is read-only.")
155190
}
156-
self$graph$param_set
191+
if (is.null(private$.param_set)) {
192+
private$.param_set = ParamSetCollection$new(sets = c(list(self$graph$param_set),
193+
if (private$.validate) ps(validate = p_uty(default = NULL, tags = "train", custom_check = check_validate)
194+
)))
195+
}
196+
private$.param_set
157197
},
158198
graph = function(rhs) {
159199
if (!missing(rhs) && !identical(rhs, private$.graph)) stop("graph is read-only")
@@ -174,7 +214,22 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
174214
),
175215
private = list(
176216
.graph = NULL,
217+
.validate = NULL,
218+
.param_set = NULL,
219+
.extract_inner_tuned_values = function() {
220+
221+
},
222+
.extract_inner_valid_scores = function() {
223+
.NotYetImplemented()
224+
# map(
225+
# keep(self$graph$pipeops, function(po) inherits(po, "PipeOpLearnerCV") || inherits(po, "PipeOpLearner")),
226+
# function(po) {
227+
# po$inner_
228+
# }
229+
# )
230+
},
177231
deep_clone = function(name, value) {
232+
private$.param_set = NULL
178233
# FIXME this repairs the mlr3::Learner deep_clone() method which is broken.
179234
if (is.environment(value) && !is.null(value[[".__enclos_env__"]])) {
180235
return(value$clone(deep = TRUE))
@@ -233,6 +288,92 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
233288
)
234289
)
235290

291+
292+
#' @param ids (`character(1)`)\cr
293+
#' The ids of the parameters to disable.
294+
#' When enabling, the inner tuning of the `$base_learner()` is enabled by default.
295+
#' When disabling, all inner tuning is disable by default.
296+
#' @export
297+
set_inner_tuning.GraphLearner = function(learner, disable = FALSE, ids = NULL, param_vals = list(), ...) {
298+
all_pipeops = learner$graph$pipeops
299+
lrn_pipeops = all_pipeops[inherits(all_pipeops, "PipeOpLearner") | inherits(all_pipeops, "PipeOpLearnerCV")]
300+
301+
if (is.null(ids) && disable) {
302+
ids = as.character(unlist(imap(lrn_pipeops, function(po, prefix) {
303+
sprintf("%s.%s", prefix, names(po$param_set$tags[map_lgl(po$param_set$tags, function(t) "inner_tuning" %in% t)]))
304+
})))
305+
} else if (is.null(ids) && !disable) {
306+
lrn_base = learner$base_learner()
307+
308+
# need to find the pipeop that is the base learner. Cannot directly use id, because id of pipeop might
309+
# differ from learner id
310+
po_baselrn = NULL
311+
for (po in lrn_pipeops[inherits(po, "PipeOpLearner")]) {
312+
if (identical(po$learner, lrn_base)) {
313+
po_baselrn = po
314+
break
315+
}
316+
}
317+
ids = paste0(
318+
po_baselrn$id, ".",
319+
names(po_baselrn$param_set$tags[map_lgl(po_baselrn$param_set$tags, function(tags) "inner_tuning" %in% tags)])
320+
)
321+
}
322+
assert_subset(ids, learner$param_set$ids())
323+
pv_prev = learner$param_set$values
324+
325+
# reset to previous pvs if anything goes wrong
326+
on.exit({learner$param_set$set_values(.values = pv_prev)}, add = TRUE)
327+
328+
learner$param_set$set_values(.values = param_vals)
329+
330+
331+
# pipeop_ids are those learners that wrap a learner and have a parameter that is containes in ids
332+
po_ids = as.character(unlist(discard(map(lrn_pipeops, function(po) {
333+
if (some(names(param_vals) %in% sprintf("%s.%s", po$id, po$param_set$ids()))) po$id
334+
}), is.null)))
335+
336+
# now we walk through the learners and call set_inner_tuning() WITHOUT passing the parameters, as we have already
337+
# set them above
338+
walk(lrn_pipeops[po_ids], set_inner_tuning, disable = disable)
339+
340+
# now put up some extra guardrails because it is not intuitive how to configure validation in the GraphLearner
341+
342+
some_pipeops_validate = FALSE
343+
if (disable) {
344+
for (po in lrn_pipeops) {
345+
if (!is.null(po$param_set$values$validate)) {
346+
some_pipeops_validate = TRUE
347+
break
348+
}
349+
}
350+
# if none of the pipeops does validation, we also disable it in the GraphLearner
351+
# (unless a value was explicitly passed via param_vals)
352+
if (!some_pipeops_validate && is.null(param_vals$validate)) {
353+
learner$param_set$set_values(validate = NULL)
354+
}
355+
} else {
356+
for (po in lrn_pipeops) {
357+
if (!is.null(po$param_set$values$validate) && is.null(learner$param_set$values$validate)) {
358+
warningf("PipeOp '%s' from GraphLearner '%s' wants a validation set but GraphLearner does not specify one. This likely not what you want.",
359+
po$id, learner$id)
360+
}
361+
if (!is.null(po$param_set$values$validate)) {
362+
if (!identical(po$param_set$values$validate, "inner_valid")) {
363+
warningf("PipeOp '%s' from GraphLearner '%s' specifies validation set other than 'inner_valid'. This is likely not what you want.")
364+
}
365+
some_pipeops_validate = TRUE
366+
}
367+
}
368+
if (!is.null(learner$param_set$values$validate) && !some_pipeops_validate) {
369+
warningf("GraphLearner '%s' specifies a validation set, but none of its Learners use it.", learner$id)
370+
}
371+
}
372+
373+
on.exit()
374+
invisible(learner)
375+
}
376+
236377
#' @export
237378
as_learner.Graph = function(x, clone = FALSE, ...) {
238379
GraphLearner$new(x, clone_graph = clone)

R/PipeOpImpute.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,11 @@ PipeOpImpute = R6Class("PipeOpImpute",
193193

194194
intask$select(setdiff(intask$feature_names, colnames(imputanda)))$cbind(imputanda)
195195

196-
self$state$outtasklayout = copy(intask$feature_types)
196+
self$state$outtasklayouto = copy(intask$feature_types)
197+
198+
if (!is.null(intask$inner_valid_task)) {
199+
intask$inner_valid_task = private$.predict(list(intask$inner_valid_task))[[1L]]
200+
}
197201

198202
list(intask)
199203
},

R/PipeOpTaskPreproc.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,17 @@ PipeOpTaskPreproc = R6Class("PipeOpTaskPreproc",
221221
self$state$outtasklayout = copy(intask$feature_types)
222222
self$state$outtaskshell = intask$data(rows = intask$row_ids[0])
223223

224+
if (!is.null(intask$inner_valid_task)) {
225+
# we call into .predict() and not .predict_task() to not put the burden
226+
# of subsetting the features etc. on the PipeOp overwriting .predict_task
227+
intask$inner_valid_task = private$.predict(list(intask$inner_valid_task))[[1L]]
228+
}
229+
224230
if (do_subset) {
225231
# FIXME: this fails if .train_task added a column with the same name
226232
intask$col_roles$feature = union(intask$col_roles$feature, y = remove_cols)
227233
}
234+
228235
list(intask)
229236
},
230237

R/utils.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,10 @@ dictionary_sugar_inc_mget = function(dict, .keys, ...) {
134134
names(objs) = map_chr(objs, "id")
135135
objs
136136
}
137+
138+
check_validate = function(x) {
139+
if (test_numeric(x, lower = 0, upper = 1, len = 1L)) {
140+
return(TRUE)
141+
}
142+
check_choice(x, c("inner_valid", "test"), null.ok = TRUE)
143+
}

man/Graph.Rd

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/PipeOp.Rd

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/PipeOpEnsemble.Rd

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)