Skip to content

Commit 8218e07

Browse files
committed
multiplicities
1 parent d0331e4 commit 8218e07

12 files changed

+203
-40
lines changed

NAMESPACE

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ S3method(as_pipeop,Filter)
1111
S3method(as_pipeop,Learner)
1212
S3method(as_pipeop,PipeOp)
1313
S3method(as_pipeop,default)
14+
S3method(marshal_model,Multiplicity)
1415
S3method(marshal_model,graph_learner_model)
16+
S3method(marshal_model,pipeop_impute_learner_state)
1517
S3method(marshal_model,pipeop_learner_cv_state)
16-
S3method(marshal_model,pipeop_learner_state)
1718
S3method(po,"NULL")
1819
S3method(po,Filter)
1920
S3method(po,Learner)
@@ -26,9 +27,10 @@ S3method(pos,list)
2627
S3method(predict,Graph)
2728
S3method(print,Multiplicity)
2829
S3method(print,Selector)
30+
S3method(unmarshal_model,Multiplicity_marshaled)
2931
S3method(unmarshal_model,graph_learner_model_marshaled)
32+
S3method(unmarshal_model,pipeop_impute_learner_state_marshaled)
3033
S3method(unmarshal_model,pipeop_learner_cv_state_marshaled)
31-
S3method(unmarshal_model,pipeop_learner_state_marshaled)
3234
export("%>>!%")
3335
export("%>>%")
3436
export(Graph)

R/GraphLearner.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ marshal_model.graph_learner_model = function(model, inplace = FALSE, ...) {
267267

268268
#' @export
269269
unmarshal_model.graph_learner_model_marshaled = function(model, inplace = FALSE, ...) {
270+
# need to re-create the class as it gets lost during marshaling
270271
structure(
271272
map(.x = model$marshaled, .f = unmarshal_model, inplace = inplace, ...),
272273
class = gsub(x = head(class(model), n = -1), pattern = "_marshaled$", replacement = "")

R/PipeOpImputeLearner.R

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
#' for each column. If a column consists of missing values only during training, the `model` is `0` or the levels of the
4545
#' feature; these are used for sampling during prediction.
4646
#'
47+
#' This state is given the class `"pipeop_impute_learner_state"`.
48+
#'
4749
#' @section Parameters:
4850
#' The parameters are the parameters inherited from [`PipeOpImpute`], in addition to the parameters of the [`Learner`][mlr3::Learner]
4951
#' used for imputation.
@@ -114,6 +116,13 @@ PipeOpImputeLearner = R6Class("PipeOpImputeLearner",
114116
)
115117
super$initialize(id, param_set = alist(private$.learner$param_set), param_vals = param_vals,
116118
whole_task_dependent = TRUE, feature_types = feature_types)
119+
},
120+
train = function(inputs) {
121+
outputs = super$train(inputs)
122+
self$state = multiplicity_recurse(self$state, function(state) {
123+
structure(state, class = c("pipeop_impute_learner_state", class(state)))
124+
})
125+
return(outputs)
117126
}
118127
),
119128
active = list(
@@ -204,3 +213,25 @@ mlr_pipeops$add("imputelearner", PipeOpImputeLearner, list(R6Class("Learner", pu
204213
convert_to_task = function(id = "imputing", data, target, task_type, ...) {
205214
get(mlr_reflections$task_types[task_type, mult = "first"]$task)$new(id = id, backend = data, target = target, ...)
206215
}
216+
217+
#' @export
218+
marshal_model.pipeop_impute_learner_state = function(model, inplace = FALSE, ...) {
219+
prev_class = class(model)
220+
model$model = map(model$model, marshal_model, inplace = inplace, ...)
221+
222+
if (!some(model$model, is_marshaled_model)) {
223+
return(model)
224+
}
225+
226+
structure(
227+
list(marshaled = model, packages = "mlr3pipelines"),
228+
class = c(paste0(prev_class, "_marshaled"), "marshaled")
229+
)
230+
}
231+
232+
#' @export
233+
unmarshal_model.pipeop_impute_learner_state_marshaled = function(model, inplace = FALSE, ...) {
234+
state = model$marshaled
235+
state$model = map(state$model, unmarshal_model, inplace = inplace, ...)
236+
return(state)
237+
}

R/PipeOpLearner.R

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
139139
.train = function(inputs) {
140140
on.exit({private$.learner$state = NULL})
141141
task = inputs[[1L]]
142-
learner_state = private$.learner$train(task)$state
143-
self$state = structure(learner_state, class = c("pipeop_learner_state", class(learner_state)))
142+
self$state = private$.learner$train(task)$state
144143

145144
list(NULL)
146145
},
@@ -155,30 +154,4 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
155154
)
156155
)
157156

158-
#' @export
159-
marshal_model.pipeop_learner_state = function(model, inplace = FALSE, ...) {
160-
# Note that a Learner state contains other objects with reference semantics, but we don't clone them here, even when inplace
161-
# is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3
162-
# workhorse function
163-
prev_class = class(model)
164-
model$model = marshal_model(model$model, inplace = inplace)
165-
# only wrap this in a marshaled class if the model was actually marshaled above
166-
# (the default marshal method does nothing)
167-
if (!is_marshaled_model(model$model)) return(model)
168-
structure(
169-
list(marshaled = model, packages = "mlr3pipelines"),
170-
class = c(paste0(prev_class, "_marshaled"), "marshaled")
171-
)
172-
}
173-
174-
#' @export
175-
unmarshal_model.pipeop_learner_state_marshaled = function(model, inplace = FALSE, ...) {
176-
prev_class = head(class(model), n = -1)
177-
state_marshaled = model$marshaled
178-
state_marshaled$model = unmarshal_model(state_marshaled$model, inplace = inplace)
179-
class(state_marshaled) = gsub(x = prev_class, pattern = "_marshaled$", replacement = "")
180-
state_marshaled
181-
}
182-
183-
184157
mlr_pipeops$add("learner", PipeOpLearner, list(R6Class("Learner", public = list(id = "learner", task_type = "classif", param_set = ps(), packages = "mlr3pipelines"))$new()))

R/PipeOpLearnerCV.R

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
#' * `predict_time` :: `NULL` | `numeric(1)`
6262
#' Prediction time, in seconds.
6363
#'
64+
#' This state is given the class `"pipeop_learner_cv_state"`.
65+
#'
6466
#' @section Parameters:
6567
#' The parameters are the parameters inherited from the [`PipeOpTaskPreproc`], as well as the parameters of the [`Learner`][mlr3::Learner] wrapped by this object.
6668
#' Besides that, parameters introduced are:
@@ -142,8 +144,14 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
142144
# private$.crossval_param_set$add_dep("folds", "method", CondEqual$new("cv")) # don't do this.
143145

144146
super$initialize(id, alist(resampling = private$.crossval_param_set, private$.learner$param_set), param_vals = param_vals, can_subset_cols = TRUE, task_type = task_type, tags = c("learner", "ensemble"))
147+
},
148+
train = function(inputs) {
149+
outputs = super$train(inputs)
150+
self$state = multiplicity_recurse(self$state, function(state) {
151+
structure(state, class = c("pipeop_learner_cv_state", class(state)))
152+
})
153+
return(outputs)
145154
}
146-
147155
),
148156
active = list(
149157
learner = function(val) {
@@ -177,7 +185,6 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
177185
private = list(
178186
.train = function(inputs) {
179187
out = super$.train(inputs)
180-
self$state = structure(self$state, class = c("pipeop_learner_cv_state", class(self$state)))
181188
return(out)
182189
},
183190
.train_task = function(task) {
@@ -232,25 +239,22 @@ marshal_model.pipeop_learner_cv_state = function(model, inplace = FALSE, ...) {
232239
# Note that a Learner state contains other reference objects, but we don't clone them here, even when inplace
233240
# is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3
234241
# workhorse function
235-
prev_class = class(model)
236242
model$model = marshal_model(model$model, inplace = inplace)
237243
# only wrap this in a marshaled class if the model was actually marshaled above
238244
# (the default marshal method does nothing)
239245
if (is_marshaled_model(model$model)) {
240246
model = structure(
241247
list(marshaled = model, packages = "mlr3pipelines"),
242-
class = c(paste0(prev_class, "_marshaled"), "marshaled")
248+
class = c(paste0(class(model), "_marshaled"), "marshaled")
243249
)
244250
}
245251
model
246252
}
247253

248254
#' @export
249255
unmarshal_model.pipeop_learner_cv_state_marshaled = function(model, inplace = FALSE, ...) {
250-
prev_class = head(class(model), n = -1)
251256
state_marshaled = model$marshaled
252257
state_marshaled$model = unmarshal_model(state_marshaled$model, inplace = inplace)
253-
class(state_marshaled) = gsub(x = prev_class, pattern = "_marshaled$", replacement = "")
254258
state_marshaled
255259
}
256260

R/multiplicity.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,16 @@ multiplicity_nests_deeper_than = function(x, cutoff) {
115115
}
116116
ret
117117
}
118+
119+
#' @export
120+
marshal_model.Multiplicity = function(model, inplace = FALSE, ...) {
121+
structure(list(
122+
marshaled = multiplicity_recurse(model, marshal_model, inplace = inplace, ...),
123+
packages = "mlr3pipelines"
124+
), class = c("Multiplicity_marshaled", "marshaled"))
125+
}
126+
127+
#' @export
128+
unmarshal_model.Multiplicity_marshaled = function(model, inplace = FALSE, ...) {
129+
multiplicity_recurse(model$marshaled, unmarshal_model, inplace = inplace, ...)
130+
}

man/mlr_pipeops_imputelearner.Rd

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_pipeops_learner_cv.Rd

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_pipeop_impute.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,5 +398,4 @@ test_that("More tests for Integers", {
398398
expect_false(any(is.na(result$data()$x)), info = po$id)
399399
expect_equal(result$missings(), c(t = 0, x = 0), info = po$id)
400400
}
401-
402401
})

tests/testthat/test_pipeop_imputelearner.R

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,17 @@ test_that("PipeOpImputeLearner - model active binding to state", {
155155
expect_equal(names(models), names(po$learner_models))
156156
expect_true(all(pmap_lgl(list(map(models, .f = "model"), map(po$learner_models, .f = "model")), .f = all.equal)))
157157
})
158+
159+
test_that("marshal", {
160+
task = tsk("penguins")
161+
po_im = po("imputelearner", learner = lrn("classif.debug"))
162+
po_im$train(list(task))
163+
164+
s = po_im$state
165+
expect_class(s, "pipeop_impute_learner_state")
166+
sm = marshal_model(s)
167+
expect_class(sm, "marshaled")
168+
su = unmarshal_model(sm)
169+
expect_equal(s, su)
170+
})
171+

0 commit comments

Comments
 (0)