@@ -98,14 +98,30 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
9898 }
9999 assert_subset(task_type , mlr_reflections $ task_types $ type )
100100
101+
102+ private $ .validate = some(
103+ keep(graph $ pipeops , function (x ) inherits(x , " PipeOpLearner" ) || inherits(x , " PipeOpLearnerCV" )),
104+ function (po ) " validation" %in% po $ learner $ properties
105+ )
106+
107+ inner_tuning = some(
108+ keep(graph $ pipeops , function (x ) inherits(x , " PipeOpLearner" ) || inherits(x , " PipeOpLearnerCV" )),
109+ function (po ) " inner_tuning" %in% po $ learner $ properties
110+ )
111+
112+ properties = setdiff(mlr_reflections $ learner_properties [[task_type ]],
113+ c(" validation" , " inner_tuning" )[c(! private $ .validate , ! inner_tuning )])
114+
101115 super $ initialize(id = id , task_type = task_type ,
102116 feature_types = mlr_reflections $ task_feature_types ,
103117 predict_types = names(mlr_reflections $ learner_predict_types [[task_type ]]),
104118 packages = graph $ packages ,
105- properties = mlr_reflections $ learner_properties [[ task_type ]] ,
119+ properties = properties ,
106120 man = " mlr3pipelines::GraphLearner"
107121 )
108122
123+ private $ .param_set = NULL
124+
109125 if (length(param_vals )) {
110126 private $ .graph $ param_set $ values = insert_named(private $ .graph $ param_set $ values , param_vals )
111127 }
@@ -132,6 +148,25 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
132148 if (length(last_pipeop_id ) == 0 ) stop(" No Learner PipeOp found." )
133149 }
134150 learner_model $ base_learner(recursive - 1 )
151+ },
152+
153+ # ' @description
154+ # ' Retrieves the inner validation scores as a named `list()`.
155+ inner_valid_scores = function (rhs ) {
156+ assert_ro_binding(rhs )
157+ if (is.null(self $ state )) {
158+ stopf(" Learner not trained" )
159+ }
160+ self $ state $ inner_valid_scores
161+ },
162+ # ' @description
163+ # ' Retrieves the inner tuned values as a named `list()`.
164+ inner_tuned_values = function (rhs ) {
165+ assert_ro_binding(rhs )
166+ if (is.null(self $ state )) {
167+ stopf(" Learner not trained" )
168+ }
169+ self $ state $ inner_tuned_values
135170 }
136171 ),
137172 active = list (
@@ -153,7 +188,12 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
153188 if (! missing(rhs ) && ! identical(rhs , self $ graph $ param_set )) {
154189 stop(" param_set is read-only." )
155190 }
156- self $ graph $ param_set
191+ if (is.null(private $ .param_set )) {
192+ private $ .param_set = ParamSetCollection $ new(sets = c(list (self $ graph $ param_set ),
193+ if (private $ .validate ) ps(validate = p_uty(default = NULL , tags = " train" , custom_check = check_validate )
194+ )))
195+ }
196+ private $ .param_set
157197 },
158198 graph = function (rhs ) {
159199 if (! missing(rhs ) && ! identical(rhs , private $ .graph )) stop(" graph is read-only" )
@@ -174,7 +214,22 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
174214 ),
175215 private = list (
176216 .graph = NULL ,
217+ .validate = NULL ,
218+ .param_set = NULL ,
219+ .extract_inner_tuned_values = function () {
220+
221+ },
222+ .extract_inner_valid_scores = function () {
223+ .NotYetImplemented()
224+ # map(
225+ # keep(self$graph$pipeops, function(po) inherits(po, "PipeOpLearnerCV") || inherits(po, "PipeOpLearner")),
226+ # function(po) {
227+ # po$inner_
228+ # }
229+ # )
230+ },
177231 deep_clone = function (name , value ) {
232+ private $ .param_set = NULL
178233 # FIXME this repairs the mlr3::Learner deep_clone() method which is broken.
179234 if (is.environment(value ) && ! is.null(value [[" .__enclos_env__" ]])) {
180235 return (value $ clone(deep = TRUE ))
@@ -233,6 +288,92 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
233288 )
234289)
235290
291+
292+ # ' @param ids (`character(1)`)\cr
293+ # ' The ids of the parameters to disable.
294+ # ' When enabling, the inner tuning of the `$base_learner()` is enabled by default.
295+ # ' When disabling, all inner tuning is disable by default.
296+ # ' @export
297+ set_inner_tuning.GraphLearner = function (learner , disable = FALSE , ids = NULL , param_vals = list (), ... ) {
298+ all_pipeops = learner $ graph $ pipeops
299+ lrn_pipeops = all_pipeops [inherits(all_pipeops , " PipeOpLearner" ) | inherits(all_pipeops , " PipeOpLearnerCV" )]
300+
301+ if (is.null(ids ) && disable ) {
302+ ids = as.character(unlist(imap(lrn_pipeops , function (po , prefix ) {
303+ sprintf(" %s.%s" , prefix , names(po $ param_set $ tags [map_lgl(po $ param_set $ tags , function (t ) " inner_tuning" %in% t )]))
304+ })))
305+ } else if (is.null(ids ) && ! disable ) {
306+ lrn_base = learner $ base_learner()
307+
308+ # need to find the pipeop that is the base learner. Cannot directly use id, because id of pipeop might
309+ # differ from learner id
310+ po_baselrn = NULL
311+ for (po in lrn_pipeops [inherits(po , " PipeOpLearner" )]) {
312+ if (identical(po $ learner , lrn_base )) {
313+ po_baselrn = po
314+ break
315+ }
316+ }
317+ ids = paste0(
318+ po_baselrn $ id , " ." ,
319+ names(po_baselrn $ param_set $ tags [map_lgl(po_baselrn $ param_set $ tags , function (tags ) " inner_tuning" %in% tags )])
320+ )
321+ }
322+ assert_subset(ids , learner $ param_set $ ids())
323+ pv_prev = learner $ param_set $ values
324+
325+ # reset to previous pvs if anything goes wrong
326+ on.exit({learner $ param_set $ set_values(.values = pv_prev )}, add = TRUE )
327+
328+ learner $ param_set $ set_values(.values = param_vals )
329+
330+
331+ # pipeop_ids are those learners that wrap a learner and have a parameter that is containes in ids
332+ po_ids = as.character(unlist(discard(map(lrn_pipeops , function (po ) {
333+ if (some(names(param_vals ) %in% sprintf(" %s.%s" , po $ id , po $ param_set $ ids()))) po $ id
334+ }), is.null )))
335+
336+ # now we walk through the learners and call set_inner_tuning() WITHOUT passing the parameters, as we have already
337+ # set them above
338+ walk(lrn_pipeops [po_ids ], set_inner_tuning , disable = disable )
339+
340+ # now put up some extra guardrails because it is not intuitive how to configure validation in the GraphLearner
341+
342+ some_pipeops_validate = FALSE
343+ if (disable ) {
344+ for (po in lrn_pipeops ) {
345+ if (! is.null(po $ param_set $ values $ validate )) {
346+ some_pipeops_validate = TRUE
347+ break
348+ }
349+ }
350+ # if none of the pipeops does validation, we also disable it in the GraphLearner
351+ # (unless a value was explicitly passed via param_vals)
352+ if (! some_pipeops_validate && is.null(param_vals $ validate )) {
353+ learner $ param_set $ set_values(validate = NULL )
354+ }
355+ } else {
356+ for (po in lrn_pipeops ) {
357+ if (! is.null(po $ param_set $ values $ validate ) && is.null(learner $ param_set $ values $ validate )) {
358+ warningf(" PipeOp '%s' from GraphLearner '%s' wants a validation set but GraphLearner does not specify one. This likely not what you want." ,
359+ po $ id , learner $ id )
360+ }
361+ if (! is.null(po $ param_set $ values $ validate )) {
362+ if (! identical(po $ param_set $ values $ validate , " inner_valid" )) {
363+ warningf(" PipeOp '%s' from GraphLearner '%s' specifies validation set other than 'inner_valid'. This is likely not what you want." )
364+ }
365+ some_pipeops_validate = TRUE
366+ }
367+ }
368+ if (! is.null(learner $ param_set $ values $ validate ) && ! some_pipeops_validate ) {
369+ warningf(" GraphLearner '%s' specifies a validation set, but none of its Learners use it." , learner $ id )
370+ }
371+ }
372+
373+ on.exit()
374+ invisible (learner )
375+ }
376+
236377# ' @export
237378as_learner.Graph = function (x , clone = FALSE , ... ) {
238379 GraphLearner $ new(x , clone_graph = clone )
0 commit comments