Skip to content

Commit fc7db82

Browse files
committed
Add tests for multiple inputs, learner_model, and predictions in multiplicities
1 parent 2237f14 commit fc7db82

File tree

2 files changed

+44
-18
lines changed

2 files changed

+44
-18
lines changed

R/PipeOpLearnerPICVPlus.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,14 @@ PipeOpLearnerPICVPlus = R6Class("PipeOpLearnerPICVPlus",
110110
out_type = mlr_reflections$task_types[type, mult = "first"]$prediction
111111

112112
# paradox requirements 1.0
113-
private$.cvplus_param_set = ps(
113+
private$.picvplus_param_set = ps(
114114
folds = p_int(lower = 2L, upper = Inf, tags = c("train", "required")),
115115
alpha = p_dbl(lower = 0L, upper = 1L, tags = c("predict", "required"))
116116
)
117117

118-
private$.cvplus_param_set$values = list(folds = 3, alpha = 0.05) # default
118+
private$.picvplus_param_set$values = list(folds = 3, alpha = 0.05) # default
119119

120-
super$initialize(id, param_set = alist(cvplus = private$.cvplus_param_set, private$.learner$param_set),
120+
super$initialize(id, param_set = alist(picvplus = private$.picvplus_param_set, private$.learner$param_set),
121121
param_vals = param_vals,
122122
input = data.table(name = "input", train = task_type, predict = task_type),
123123
output = data.table(name = "output", train = "NULL", predict = out_type),
@@ -162,7 +162,7 @@ PipeOpLearnerPICVPlus = R6Class("PipeOpLearnerPICVPlus",
162162

163163
.train = function(inputs) {
164164
task = inputs[[1L]]
165-
pv = private$.cvplus_param_set$values
165+
pv = private$.picvplus_param_set$values
166166

167167
# Compute CV Predictions
168168
rdesc = rsmp("cv", folds = pv$folds)
@@ -179,7 +179,7 @@ PipeOpLearnerPICVPlus = R6Class("PipeOpLearnerPICVPlus",
179179

180180
.predict = function(inputs) {
181181
task = inputs[[1L]]
182-
pv = private$.cvplus_param_set$values
182+
pv = private$.picvplus_param_set$values
183183

184184
mu_hat = map(self$state$cv_model_states, function(state) {
185185
on.exit({private$.learner$state = NULL})
@@ -209,7 +209,7 @@ PipeOpLearnerPICVPlus = R6Class("PipeOpLearnerPICVPlus",
209209
))
210210
},
211211

212-
.cvplus_param_set = NULL,
212+
.picvplus_param_set = NULL,
213213
.learner = NULL,
214214
.additional_phash_input = function() private$.learner$phash
215215
)

tests/testthat/test_pipeop_learnerpicvplus.R

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,17 @@ test_that("PipeOpLearnerPICVPlus - param set and values", {
5454
lrn = mlr_learners$get("regr.rpart")
5555
po = PipeOpLearnerPICVPlus$new(lrn)
5656

57-
expect_subset(c("minsplit", "cvplus.folds", "cvplus.alpha"), po$param_set$ids())
58-
expect_equal(po$param_set$values, list(cvplus.folds = 3, cvplus.alpha = 0.05, xval = 0))
57+
expect_subset(c("minsplit", "picvplus.folds", "picvplus.alpha"), po$param_set$ids())
58+
expect_equal(po$param_set$values, list(picvplus.folds = 3, picvplus.alpha = 0.05, xval = 0))
5959

6060
po$param_set$values$minsplit = 2
61-
expect_equal(po$param_set$values, list(cvplus.folds = 3, cvplus.alpha = 0.05, minsplit = 2, xval = 0))
61+
expect_equal(po$param_set$values, list(picvplus.folds = 3, picvplus.alpha = 0.05, minsplit = 2, xval = 0))
6262

63-
po$param_set$values$cvplus.folds = 5
64-
expect_equal(po$param_set$values, list(cvplus.folds = 5, cvplus.alpha = 0.05, minsplit = 2, xval = 0))
63+
po$param_set$values$picvplus.folds = 5
64+
expect_equal(po$param_set$values, list(picvplus.folds = 5, picvplus.alpha = 0.05, minsplit = 2, xval = 0))
6565

66-
expect_error(PipeOpLearnerPICVPlus$new(lrn, param_vals = list(cvplus.folds = 1)), "is not >= 1")
67-
expect_error(PipeOpLearnerPICVPlus$new(lrn, param_vals = list(cvplus.alpha = -1)), "is not >= -")
66+
expect_error(PipeOpLearnerPICVPlus$new(lrn, param_vals = list(picvplus.folds = 1)), "is not >= 1")
67+
expect_error(PipeOpLearnerPICVPlus$new(lrn, param_vals = list(picvplus.alpha = -1)), "is not >= -")
6868

6969
lrn_classif = mlr_learners$get("classif.featureless")
7070
expect_error(PipeOpLearnerPICVPlus$new(lrn_classif), "only supports regression")
@@ -109,9 +109,9 @@ test_that("PipeOpLearnerPICVPlus - integration with larger graph", {
109109
task = mlr_tasks$get("mtcars")
110110
lrn = mlr_learners$get("regr.rpart")
111111

112-
po_cvplus = PipeOpLearnerPICVPlus$new(lrn)
112+
po_picvplus = PipeOpLearnerPICVPlus$new(lrn)
113113
po_nop = PipeOpNOP$new()
114-
graph = po_cvplus %>>% po_nop
114+
graph = po_picvplus %>>% po_nop
115115

116116
graph$train(task)
117117
predictions = graph$predict(task)[[1]]
@@ -163,14 +163,40 @@ test_that("state class and multiplicity", {
163163
lrn = lrn("regr.debug")
164164
lrn$properties = c(lrn$properties, "marshal")
165165
po = PipeOpLearnerPICVPlus$new(lrn)
166-
po$train(list(Multiplicity(tsk("mtcars"))))
166+
167+
task1 = mlr_tasks$get("mtcars")
168+
task2 = mlr_tasks$get("boston_housing")
169+
input = Multiplicity(task1, task2)
170+
171+
po$train(list(input))
167172
expect_class(po$state, "Multiplicity")
168173
expect_class(po$state[[1L]], "pipeop_learner_pi_cvplus_state")
174+
expect_class(po$learner_model, "Multiplicity")
175+
expect_class(po$learner_model[[1L]][[1L]], "LearnerRegr")
176+
expect_equal(length(po$learner_model), length(input))
177+
expect_equal(length(po$learner_model[[1]]), po$param_set$values$picvplus.folds)
178+
179+
prds = po$predict(list(input))
180+
expect_class(prds$output, "Multiplicity")
181+
expect_equal(length(prds$output), length(input))
182+
expect_class(prds$output[[1L]], "PredictionRegr")
169183

170184
# recursive
171-
po1 = po("learner_pi_cvplus", learner = lrn("regr.debug"))
172-
po1$train(list(Multiplicity(Multiplicity(tsk("mtcars")))))
185+
po1 = PipeOpLearnerPICVPlus$new(lrn)
186+
po1$train(list(Multiplicity(input)))
173187
expect_class(po1$state, "Multiplicity")
174188
expect_class(po1$state[[1L]], "Multiplicity")
175189
expect_class(po1$state[[1L]][[1L]], "pipeop_learner_pi_cvplus_state")
190+
191+
expect_class(po1$learner_model, "Multiplicity")
192+
expect_class(po1$learner_model[[1L]], "Multiplicity")
193+
expect_class(po1$learner_model[[1L]][[1L]][[1]], "LearnerRegr")
194+
expect_equal(length(po1$learner_model[[1L]]), length(input))
195+
expect_equal(length(po1$learner_model[[1L]][[1L]]), po1$param_set$values$picvplus.folds)
196+
197+
prds1 = po1$predict(list(Multiplicity(input)))
198+
expect_class(prds1$output, "Multiplicity")
199+
expect_class(prds1$output[[1L]], "Multiplicity")
200+
expect_class(prds1$output[[1L]][[1L]], "PredictionRegr")
201+
expect_equal(length(prds1$output[[1L]]), length(input))
176202
})

0 commit comments

Comments
 (0)