Skip to content

Commit 8946fcd

Browse files
committed
add replace argument to ppl("bagging")
1 parent dae03af commit 8946fcd

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

R/pipeline_bagging.R

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
#' predictions respectively.
2929
#' If `NULL` (default), no averager is added to the end of the graph.
3030
#' Note that setting `collect_multipliciy = TRUE` during construction of the averager is required.
31+
#' @param replace `logical(1)` \cr
32+
#' Whether to sample with replacement.
33+
#' Default `FALSE`.
3134
#' @return [`Graph`]
3235
#' @export
3336
#' @examples
@@ -36,9 +39,15 @@
3639
#' lrn_po = po("learner", lrn("regr.rpart"))
3740
#' task = mlr_tasks$get("boston_housing")
3841
#' gr = pipeline_bagging(lrn_po, 3, averager = po("regravg", collect_multiplicity = TRUE))
39-
#' resample(task, GraphLearner$new(gr), rsmp("holdout"))
42+
#' resample(task, GraphLearner$new(gr), rsmp("holdout"))$aggregate()
43+
#'
44+
#' # The original bagging method uses boosting by sampling with replacement.
45+
#' # This may give better performance but is also slower.
46+
#' gr = ppl("bagging", lrn_po, frac = 1, replace = TRUE,
47+
#' averager = po("regravg", collect_multiplicity = TRUE))
48+
#' resample(task, GraphLearner$new(gr), rsmp("holdout"))$aggregate()
4049
#' }
41-
pipeline_bagging = function(graph, iterations = 10, frac = 0.7, averager = NULL) {
50+
pipeline_bagging = function(graph, iterations = 10, frac = 0.7, averager = NULL, replace = FALSE) {
4251
g = as_graph(graph)
4352
assert_count(iterations)
4453
assert_number(frac, lower = 0, upper = 1)
@@ -50,7 +59,7 @@ pipeline_bagging = function(graph, iterations = 10, frac = 0.7, averager = NULL)
5059
}
5160

5261
po("replicate", param_vals = list(reps = iterations)) %>>!%
53-
po("subsample", param_vals = list(frac = frac)) %>>!%
62+
po("subsample", param_vals = list(frac = frac, replace = replace)) %>>!%
5463
g %>>!%
5564
averager
5665
}

man/mlr_graphs_bagging.Rd

Lines changed: 18 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)