Skip to content

Commit 1755c1b

Browse files
committed
smarter predict_type handling
1 parent 26c7d17 commit 1755c1b

File tree

4 files changed

+33
-3
lines changed

4 files changed

+33
-3
lines changed

R/PipeOp.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,7 @@ PipeOp = R6Class("PipeOp",
364364
},
365365
predict_type = function(val) {
366366
if (!missing(val)) {
367-
if (!identical(val, private$.learner)) {
368-
stop("$predict_type is read-only.")
369-
}
367+
stop("$predict_type is read-only.")
370368
}
371369
return(NULL)
372370
},

R/PipeOpLearnerCV.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,13 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
161161
} else {
162162
multiplicity_recurse(self$state, clone_with_state, learner = private$.learner)
163163
}
164+
},
165+
predict_type = function(val) {
166+
if (!missing(val)) {
167+
assert_subset(val, names(mlr_reflections$learner_predict_types[[private$.learner$task_type]]))
168+
private$.learner$predict_type = val
169+
}
170+
private$.learner$predict_type
164171
}
165172
),
166173
private = list(

R/PipeOpTuneThreshold.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ PipeOpTuneThreshold = R6Class("PipeOpTuneThreshold",
8989
)
9090
}
9191
),
92+
active = list(
93+
predict_type = function() "response" # we are predict type "response" for now, so we don't break things. See discussion in #712
94+
),
9295
private = list(
9396
.train = function(input) {
9497
if(!all(input[[1]]$feature_types$type == "numeric")) {

tests/testthat/test_pipeop_learnercv.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,25 @@ test_that("PipeOpLearnerCV - model active binding to state", {
9898
expect_null(po$learner$state)
9999
expect_equal(po$learner_model$state, po$state)
100100
})
101+
102+
test_that("predict_type", {
103+
expect_equal(po("learner_cv", lrn("classif.rpart", predict_type = "response"))$predict_type, "response")
104+
expect_equal(po("learner_cv", lrn("classif.rpart", predict_type = "prob"))$predict_type, "prob")
105+
106+
lcv <- po("learner_cv", lrn("classif.rpart", predict_type = "prob"))
107+
108+
lcv$predict_type = "response"
109+
expect_equal(lcv$predict_type, "response")
110+
expect_equal(lcv$learner$predict_type, "response")
111+
112+
expect_equal(lcv$train(list(tsk("iris")))[[1]]$feature_names, "classif.rpart.response")
113+
114+
lcv$predict_type = "prob"
115+
116+
expect_equal(lcv$predict_type, "prob")
117+
expect_equal(lcv$learner$predict_type, "prob")
118+
119+
expect_subset(c("classif.rpart.prob.virginica", "classif.rpart.prob.setosa", "classif.rpart.prob.versicolor"),
120+
lcv$train(list(tsk("iris")))[[1]]$feature_names)
121+
122+
})

0 commit comments

Comments
 (0)