|
61 | 61 | #' * `predict_time` :: `NULL` | `numeric(1)` |
62 | 62 | #' Prediction time, in seconds. |
63 | 63 | #' |
| 64 | +#' This state is given the class `"pipeop_learner_cv_state"`. |
| 65 | +#' |
64 | 66 | #' @section Parameters: |
65 | 67 | #' The parameters are the parameters inherited from the [`PipeOpTaskPreproc`], as well as the parameters of the [`Learner`][mlr3::Learner] wrapped by this object. |
66 | 68 | #' Besides that, parameters introduced are: |
@@ -142,8 +144,14 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV", |
142 | 144 | # private$.crossval_param_set$add_dep("folds", "method", CondEqual$new("cv")) # don't do this. |
143 | 145 |
|
144 | 146 | super$initialize(id, alist(resampling = private$.crossval_param_set, private$.learner$param_set), param_vals = param_vals, can_subset_cols = TRUE, task_type = task_type, tags = c("learner", "ensemble")) |
| 147 | + }, |
| 148 | + train = function(inputs) { |
| 149 | + outputs = super$train(inputs) |
| 150 | + self$state = multiplicity_recurse(self$state, function(state) { |
| 151 | + structure(state, class = c("pipeop_learner_cv_state", class(state))) |
| 152 | + }) |
| 153 | + return(outputs) |
145 | 154 | } |
146 | | - |
147 | 155 | ), |
148 | 156 | active = list( |
149 | 157 | learner = function(val) { |
@@ -177,7 +185,6 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV", |
177 | 185 | private = list( |
178 | 186 | .train = function(inputs) { |
179 | 187 | out = super$.train(inputs) |
180 | | - self$state = structure(self$state, class = c("pipeop_learner_cv_state", class(self$state))) |
181 | 188 | return(out) |
182 | 189 | }, |
183 | 190 | .train_task = function(task) { |
@@ -232,25 +239,22 @@ marshal_model.pipeop_learner_cv_state = function(model, inplace = FALSE, ...) { |
232 | 239 | # Note that a Learner state contains other reference objects, but we don't clone them here, even when inplace |
233 | 240 | # is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3 |
234 | 241 | # workhorse function |
235 | | - prev_class = class(model) |
236 | 242 | model$model = marshal_model(model$model, inplace = inplace) |
237 | 243 | # only wrap this in a marshaled class if the model was actually marshaled above |
238 | 244 | # (the default marshal method does nothing) |
239 | 245 | if (is_marshaled_model(model$model)) { |
240 | 246 | model = structure( |
241 | 247 | list(marshaled = model, packages = "mlr3pipelines"), |
242 | | - class = c(paste0(prev_class, "_marshaled"), "marshaled") |
| 248 | + class = c(paste0(class(model), "_marshaled"), "marshaled") |
243 | 249 | ) |
244 | 250 | } |
245 | 251 | model |
246 | 252 | } |
247 | 253 |
|
248 | 254 | #' @export |
249 | 255 | unmarshal_model.pipeop_learner_cv_state_marshaled = function(model, inplace = FALSE, ...) { |
250 | | - prev_class = head(class(model), n = -1) |
251 | 256 | state_marshaled = model$marshaled |
252 | 257 | state_marshaled$model = unmarshal_model(state_marshaled$model, inplace = inplace) |
253 | | - class(state_marshaled) = gsub(x = prev_class, pattern = "_marshaled$", replacement = "") |
254 | 258 | state_marshaled |
255 | 259 | } |
256 | 260 |
|
|
0 commit comments