Skip to content

Commit 458170e

Browse files
authored
feat: store ranger oob error without storing models (#357)
* feat: store ranger oob error without storing models * ...
1 parent 994410f commit 458170e

File tree

4 files changed

+43
-7
lines changed

4 files changed

+43
-7
lines changed

R/LearnerClassifRanger.R

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,15 @@ LearnerClassifRanger = R6Class("LearnerClassifRanger",
112112
#'
113113
#' @return `numeric(1)`.
114114
oob_error = function() {
115-
if (is.null(self$model)) {
116-
stopf("No model stored")
115+
if (!is.null(self$state$oob_error)) {
116+
return(self$state$oob_error)
117117
}
118-
self$model$prediction.error
118+
119+
if (!is.null(self$model)) {
120+
return(self$model$prediction.error)
121+
}
122+
123+
stopf("No model stored")
119124
},
120125

121126
#' @description
@@ -162,6 +167,10 @@ LearnerClassifRanger = R6Class("LearnerClassifRanger",
162167
model = self$model
163168
model$num.trees = self$param_set$values$num.trees
164169
model
170+
},
171+
172+
.extract_oob_error = function() {
173+
self$model$prediction.error
165174
}
166175
)
167176
)

R/LearnerRegrRanger.R

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,17 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
9292
#' @description
9393
#' The out-of-bag error, extracted from model slot `prediction.error`.
9494
#'
95-
#' @return `numeric(1)`.
95+
#' @return `numeric(1)`
9696
oob_error = function() {
97-
if (is.null(self$model)) {
98-
stopf("No model stored")
97+
if (!is.null(self$state$oob_error)) {
98+
return(self$state$oob_error)
9999
}
100-
self$model$prediction.error
100+
101+
if (!is.null(self$model)) {
102+
return(self$model$prediction.error)
103+
}
104+
105+
stopf("No model stored")
101106
},
102107

103108
#' @description
@@ -155,6 +160,10 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
155160
model = self$models
156161
model$num.trees = self$param_set$values$num.trees
157162
model
163+
},
164+
165+
.extract_oob_error = function() {
166+
self$model$prediction.error
158167
}
159168
)
160169
)

tests/testthat/test_classif_ranger.R

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,12 @@ test_that("selected_features", {
9797
learner$train(task)
9898
expect_set_equal(learner$selected_features(), c("Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width"))
9999
})
100+
101+
test_that("oob_error available without stored model", {
102+
task = tsk("pima")
103+
learner = lrn("classif.ranger")
104+
105+
rr = resample(task, learner, rsmp("holdout"), store_models = FALSE)
106+
107+
expect_number(rr$aggregate(msr("oob_error")))
108+
})

tests/testthat/test_regr_ranger.R

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,12 @@ test_that("selected_features", {
9696
learner$train(task)
9797
expect_set_equal(learner$selected_features(), c("am", "cyl", "wt"))
9898
})
99+
100+
test_that("oob_error available without stored model", {
101+
task = tsk("mtcars")
102+
learner = lrn("regr.ranger")
103+
104+
rr = resample(task, learner, rsmp("holdout"), store_models = FALSE)
105+
106+
expect_number(rr$aggregate(msr("oob_error")))
107+
})

0 commit comments

Comments
 (0)