Skip to content

Commit 0d8a29d

Browse files
committed
tests
1 parent ac5a09a commit 0d8a29d

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

tests/testthat/test_mlr_graphs_bagging.R

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,36 @@ test_that("Bagging Pipeline", {
3939
expect_true(all(map_lgl(predict_out, function(x) "PredictionClassif" %in% class(x))))
4040
})
4141

42+
test_that("Bagging with replacement", {
43+
tsk = tsk("iris")
44+
lrn = lrn("classif.rpart")
45+
p = ppl("bagging", graph = po(lrn), replace = TRUE, averager = po("classifavg", collect_multiplicity = TRUE))
46+
expect_graph(p)
47+
res = resample(tsk, GraphLearner$new(p), rsmp("holdout"))
48+
expect_resample_result(res)
49+
50+
tsk$filter(1:140)
51+
expect_equal(anyDuplicated(tsk$data()), 0) # make sure no duplicates
52+
53+
p = ppl("bagging", iterations = 2,
54+
graph = lrn("classif.debug", save_tasks = TRUE),
55+
replace = TRUE, averager = po("classifavg", collect_multiplicity = TRUE)
56+
)
57+
p$train(tsk)
58+
59+
expect_true(anyDuplicated(p$pipeops$classif.debug$state[[1]]$model$task_train$data()) != 0)
60+
61+
getOrigId = function(data) {
62+
tsk$data()[, origline := .I][data, on = colnames(tsk$data()), origline]
63+
}
64+
orig_id_1 = getOrigId(p$pipeops$classif.debug$state[[1]]$model$task_train$data())
65+
orig_id_2 = getOrigId(p$pipeops$classif.debug$state[[2]]$model$task_train$data())
66+
67+
expect_equal(length(orig_id_1), 140)
68+
expect_equal(length(orig_id_2), 140)
69+
# if we sampled the same values twice, the all.equal() would just give TRUE
70+
expect_string(all.equal(orig_id_1, orig_id_2))
71+
72+
expect_true(length(unique(orig_id_1)) < 140)
73+
expect_true(length(unique(orig_id_2)) < 140)
74+
})

0 commit comments

Comments
 (0)