Skip to content

Commit 750636f

Browse files
authored
Merge pull request #836 from mlr-org/glrn_shortcuts
AB shortcuts for GraphLearner
2 parents c823f18 + 9e3289a commit 750636f

File tree

6 files changed

+135
-26
lines changed

6 files changed

+135
-26
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# mlr3pipelines 0.7.0-9000
22

33
* New down-sampling PipeOps for inbalanced data: `PipeOpTomek` / `po("tomek")` and `PipeOpNearmiss` / `po("nearmiss")`
4+
* `GraphLearner` has new active bindings/methods as shortcuts for active bindings/methods of the underlying `Graph`:
5+
`$pipeops`, `$edges`, `$pipeops_param_set`, and `$pipeops_param_set_values` as well as `$ids()` and `$plot()`.
46

57
# mlr3pipelines 0.7.0
68

R/Graph.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,12 @@
9191
#' Takes a list of `Graph`s or [`PipeOp`]s (or objects that can be automatically converted into `Graph`s or [`PipeOp`]s,
9292
#' see [`as_graph()`] and [`as_pipeop()`]) as inputs and joins them in a serial `Graph` coming after `self`, as if
9393
#' connecting them using [`%>>%`].
94-
#' * `plot(html)` \cr
95-
#' (`logical(1)`) -> `NULL` \cr
94+
#' * `plot(html = FALSE, horizontal = FALSE)` \cr
95+
#' (`logical(1)`, `logical(1)`) -> `NULL` \cr
9696
#' Plot the [`Graph`], using either the \pkg{igraph} package (for `html = FALSE`, default) or
9797
#' the `visNetwork` package for `html = TRUE` producing a [`htmlWidget`][htmlwidgets::htmlwidgets].
9898
#' The [`htmlWidget`][htmlwidgets::htmlwidgets] can be rescaled using [`visOptions`][visNetwork::visOptions].
99+
#' For `html = FALSE`, the orientation of the plotted graph can be controlled through `horizontal`.
99100
#' * `print(dot = FALSE, dotname = "dot", fontsize = 24L)` \cr
100101
#' (`logical(1)`, `character(1)`, `integer(1)`) -> `NULL` \cr
101102
#' Print a representation of the [`Graph`] on the console. If `dot` is `FALSE`, output is a table with one row for each contained [`PipeOp`] and

R/GraphLearner.R

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,25 @@
4747
#' contain the model. Use `graph_model` to access the trained [`Graph`] after `$train()`. Read-only.
4848
#' * `graph_model` :: [`Learner`][mlr3::Learner]\cr
4949
#' [`Graph`] that is being wrapped. This [`Graph`] contains a trained state after `$train()`. Read-only.
50+
#' * `pipeops` :: named `list` of [`PipeOp`] \cr
51+
#' Contains all [`PipeOp`]s in the underlying [`Graph`], named by the [`PipeOp`]'s `$id`s. Shortcut for `$graph_model$pipeops`. See [`Graph`] for details.
52+
#' * `edges` :: [`data.table`][data.table::data.table] with columns `src_id` (`character`), `src_channel` (`character`), `dst_id` (`character`), `dst_channel` (`character`)\cr
53+
#' Table of connections between the [`PipeOp`]s in the underlying [`Graph`]. Shortcut for `$graph$edges`. See [`Graph`] for details.
54+
#' * `param_set` :: [`ParamSet`][paradox::ParamSet]\cr
55+
#' Parameters of the underlying [`Graph`]. Shortcut for `$graph$param_set`. See [`Graph`] for details.
56+
#' * `pipeops_param_set` :: named `list()`\cr
57+
#' Named list containing the [`ParamSet`][paradox::ParamSet]s of all [`PipeOp`]s in the [`Graph`]. See there for details.
58+
#' * `pipeops_param_set_values` :: named `list()`\cr
59+
#' Named list containing the set parameter values of all [`PipeOp`]s in the [`Graph`]. See there for details.
5060
#' * `internal_tuned_values` :: named `list()` or `NULL`\cr
51-
#' The internal tuned parameter values collected from all `PipeOp`s.
61+
#' The internal tuned parameter values collected from all [`PipeOp`]s.
5262
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal tuning.
5363
#' * `internal_valid_scores` :: named `list()` or `NULL`\cr
54-
#' The internal validation scores as retrieved from the `PipeOps`.
55-
#' The names are prefixed with the respective IDs of the `PipeOp`s.
64+
#' The internal validation scores as retrieved from the [`PipeOp`]s.
65+
#' The names are prefixed with the respective IDs of the [`PipeOp`]s.
5666
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal validation.
5767
#' * `validate` :: `numeric(1)`, `"predefined"`, `"test"` or `NULL`\cr
58-
#' How to construct the validation data. This also has to be configured for the individual `PipeOp`s such as
68+
#' How to construct the validation data. This also has to be configured for the individual [`PipeOp`]s such as
5969
#' `PipeOpLearner`, see [`set_validate.GraphLearner`].
6070
#' For more details on the possible values, see [`mlr3::Learner`].
6171
#' * `marshaled` :: `logical(1)`\cr
@@ -75,6 +85,16 @@
7585
#'
7686
#' @section Methods:
7787
#' Methods inherited from [`Learner`][mlr3::Learner], as well as:
88+
#' * `ids(sorted = FALSE)` \cr
89+
#' (`logical(1)`) -> `character` \cr
90+
#' Get IDs of all [`PipeOp`]s. This is in order that [`PipeOp`]s were added if
91+
#' `sorted` is `FALSE`, and topologically sorted if `sorted` is `TRUE`.
92+
#' * `plot(html = FALSE, horizontal = FALSE)` \cr
93+
#' (`logical(1)`, `logical(1)`) -> `NULL` \cr
94+
#' Plot the [`Graph`], using either the \pkg{igraph} package (for `html = FALSE`, default) or
95+
#' the `visNetwork` package for `html = TRUE` producing a [`htmlWidget`][htmlwidgets::htmlwidgets].
96+
#' The [`htmlWidget`][htmlwidgets::htmlwidgets] can be rescaled using [`visOptions`][visNetwork::visOptions].
97+
#' For `html = FALSE`, the orientation of the plotted graph can be controlled through `horizontal`.
7898
#' * `marshal`\cr
7999
#' (any) -> `self`\cr
80100
#' Marshal the model.
@@ -104,11 +124,11 @@
104124
#' This works well for simple [`Graph`]s that do not modify features too much, but may give unexpected results for `Graph`s that
105125
#' add new features or move information between features.
106126
#'
107-
#' As an example, consider a feature `A`` with missing values, and a feature `B`` that is used for imputatoin, using a [`po("imputelearner")`][PipeOpImputeLearner].
108-
#' In a case where the following [`Learner`][mlr3::Learner] performs embedded feature selection and only selects feature A,
109-
#' the `selected_features()` method could return only feature `A``, and `$importance()` may even report 0 for feature `B`.
110-
#' This would not be entirbababababely accurate when considering the entire `GraphLearner`, as feature `B` is used for imputation and would therefore have an impact on predictions.
111-
#' The following should therefore only be used if the `Graph` is known to not have an impact on the relevant properties.
127+
#' As an example, consider a feature `A` with missing values, and a feature `B` that is used for imputation, using a [`po("imputelearner")`][PipeOpImputeLearner].
128+
#' In a case where the following [`Learner`][mlr3::Learner] performs embedded feature selection and only selects feature `A`,
129+
#' the `selected_features()` method could return only feature `A`, and `$importance()` may even report 0 for feature `B`.
130+
#' This would not be entirely accurate when considering the entire `GraphLearner`, as feature `B` is used for imputation and would therefore have an impact on predictions.
131+
#' The following should therefore only be used if the [`Graph`] is known to not have an impact on the relevant properties.
112132
#'
113133
#' * `importance()`\cr
114134
#' () -> `numeric`\cr
@@ -286,6 +306,12 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
286306
} else {
287307
stopf("Baselearner %s of %s does not implement '$loglik()'.", base_learner$id, self$id)
288308
}
309+
},
310+
ids = function(sorted = FALSE) {
311+
private$.graph$ids(sorted = sorted)
312+
},
313+
plot = function(html = FALSE, horizontal = FALSE, ...) {
314+
private$.graph$plot(html = html, horizontal = horizontal, ...)
289315
}
290316
),
291317
active = list(
@@ -339,12 +365,6 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
339365
}
340366
pt
341367
},
342-
param_set = function(rhs) {
343-
if (!missing(rhs) && !identical(rhs, self$graph$param_set)) {
344-
stop("param_set is read-only.")
345-
}
346-
self$graph$param_set
347-
},
348368
graph = function(rhs) {
349369
if (!missing(rhs) && !identical(rhs, private$.graph)) stop("graph is read-only")
350370
private$.graph
@@ -360,6 +380,41 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
360380
g$state = self$model
361381
g
362382
}
383+
},
384+
pipeops = function(rhs) {
385+
if (!missing(rhs) && (!identical(rhs, self$graph_model$pipeops))) {
386+
stop("pipeops is read-only")
387+
}
388+
self$graph_model$pipeops
389+
},
390+
edges = function(rhs) {
391+
if (!missing(rhs) && !identical(rhs, private$.graph$edges)) {
392+
stop("edges is read-only")
393+
}
394+
private$.graph$edges
395+
},
396+
param_set = function(rhs) {
397+
if (!missing(rhs) && !identical(rhs, self$graph$param_set)) {
398+
stop("param_set is read-only.")
399+
}
400+
self$graph$param_set
401+
},
402+
pipeops_param_set = function(rhs) {
403+
value = map(self$graph$pipeops, "param_set")
404+
if (!missing(rhs) && !identical(value, rhs)) {
405+
stop("pipeops_param_set is read-only")
406+
}
407+
value
408+
},
409+
pipeops_param_set_values = function(rhs) {
410+
if (!missing(rhs)) {
411+
assert_list(rhs)
412+
assert_names(names(rhs), permutation.of = names(self$graph$pipeops))
413+
for (n in names(rhs)) {
414+
self$graph$pipeops[[n]]$param_set$values = rhs[[n]]
415+
}
416+
}
417+
map(self$graph$pipeops, function(x) x$param_set$values)
363418
}
364419
),
365420
private = list(

man/Graph.Rd

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_learners_graph.Rd

Lines changed: 29 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_GraphLearner.R

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ test_that("graphlearner parameters behave as they should", {
127127
dblrn = mlr_learners$get("classif.debug")
128128
dblrn$param_set$values$save_tasks = TRUE
129129

130+
# Graph ParamSet
130131
dbgr = PipeOpScale$new() %>>% PipeOpLearner$new(dblrn)
131132

132133
expect_subset(c("scale.center", "scale.scale", "classif.debug.x"), dbgr$param_set$ids())
@@ -163,6 +164,7 @@ test_that("graphlearner parameters behave as they should", {
163164
expect_equal(dbgr$pipeops$classif.debug$param_set$values$x, 0.5)
164165
expect_equal(dbgr$pipeops$classif.debug$learner$param_set$values$x, 0.5)
165166

167+
# Graph Learner ParamSet
166168
dblrn = mlr_learners$get("classif.debug")
167169
dblrn$param_set$values$message_train = 1
168170
dblrn$param_set$values$message_predict = 1
@@ -177,6 +179,32 @@ test_that("graphlearner parameters behave as they should", {
177179

178180
expect_mapequal(gl$param_set$values,
179181
list(classif.debug.message_predict = 0, classif.debug.message_train = 1, classif.debug.warning_predict = 0, classif.debug.warning_train = 1))
182+
183+
# GraphLearner AB shortcuts
184+
gl = GraphLearner$new(dbgr)
185+
186+
# GraphLearner AB $pipeops
187+
expect_no_error({gl$pipeops$classif.debug$param_set$values$x = 0.5})
188+
expect_equal(gl$pipeops$classif.debug$param_set$values$x, 0.5)
189+
expect_equal(gl$graph_model$pipeops$classif.debug$param_set$values$x, 0.5)
190+
191+
# GraphLearner AB $pipeops_param_set
192+
expect_no_error({gl$pipeops_param_set$classif.debug$values$x = 0})
193+
expect_equal(gl$pipeops_param_set$classif.debug$values$x, 0)
194+
expect_equal(gl$graph_model$pipeops$classif.debug$param_set$values$x, 0)
195+
196+
# GraphLearner AB $pipeops_param_set_values
197+
expect_no_error({gl$pipeops_param_set_values$classif.debug$x = 1})
198+
expect_equal(gl$pipeops_param_set_values$classif.debug$x, 1)
199+
expect_equal(gl$graph_model$pipeops$classif.debug$param_set$values$x, 1)
200+
201+
# Change param_set pointer should throw error
202+
expect_error({gl$pipeops$scale$param_set = ps()})
203+
expect_error({gl$pipeops_param_set$scale = ps()})
204+
# Lists with wrong properties should not be accepted
205+
expect_error({gl$pipeops_param_set_values = list()})
206+
expect_error({gl$pipeops_param_set_values = list(x = 5)})
207+
180208
})
181209

182210
test_that("graphlearner type inference", {

0 commit comments

Comments
 (0)