Skip to content

Commit 5dc84f6

Browse files
authored
Merge pull request #725 from mlr-org/ppl_bagging_with_replacement
add `replace` argument to ppl("bagging")
2 parents 11a22ce + faab712 commit 5dc84f6

File tree

4 files changed

+65
-8
lines changed

4 files changed

+65
-8
lines changed

NEWS.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# mlr3pipelines 0.5.0-9000
22

3-
* Feature: The `$add_pipeop()` method got an argument `clone` (old behaviour `TRUE` by default)
4-
* Bugfix: `PipeOpFeatureUnion` in some rare cases dropped variables called `"x"`
5-
* Compatibility with upcoming paradox release
3+
* `pipeline_bagging()` gets the `replace` argument (old behaviour `FALSE` by default).
4+
* Feature: The `$add_pipeop()` method got an argument `clone` (old behaviour `TRUE` by default).
5+
* Bugfix: `PipeOpFeatureUnion` in some rare cases dropped variables called `"x"`.
6+
* Compatibility with upcoming paradox release.
67

78
# mlr3pipelines 0.5.0-2
89

R/pipeline_bagging.R

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
#' predictions respectively.
2929
#' If `NULL` (default), no averager is added to the end of the graph.
3030
#' Note that setting `collect_multipliciy = TRUE` during construction of the averager is required.
31+
#' @param replace `logical(1)` \cr
32+
#' Whether to sample with replacement.
33+
#' Default `FALSE`.
3134
#' @return [`Graph`]
3235
#' @export
3336
#' @examples
@@ -36,9 +39,14 @@
3639
#' lrn_po = po("learner", lrn("regr.rpart"))
3740
#' task = mlr_tasks$get("boston_housing")
3841
#' gr = pipeline_bagging(lrn_po, 3, averager = po("regravg", collect_multiplicity = TRUE))
39-
#' resample(task, GraphLearner$new(gr), rsmp("holdout"))
42+
#' resample(task, GraphLearner$new(gr), rsmp("holdout"))$aggregate()
43+
#'
44+
#' # The original bagging method uses boosting by sampling with replacement.
45+
#' gr = ppl("bagging", lrn_po, frac = 1, replace = TRUE,
46+
#' averager = po("regravg", collect_multiplicity = TRUE))
47+
#' resample(task, GraphLearner$new(gr), rsmp("holdout"))$aggregate()
4048
#' }
41-
pipeline_bagging = function(graph, iterations = 10, frac = 0.7, averager = NULL) {
49+
pipeline_bagging = function(graph, iterations = 10, frac = 0.7, averager = NULL, replace = FALSE) {
4250
g = as_graph(graph)
4351
assert_count(iterations)
4452
assert_number(frac, lower = 0, upper = 1)
@@ -50,7 +58,7 @@ pipeline_bagging = function(graph, iterations = 10, frac = 0.7, averager = NULL)
5058
}
5159

5260
po("replicate", param_vals = list(reps = iterations)) %>>!%
53-
po("subsample", param_vals = list(frac = frac)) %>>!%
61+
po("subsample", param_vals = list(frac = frac, replace = replace)) %>>!%
5462
g %>>!%
5563
averager
5664
}

man/mlr_graphs_bagging.Rd

Lines changed: 17 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_mlr_graphs_bagging.R

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,36 @@ test_that("Bagging Pipeline", {
3939
expect_true(all(map_lgl(predict_out, function(x) "PredictionClassif" %in% class(x))))
4040
})
4141

42+
test_that("Bagging with replacement", {
43+
tsk = tsk("iris")
44+
lrn = lrn("classif.rpart")
45+
p = ppl("bagging", graph = po(lrn), replace = TRUE, averager = po("classifavg", collect_multiplicity = TRUE))
46+
expect_graph(p)
47+
res = resample(tsk, GraphLearner$new(p), rsmp("holdout"))
48+
expect_resample_result(res)
49+
50+
tsk$filter(1:140)
51+
expect_equal(anyDuplicated(tsk$data()), 0) # make sure no duplicates
52+
53+
p = ppl("bagging", iterations = 2, frac = 1,
54+
graph = lrn("classif.debug", save_tasks = TRUE),
55+
replace = TRUE, averager = po("classifavg", collect_multiplicity = TRUE)
56+
)
57+
p$train(tsk)
58+
59+
expect_true(anyDuplicated(p$pipeops$classif.debug$state[[1]]$model$task_train$data()) != 0)
60+
61+
getOrigId = function(data) {
62+
tsk$data()[, origline := .I][data, on = colnames(tsk$data()), origline]
63+
}
64+
orig_id_1 = getOrigId(p$pipeops$classif.debug$state[[1]]$model$task_train$data())
65+
orig_id_2 = getOrigId(p$pipeops$classif.debug$state[[2]]$model$task_train$data())
66+
67+
expect_equal(length(orig_id_1), 140)
68+
expect_equal(length(orig_id_2), 140)
69+
# if we sampled the same values twice, the all.equal() would just give TRUE
70+
expect_string(all.equal(orig_id_1, orig_id_2))
71+
72+
expect_true(length(unique(orig_id_1)) < 140)
73+
expect_true(length(unique(orig_id_2)) < 140)
74+
})

0 commit comments

Comments
 (0)