Skip to content

Commit ae16775

Browse files
committed
merge bundling
2 parents 6d851c6 + b031b22 commit ae16775

21 files changed

+357
-14
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Imports:
5252
data.table,
5353
digest,
5454
lgr,
55-
mlr3 (>= 0.6.0),
55+
mlr3 (>= 0.19.0),
5656
mlr3misc (>= 0.9.0),
5757
paradox,
5858
R6,

NAMESPACE

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ S3method(as_pipeop,Learner)
1212
S3method(as_pipeop,PipeOp)
1313
S3method(as_pipeop,default)
1414
S3method(disable_internal_tuning,GraphLearner)
15+
S3method(marshal_model,Multiplicity)
16+
S3method(marshal_model,graph_learner_model)
17+
S3method(marshal_model,pipeop_impute_learner_state)
18+
S3method(marshal_model,pipeop_learner_cv_state)
1519
S3method(po,"NULL")
1620
S3method(po,Filter)
1721
S3method(po,Learner)
@@ -25,6 +29,10 @@ S3method(predict,Graph)
2529
S3method(print,Multiplicity)
2630
S3method(print,Selector)
2731
S3method(set_validate,GraphLearner)
32+
S3method(unmarshal_model,Multiplicity_marshaled)
33+
S3method(unmarshal_model,graph_learner_model_marshaled)
34+
S3method(unmarshal_model,pipeop_impute_learner_state_marshaled)
35+
S3method(unmarshal_model,pipeop_learner_cv_state_marshaled)
2836
export("%>>!%")
2937
export("%>>%")
3038
export(Graph)
@@ -155,5 +163,6 @@ importFrom(data.table,as.data.table)
155163
importFrom(digest,digest)
156164
importFrom(stats,setNames)
157165
importFrom(utils,bibentry)
166+
importFrom(utils,head)
158167
importFrom(utils,tail)
159168
importFrom(withr,with_options)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* Minor documentation fixes.
77
* Test helpers are now available in `inst/`. These are considered experimental and unstable.
88

9+
* Added marshaling support to `GraphLearner`
10+
911
# mlr3pipelines 0.5.1
1012

1113
* Changed the ID of `PipeOpFeatureUnion` used in `ppl("robustify")` and `ppl("stacking")`.

R/GraphLearner.R

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,16 @@
5656
#' * `validate` :: `numeric(1)`, `"predefined"`, `"test"` or `NULL`\cr
5757
#' How to construct the validation data. This also has to be configured in the individual learners wrapped by
5858
#' `PipeOpLearner`, see [`set_validate.GraphLearner`] on how to configure this.
59+
#' * `marshaled` :: `logical(1)`\cr
60+
#' Whether the learner is marshaled.
5961
#'
62+
#' @section Methods:
63+
#' * `marshal(...)`\cr
64+
#' (any) -> `self`\cr
65+
#' Marshal the model.
66+
#' * `unmarshal(...)`\cr
67+
#' (any) -> `self`\cr
68+
#' Unmarshal the model.
6069
#'
6170
#' @section Internals:
6271
#' [`as_graph()`] is called on the `graph` argument, so it can technically also be a `list` of things, which is
@@ -150,6 +159,12 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
150159
if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.")
151160
}
152161
learner_model$base_learner(recursive - 1)
162+
},
163+
marshal = function(...) {
164+
learner_marshal(.learner = self, ...)
165+
},
166+
unmarshal = function(...) {
167+
learner_unmarshal(.learner = self, ...)
153168
}
154169
),
155170
active = list(
@@ -169,7 +184,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
169184
private$.validate = assert_validate(rhs)
170185
}
171186
private$.validate
172-
187+
},
188+
marshaled = function() {
189+
learner_marshaled(self)
173190
},
174191
hash = function() {
175192
digest(list(class(self), self$id, self$graph$hash, private$.predict_type, private$.validate,
@@ -260,6 +277,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
260277
on.exit({self$graph$state = NULL})
261278
self$graph$train(task)
262279
state = self$graph$state
280+
class(state) = c("graph_learner_model", class(state))
263281
state
264282
},
265283
.predict = function(task) {
@@ -414,6 +432,27 @@ disable_internal_tuning.GraphLearner = function(learner, ids, ...) {
414432
}
415433

416434

435+
#' @export
436+
marshal_model.graph_learner_model = function(model, inplace = FALSE, ...) {
437+
xm = map(.x = model, .f = marshal_model, inplace = inplace, ...)
438+
# if none of the states required any marshaling we return the model as-is
439+
if (!some(xm, is_marshaled_model)) return(model)
440+
441+
structure(list(
442+
marshaled = xm,
443+
packages = "mlr3pipelines"
444+
), class = c("graph_learner_model_marshaled", "list_marshaled", "marshaled"))
445+
}
446+
447+
#' @export
448+
unmarshal_model.graph_learner_model_marshaled = function(model, inplace = FALSE, ...) {
449+
# need to re-create the class as it gets lost during marshaling
450+
structure(
451+
map(.x = model$marshaled, .f = unmarshal_model, inplace = inplace, ...),
452+
class = gsub(x = head(class(model), n = -1), pattern = "_marshaled$", replacement = "")
453+
)
454+
}
455+
417456
#' @export
418457
as_learner.Graph = function(x, clone = FALSE, ...) {
419458
GraphLearner$new(x, clone_graph = clone)

R/PipeOpImputeLearner.R

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
#' for each column. If a column consists of missing values only during training, the `model` is `0` or the levels of the
4545
#' feature; these are used for sampling during prediction.
4646
#'
47+
#' This state is given the class `"pipeop_impute_learner_state"`.
48+
#'
4749
#' @section Parameters:
4850
#' The parameters are the parameters inherited from [`PipeOpImpute`], in addition to the parameters of the [`Learner`][mlr3::Learner]
4951
#' used for imputation.
@@ -116,6 +118,13 @@ PipeOpImputeLearner = R6Class("PipeOpImputeLearner",
116118
)
117119
super$initialize(id, param_set = alist(private$.learner$param_set), param_vals = param_vals,
118120
whole_task_dependent = TRUE, feature_types = feature_types)
121+
},
122+
train = function(inputs) {
123+
outputs = super$train(inputs)
124+
self$state = multiplicity_recurse(self$state, function(state) {
125+
structure(state, class = c("pipeop_impute_learner_state", class(state)))
126+
})
127+
return(outputs)
119128
}
120129
),
121130
active = list(
@@ -206,3 +215,25 @@ mlr_pipeops$add("imputelearner", PipeOpImputeLearner, list(R6Class("Learner", pu
206215
convert_to_task = function(id = "imputing", data, target, task_type, ...) {
207216
get(mlr_reflections$task_types[task_type, mult = "first"]$task)$new(id = id, backend = data, target = target, ...)
208217
}
218+
219+
#' @export
220+
marshal_model.pipeop_impute_learner_state = function(model, inplace = FALSE, ...) {
221+
prev_class = class(model)
222+
model$model = map(model$model, marshal_model, inplace = inplace, ...)
223+
224+
if (!some(model$model, is_marshaled_model)) {
225+
return(model)
226+
}
227+
228+
structure(
229+
list(marshaled = model, packages = "mlr3pipelines"),
230+
class = c(paste0(prev_class, "_marshaled"), "marshaled")
231+
)
232+
}
233+
234+
#' @export
235+
unmarshal_model.pipeop_impute_learner_state_marshaled = function(model, inplace = FALSE, ...) {
236+
state = model$marshaled
237+
state$model = map(state$model, unmarshal_model, inplace = inplace, ...)
238+
return(state)
239+
}

R/PipeOpLearner.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,7 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
9494
super$initialize(id, param_set = alist(private$.learner$param_set), param_vals = param_vals,
9595
input = data.table(name = "input", train = task_type, predict = task_type),
9696
output = data.table(name = "output", train = "NULL", predict = out_type),
97-
tags = "learner", packages = learner$packages
98-
)
97+
tags = "learner", packages = learner$packages)
9998
}
10099
),
101100
active = list(

R/PipeOpLearnerCV.R

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
#' * `predict_time` :: `NULL` | `numeric(1)`
6262
#' Prediction time, in seconds.
6363
#'
64+
#' This state is given the class `"pipeop_learner_cv_state"`.
65+
#'
6466
#' @section Parameters:
6567
#' The parameters are the parameters inherited from the [`PipeOpTaskPreproc`], as well as the parameters of the [`Learner`][mlr3::Learner] wrapped by this object.
6668
#' Besides that, parameters introduced are:
@@ -144,8 +146,14 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
144146
# private$.crossval_param_set$add_dep("folds", "method", CondEqual$new("cv")) # don't do this.
145147

146148
super$initialize(id, alist(resampling = private$.crossval_param_set, private$.learner$param_set), param_vals = param_vals, can_subset_cols = TRUE, task_type = task_type, tags = c("learner", "ensemble"))
149+
},
150+
train = function(inputs) {
151+
outputs = super$train(inputs)
152+
self$state = multiplicity_recurse(self$state, function(state) {
153+
structure(state, class = c("pipeop_learner_cv_state", class(state)))
154+
})
155+
return(outputs)
147156
}
148-
149157
),
150158
active = list(
151159
learner = function(val) {
@@ -224,4 +232,29 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
224232
)
225233
)
226234

235+
#' @export
236+
marshal_model.pipeop_learner_cv_state = function(model, inplace = FALSE, ...) {
237+
# Note that a Learner state contains other reference objects, but we don't clone them here, even when inplace
238+
# is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3
239+
# workhorse function
240+
model$model = marshal_model(model$model, inplace = inplace)
241+
# only wrap this in a marshaled class if the model was actually marshaled above
242+
# (the default marshal method does nothing)
243+
if (is_marshaled_model(model$model)) {
244+
model = structure(
245+
list(marshaled = model, packages = "mlr3pipelines"),
246+
class = c(paste0(class(model), "_marshaled"), "marshaled")
247+
)
248+
}
249+
model
250+
}
251+
252+
#' @export
253+
unmarshal_model.pipeop_learner_cv_state_marshaled = function(model, inplace = FALSE, ...) {
254+
state_marshaled = model$marshaled
255+
state_marshaled$model = unmarshal_model(state_marshaled$model, inplace = inplace)
256+
state_marshaled
257+
}
258+
259+
227260
mlr_pipeops$add("learner_cv", PipeOpLearnerCV, list(R6Class("Learner", public = list(id = "learner_cv", task_type = "classif", param_set = ps()))$new()))

R/PipeOpTaskPreproc.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,7 @@ PipeOpTaskPreproc = R6Class("PipeOpTaskPreproc",
187187
super$initialize(id = id, param_set = param_set, param_vals = param_vals,
188188
input = data.table(name = "input", train = task_type, predict = task_type),
189189
output = data.table(name = "output", train = task_type, predict = task_type),
190-
packages = packages, tags = c(tags, "data transform")
191-
)
190+
packages = packages, tags = c(tags, "data transform"))
192191
}
193192
),
194193
active = list(

R/multiplicity.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,16 @@ multiplicity_nests_deeper_than = function(x, cutoff) {
115115
}
116116
ret
117117
}
118+
119+
#' @export
120+
marshal_model.Multiplicity = function(model, inplace = FALSE, ...) {
121+
structure(list(
122+
marshaled = multiplicity_recurse(model, marshal_model, inplace = inplace, ...),
123+
packages = "mlr3pipelines"
124+
), class = c("Multiplicity_marshaled", "marshaled"))
125+
}
126+
127+
#' @export
128+
unmarshal_model.Multiplicity_marshaled = function(model, inplace = FALSE, ...) {
129+
multiplicity_recurse(model$marshaled, unmarshal_model, inplace = inplace, ...)
130+
}

R/zzz.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#' @import paradox
55
#' @import mlr3misc
66
#' @importFrom R6 R6Class
7-
#' @importFrom utils tail
7+
#' @importFrom utils tail head
88
#' @importFrom digest digest
99
#' @importFrom withr with_options
1010
#' @importFrom stats setNames

0 commit comments

Comments
 (0)