Skip to content

Commit 11aa5ce

Browse files
committed
feat: implement marshaling
1 parent 273d44b commit 11aa5ce

File tree

11 files changed

+118
-60
lines changed

11 files changed

+118
-60
lines changed

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ S3method(as_pipeop,Learner)
1212
S3method(as_pipeop,PipeOp)
1313
S3method(as_pipeop,default)
1414
S3method(marshal_model,graph_learner_model)
15+
S3method(marshal_model,pipeop_learner_cv_state)
16+
S3method(marshal_model,pipeop_learner_state)
1517
S3method(po,"NULL")
1618
S3method(po,Filter)
1719
S3method(po,Learner)
@@ -25,6 +27,8 @@ S3method(predict,Graph)
2527
S3method(print,Multiplicity)
2628
S3method(print,Selector)
2729
S3method(unmarshal_model,graph_learner_model_marshaled)
30+
S3method(unmarshal_model,pipeop_learner_cv_state_marshaled)
31+
S3method(unmarshal_model,pipeop_learner_state_marshaled)
2832
export("%>>!%")
2933
export("%>>%")
3034
export(Graph)
@@ -154,5 +158,6 @@ importFrom(data.table,as.data.table)
154158
importFrom(digest,digest)
155159
importFrom(stats,setNames)
156160
importFrom(utils,bibentry)
161+
importFrom(utils,head)
157162
importFrom(utils,tail)
158163
importFrom(withr,with_options)

R/GraphLearner.R

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@
5151
#' Whether the learner is marshaled. Read-only.
5252
#'
5353
#' @section Methods:
54-
#' * `marshal_model(...)`\cr
54+
#' * `marshal(...)`\cr
5555
#' (any) -> `self`\cr
5656
#' Marshal the model.
57-
#' * `unmarshal_model(...)`\cr
57+
#' * `unmarshal(...)`\cr
5858
#' (any) -> `self`\cr
5959
#' Unmarshal the model.
6060
#'
@@ -253,41 +253,24 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
253253
)
254254
)
255255

256-
#' @title (Un-)Marshal GraphLearner Model
257-
#' @name marshal_graph_learner
258-
#' @description
259-
#' (Un-)marshal the model of a [`GraphLearner`].
260-
#' @param model (model of [`GraphLearner`])\cr
261-
#' The model to be marshaled.
262-
#' @param ... (any)\cr
263-
#' Currently unused.
264-
#' @param inplace (`logical(1)`)\cr
265-
#' Whether to marshal in-place.
266-
#' If `FALSE` (default), all R6-objects are cloned.
267-
#' @keywords internal
268256
#' @export
269257
marshal_model.graph_learner_model = function(model, inplace = FALSE, ...) {
270-
x = map(model, function(po_state) {
271-
po_state$model = if (!is.null(po_state$model)) marshal_model(po_state$model, inplace = inplace, ...)
272-
po_state
273-
})
274-
if (!some(map(x, "model"), is_marshaled_model)) {
275-
return(structure(x, class = c("graph_learner_model", "list")))
276-
}
258+
xm = map(.x = model, .f = marshal_model, inplace = inplace, ...)
259+
# if none of the states required any marshaling we return the model as-is
260+
if (!some(xm, is_marshaled_model)) return(model)
261+
277262
structure(list(
278-
marshaled = x,
279-
packages = "mlr3pipelines"
263+
marshaled = xm,
264+
packages = "mlr3pipelines"
280265
), class = c("graph_learner_model_marshaled", "list_marshaled", "marshaled"))
281266
}
282267

283268
#' @export
284269
unmarshal_model.graph_learner_model_marshaled = function(model, inplace = FALSE, ...) {
285270
structure(
286-
map(model$marshaled, function(po_state) {
287-
po_state$model = if (!is.null(po_state$model)) unmarshal_model(po_state$model, inplace = inplace, ...)
288-
po_state
289-
}
290-
), class = c("graph_learner_model", "list"))
271+
map(.x = model$marshaled, .f = unmarshal_model, inplace = inplace, ...),
272+
class = gsub(x = head(class(model), n = -1), pattern = "_marshaled$", replacement = "")
273+
)
291274
}
292275

293276
#' @export

R/PipeOpLearner.R

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
139139
.train = function(inputs) {
140140
on.exit({private$.learner$state = NULL})
141141
task = inputs[[1L]]
142-
self$state = private$.learner$train(task)$state
142+
learner_state = private$.learner$train(task)$state
143+
self$state = structure(learner_state, class = c("pipeop_learner_state", class(learner_state)))
143144

144145
list(NULL)
145146
},
@@ -154,4 +155,30 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
154155
)
155156
)
156157

158+
#' @export
159+
marshal_model.pipeop_learner_state = function(model, inplace = FALSE, ...) {
160+
# Note that a Learner state contains other objects with reference semantics, but we don't clone them here, even when inplace
161+
# is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3
162+
# workhorse function
163+
prev_class = class(model)
164+
model$model = marshal_model(model$model, inplace = inplace)
165+
# only wrap this in a marshaled class if the model was actually marshaled above
166+
# (the default marshal method does nothing)
167+
if (!is_marshaled_model(model$model)) return(model)
168+
structure(
169+
list(marshaled = model, packages = "mlr3pipelines"),
170+
class = c(paste0(prev_class, "_marshaled"), "marshaled")
171+
)
172+
}
173+
174+
#' @export
175+
unmarshal_model.pipeop_learner_state_marshaled = function(model, inplace = FALSE, ...) {
176+
prev_class = head(class(model), n = -1)
177+
state_marshaled = model$marshaled
178+
state_marshaled$model = unmarshal_model(state_marshaled$model, inplace = inplace)
179+
class(state_marshaled) = gsub(x = prev_class, pattern = "_marshaled$", replacement = "")
180+
state_marshaled
181+
}
182+
183+
157184
mlr_pipeops$add("learner", PipeOpLearner, list(R6Class("Learner", public = list(id = "learner", task_type = "classif", param_set = ps(), packages = "mlr3pipelines"))$new()))

R/PipeOpLearnerCV.R

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,11 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
175175
}
176176
),
177177
private = list(
178+
.train = function(inputs) {
179+
out = super$.train(inputs)
180+
self$state = structure(self$state, class = c("pipeop_learner_cv_state", class(self$state)))
181+
return(out)
182+
},
178183
.train_task = function(task) {
179184
on.exit({private$.learner$state = NULL})
180185

@@ -222,4 +227,32 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
222227
)
223228
)
224229

230+
#' @export
231+
marshal_model.pipeop_learner_cv_state = function(model, inplace = FALSE, ...) {
232+
# Note that a Learner state contains other reference objects, but we don't clone them here, even when inplace
233+
# is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3
234+
# workhorse function
235+
prev_class = class(model)
236+
model$model = marshal_model(model$model, inplace = inplace)
237+
# only wrap this in a marshaled class if the model was actually marshaled above
238+
# (the default marshal method does nothing)
239+
if (is_marshaled_model(model$model)) {
240+
model = structure(
241+
list(marshaled = model, packages = "mlr3pipelines"),
242+
class = c(paste0(prev_class, "_marshaled"), "marshaled")
243+
)
244+
}
245+
model
246+
}
247+
248+
#' @export
249+
unmarshal_model.pipeop_learner_cv_state_marshaled = function(model, inplace = FALSE, ...) {
250+
prev_class = head(class(model), n = -1)
251+
state_marshaled = model$marshaled
252+
state_marshaled$model = unmarshal_model(state_marshaled$model, inplace = inplace)
253+
class(state_marshaled) = gsub(x = prev_class, pattern = "_marshaled$", replacement = "")
254+
state_marshaled
255+
}
256+
257+
225258
mlr_pipeops$add("learner_cv", PipeOpLearnerCV, list(R6Class("Learner", public = list(id = "learner_cv", task_type = "classif", param_set = ps()))$new()))

R/zzz.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#' @import paradox
55
#' @import mlr3misc
66
#' @importFrom R6 R6Class
7-
#' @importFrom utils tail
7+
#' @importFrom utils tail head
88
#' @importFrom digest digest
99
#' @importFrom withr with_options
1010
#' @importFrom stats setNames

man/marshal_graph_learner.Rd

Lines changed: 0 additions & 24 deletions
This file was deleted.

man/mlr_learners_graph.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_GraphLearner.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,12 +575,14 @@ test_that("marshal", {
575575
glrn$train(task)
576576
glrn$marshal()
577577
expect_true(glrn$marshaled)
578-
expect_true(is_marshaled_model(glrn$state$model$marshaled$classif.debug$model))
578+
expect_true(is_marshaled_model(glrn$state$model$marshaled$classif.debug))
579579
glrn$unmarshal()
580-
expect_false(is_marshaled_model(glrn$model))
580+
expect_false(is_marshaled_model(glrn$state$model$marshaled$classif.debug))
581581
expect_class(glrn$model, "graph_learner_model")
582582
expect_false(is_marshaled_model(glrn$state$model$marshaled$classif.debug$model))
583583

584+
glrn$predict(task)
585+
584586
# checks that it is marshalable
585587
glrn$train(task)
586588
expect_learner(glrn, task)

tests/testthat/test_mlr_graphs_bagging.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
context("ppl - pipeline_bagging")
22

3-
43
test_that("Bagging Pipeline", {
54
skip_on_cran() # takes too long
65

tests/testthat/test_pipeop_learner.R

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,31 @@ test_that("PipeOpLearner - model active binding to state", {
8484
})
8585

8686
test_that("packages", {
87-
8887
expect_set_equal(
8988
c("mlr3pipelines", lrn("classif.rpart")$packages),
9089
po("learner", learner = lrn("classif.rpart"))$packages
9190
)
91+
})
92+
93+
test_that("marshal", {
94+
task = tsk("iris")
95+
po_lrn = as_pipeop(lrn("classif.debug"))
96+
po_lrn$train(list(task))
97+
po_state = po_lrn$state
98+
expect_class(po_state, "pipeop_learner_state")
99+
po_state_marshaled = marshal_model(po_state, inplace = FALSE)
100+
expect_class(po_state_marshaled, "pipeop_learner_state_marshaled")
101+
expect_true(is_marshaled_model(po_state_marshaled))
102+
expect_equal(po_state, unmarshal_model(po_state_marshaled))
103+
})
104+
105+
test_that("multiple marshal round-trips", {
106+
task = tsk("iris")
107+
glrn = as_learner(as_graph(lrn("classif.debug")))
108+
glrn$train(task)
109+
glrn$marshal()$unmarshal()$marshal()$unmarshal()
110+
expect_class(glrn$model, "graph_learner_model")
111+
expect_class(glrn$model$classif.debug$model, "classif.debug_model")
92112

113+
expect_learner(glrn, task = task)
93114
})

0 commit comments

Comments
 (0)