Skip to content

Commit c10b3c8

Browse files
authored
Merge pull request #724 from mlr-org/fix_tunethreshold
Fix tunethreshold
2 parents e83859d + 1755c1b commit c10b3c8

File tree

7 files changed

+64
-11
lines changed

7 files changed

+64
-11
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: mlr3pipelines
22
Title: Preprocessing Operators and Pipelines for 'mlr3'
3-
Version: 0.5.0-9000
3+
Version: 0.5.0-1
44
Authors@R:
55
c(person(given = "Martin",
66
family = "Binder",

NEWS.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
# mlr3pipelines 0.5.0-9000
1+
# mlr3pipelines 0.5.0-1
2+
3+
* Bugfix: `PipeOpTuneThreshold` was not overloading the correct `.train` and `.predict` functions.
24

35
# mlr3pipelines 0.5.0
46

R/PipeOp.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,7 @@ PipeOp = R6Class("PipeOp",
364364
},
365365
predict_type = function(val) {
366366
if (!missing(val)) {
367-
if (!identical(val, private$.learner)) {
368-
stop("$predict_type is read-only.")
369-
}
367+
stop("$predict_type is read-only.")
370368
}
371369
return(NULL)
372370
},

R/PipeOpLearnerCV.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,13 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
161161
} else {
162162
multiplicity_recurse(self$state, clone_with_state, learner = private$.learner)
163163
}
164+
},
165+
predict_type = function(val) {
166+
if (!missing(val)) {
167+
assert_subset(val, names(mlr_reflections$learner_predict_types[[private$.learner$task_type]]))
168+
private$.learner$predict_type = val
169+
}
170+
private$.learner$predict_type
164171
}
165172
),
166173
private = list(

R/PipeOpTuneThreshold.R

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,13 @@ PipeOpTuneThreshold = R6Class("PipeOpTuneThreshold",
8787
output = data.table(name = "output", train = "NULL", predict = "Prediction"),
8888
tags = "target transform"
8989
)
90-
},
91-
train = function(input) {
90+
}
91+
),
92+
active = list(
93+
predict_type = function() "response" # we are predict type "response" for now, so we don't break things. See discussion in #712
94+
),
95+
private = list(
96+
.train = function(input) {
9297
if(!all(input[[1]]$feature_types$type == "numeric")) {
9398
stop("PipeOpTuneThreshold requires predicted probabilities! Set learner predict_type to 'prob'")
9499
}
@@ -97,13 +102,11 @@ PipeOpTuneThreshold = R6Class("PipeOpTuneThreshold",
97102
self$state = list("threshold" = th)
98103
return(list(NULL))
99104
},
100-
predict = function(input) {
105+
.predict = function(input) {
101106
pred = private$.task_to_prediction(input[[1]])
102107
pred$set_threshold(self$state$threshold)
103108
return(list(pred))
104-
}
105-
),
106-
private = list(
109+
},
107110
.objfun = function(xs, pred, measure) {
108111
lvls = colnames(pred$prob)
109112
res = pred$set_threshold(unlist(xs))$score(measure)

tests/testthat/test_pipeop_learnercv.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,25 @@ test_that("PipeOpLearnerCV - model active binding to state", {
9898
expect_null(po$learner$state)
9999
expect_equal(po$learner_model$state, po$state)
100100
})
101+
102+
test_that("predict_type", {
103+
expect_equal(po("learner_cv", lrn("classif.rpart", predict_type = "response"))$predict_type, "response")
104+
expect_equal(po("learner_cv", lrn("classif.rpart", predict_type = "prob"))$predict_type, "prob")
105+
106+
lcv <- po("learner_cv", lrn("classif.rpart", predict_type = "prob"))
107+
108+
lcv$predict_type = "response"
109+
expect_equal(lcv$predict_type, "response")
110+
expect_equal(lcv$learner$predict_type, "response")
111+
112+
expect_equal(lcv$train(list(tsk("iris")))[[1]]$feature_names, "classif.rpart.response")
113+
114+
lcv$predict_type = "prob"
115+
116+
expect_equal(lcv$predict_type, "prob")
117+
expect_equal(lcv$learner$predict_type, "prob")
118+
119+
expect_subset(c("classif.rpart.prob.virginica", "classif.rpart.prob.setosa", "classif.rpart.prob.versicolor"),
120+
lcv$train(list(tsk("iris")))[[1]]$feature_names)
121+
122+
})

tests/testthat/test_pipeop_tunethreshold.R

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,24 @@ test_that("threshold works for binary", {
3434
po("tunethreshold")
3535
expect_error(po_cv$train(t), "prob")
3636
})
37+
38+
test_that("tunethreshold graph works", {
39+
40+
graph = po("learner_cv", lrn("classif.rpart", predict_type = "prob")) %>>% po("tunethreshold")
41+
42+
out = graph$train(tsk("pima"))
43+
44+
expect_null(out$tunethreshold.output)
45+
46+
out = graph$predict(tsk("pima"))
47+
48+
expect_prediction(out$tunethreshold.output)
49+
50+
glrn = as_learner(graph)
51+
52+
glrn$train(tsk("pima"))
53+
54+
expect_prediction(glrn$predict(tsk("pima")))
55+
56+
57+
})

0 commit comments

Comments
 (0)