Skip to content

Commit f94eb7b

Browse files
be-marcberndbischlBernd Bischl
authored
feat: add simple and law of total variance se methods to ranger (#347)
* feat: add simple and law of total variance se methods to ranger * ... * ... * ... * ... * ... * ... * ... * ... * ... * ...g * ... * ... * ... * ... * ... * ... * fix off by one error * ... * .... * .... * ... * ... * ... * ... * ... * ... * ... --------- Co-authored-by: Bernd Bischl <[email protected]> Co-authored-by: Bernd Bischl <[email protected]>
1 parent f6b81f1 commit f94eb7b

21 files changed

+774
-56
lines changed

.Rbuildignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ vignettes/learners/
1919
^revdep$
2020
^cran-comments\.md$
2121
^CRAN-SUBMISSION$
22+
.clangd

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,4 @@ revdep/
181181

182182
# misc
183183
Meta/
184+
.clangd

.lintr

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ linters: linters_with_defaults(
55
object_name_linter = object_name_linter(c("snake_case", "CamelCase")), # only allow snake case and camel case object names
66
cyclocomp_linter = NULL, # do not check function complexity
77
commented_code_linter = NULL, # allow code in comments
8-
line_length_linter = line_length_linter(120L)
8+
line_length_linter = line_length_linter(300L)
99
)

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Suggests:
4343
knitr,
4444
lgr,
4545
MASS,
46+
mirai,
4647
nnet,
4748
pracma,
4849
ranger,
@@ -54,7 +55,7 @@ Remotes:
5455
mlr-org/mlr3
5556
Config/testthat/edition: 3
5657
Encoding: UTF-8
57-
NeedsCompilation: no
58+
NeedsCompilation: yes
5859
Roxygen: list(markdown = TRUE)
5960
RoxygenNote: 7.3.3
6061
Collate:

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,5 @@ importFrom(stats,predict)
4141
importFrom(stats,reformulate)
4242
importFrom(utils,bibentry)
4343
importFrom(utils,packageVersion)
44+
useDynLib(mlr3learners,c_ranger_mu_sigma)
45+
useDynLib(mlr3learners,c_ranger_var)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# mlr3learners (development version)
22

3+
* feat: Add new uncertainty estimation methods `ensemble_standard_deviation` and `law_of_total_variance` to `regr.ranger` learner.
4+
35
# mlr3learners 0.12.0
46

57
* feat: Add `classif.kknn` and `regr.kknn` learners.

R/LearnerClassifRanger.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ LearnerClassifRanger = R6Class("LearnerClassifRanger",
128128
#'
129129
#' @return `character()`.
130130
selected_features = function() {
131-
ranger_selected_features(self)
131+
ranger_selected_features(self$model, self$state$feature_names)
132132
}
133133
),
134134

R/LearnerRegrRanger.R

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,18 @@
44
#'
55
#' @description
66
#' Random regression forest.
7-
#' Calls [ranger::ranger()] from package \CRANpkg{ranger}.
7+
#' Calls `ranger()` from package \CRANpkg{ranger}.
8+
#'
9+
#' @details
10+
#' Additionally to the uncertainty estimation methods provided by the ranger package, the learner provides a ensemble standard deviation and law of total variance uncertainty estimation.
11+
#' Both methods compute the empirical mean and variance of the training data points that fall into the predicted leaf nodes.
12+
#' The ensemble standard deviation method calculates the standard deviation of the mean of the leaf nodes.
13+
#' The law of total variance method calculates the mean of the variance of the leaf nodes plus the variance of the means of the leaf nodes.
14+
#' Formulas for the ensemble standard deviation and law of total variance method are given in Hutter et al. (2015).
15+
#'
16+
#' For these 2 methods, the parameter `sigma2.threshold` can be used to set a threshold for the variance of the leaf nodes,
17+
#' this is a minimal value for the variance of the leaf nodes, if the variance is below this threshold, it is set to this value (as described in the paper).
18+
#' Default is 1e-2.
819
#'
920
#' @inheritSection mlr_learners_classif.ranger Custom mlr3 parameters
1021
#' @inheritSection mlr_learners_classif.ranger Initial parameter values
@@ -13,7 +24,7 @@
1324
#' @template learner
1425
#'
1526
#' @references
16-
#' `r format_bib("wright_2017", "breiman_2001")`
27+
#' `r format_bib("wright_2017", "breiman_2001", "hutter_2015")`
1728
#'
1829
#' @export
1930
#' @template seealso_learner
@@ -50,15 +61,16 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
5061
sample.fraction = p_dbl(0L, 1L, tags = "train"),
5162
save.memory = p_lgl(default = FALSE, tags = "train"),
5263
scale.permutation.importance = p_lgl(default = FALSE, tags = "train", depends = quote(importance == "permutation")),
53-
se.method = p_fct(c("jack", "infjack"), default = "infjack", tags = "predict"), # FIXME: only works if predict_type == "se". How to set dependency?
64+
se.method = p_fct(c("jack", "infjack", "ensemble_standard_deviation", "law_of_total_variance"), default = "infjack", tags = "predict"),
65+
sigma2.threshold = p_dbl(default = 1e-2, tags = "train"),
5466
seed = p_int(default = NULL, special_vals = list(NULL), tags = c("train", "predict")),
5567
split.select.weights = p_uty(default = NULL, tags = "train"),
5668
splitrule = p_fct(c("variance", "extratrees", "maxstat", "beta", "poisson"), default = "variance", tags = "train"),
5769
verbose = p_lgl(default = TRUE, tags = c("train", "predict")),
5870
write.forest = p_lgl(default = TRUE, tags = "train")
5971
)
6072

61-
ps$set_values(num.threads = 1L)
73+
ps$set_values(num.threads = 1L, sigma2.threshold = 1e-2)
6274

6375
super$initialize(
6476
id = "regr.ranger",
@@ -79,14 +91,14 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
7991
#'
8092
#' @return Named `numeric()`.
8193
importance = function() {
82-
if (is.null(self$model)) {
94+
if (is.null(self$model$model)) {
8395
stopf("No model stored")
8496
}
85-
if (self$model$importance.mode == "none") {
97+
if (self$model$model$importance.mode == "none") {
8698
stopf("No importance stored")
8799
}
88100

89-
sort(self$model$variable.importance, decreasing = TRUE)
101+
sort(self$model$model$variable.importance, decreasing = TRUE)
90102
},
91103

92104
#' @description
@@ -98,8 +110,8 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
98110
return(self$state$oob_error)
99111
}
100112

101-
if (!is.null(self$model)) {
102-
return(self$model$prediction.error)
113+
if (!is.null(self$model$model)) {
114+
return(self$model$model$prediction.error)
103115
}
104116

105117
stopf("No model stored")
@@ -110,14 +122,17 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
110122
#'
111123
#' @return `character()`.
112124
selected_features = function() {
113-
ranger_selected_features(self)
125+
ranger_selected_features(self$model$model, self$state$feature_names)
114126
}
115127
),
116128

117129
private = list(
118130
.train = function(task) {
119131
pv = self$param_set$get_values(tags = "train")
120132
pv = convert_ratio(pv, "mtry", "mtry.ratio", length(task$feature_names))
133+
pv$se.method = NULL
134+
sigma2_threshold = pv$sigma2.threshold
135+
pv$sigma2.threshold = NULL
121136
pv$case.weights = get_weights(task, private)
122137

123138
if (self$predict_type == "se") {
@@ -127,43 +142,56 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
127142
if (self$predict_type == "quantiles") {
128143
pv$quantreg = TRUE # nolint
129144
}
130-
131-
invoke(ranger::ranger,
145+
data = task$data()
146+
model = invoke(ranger::ranger,
132147
dependent.variable.name = task$target_names,
133-
data = task$data(),
148+
data = data,
134149
.args = pv
135150
)
151+
152+
if (isTRUE(self$param_set$values$se.method %in% c("ensemble_standard_deviation", "law_of_total_variance"))) {
153+
# num.threads is the only thing from the param set we want to pass here and not set manually
154+
prediction_nodes = mlr3misc::invoke(predict, model, data = data, type = "terminalNodes", predict.all = TRUE, num.threads = pv$num.threads)
155+
storage.mode(prediction_nodes$predictions) = "integer"
156+
mu_sigma = .Call("c_ranger_mu_sigma", prediction_nodes$predictions, task$truth(), sigma2_threshold)
157+
list(model = model, mu_sigma = mu_sigma)
158+
} else {
159+
list(model = model)
160+
}
136161
},
137162

138163
.predict = function(task) {
139164
pv = self$param_set$get_values(tags = "predict")
140165
newdata = ordered_features(task, self)
141166

142-
prediction = invoke(predict, self$model,
143-
data = newdata,
144-
type = self$predict_type,
145-
quantiles = private$.quantiles,
146-
.args = pv)
147-
148-
if (self$predict_type == "quantiles") {
149-
assert_quantiles(self, quantile_response = TRUE)
150-
quantiles = prediction$predictions
151-
setattr(quantiles, "probs", private$.quantiles)
152-
setattr(quantiles, "response", private$.quantile_response)
153-
return(list(quantiles = quantiles))
167+
if (isTRUE(pv$se.method %in% c("ensemble_standard_deviation", "law_of_total_variance"))) {
168+
prediction_nodes = mlr3misc::invoke(predict, self$model$model, data = newdata, type = "terminalNodes", .args = pv[setdiff(names(pv), "se.method")], predict.all = TRUE)
169+
storage.mode(prediction_nodes$predictions) = "integer"
170+
method = if (pv$se.method == "ensemble_standard_deviation") 0 else 1
171+
.Call("c_ranger_var", prediction_nodes$predictions, self$model$mu_sigma, method)
172+
} else {
173+
prediction = mlr3misc::invoke(predict, self$model$model, data = newdata, type = self$predict_type, quantiles = private$.quantiles, .args = pv)
174+
175+
if (self$predict_type == "quantiles") {
176+
assert_quantiles(self, quantile_response = TRUE)
177+
quantiles = prediction$predictions
178+
setattr(quantiles, "probs", private$.quantiles)
179+
setattr(quantiles, "response", private$.quantile_response)
180+
return(list(quantiles = quantiles))
181+
}
182+
183+
list(response = prediction$predictions, se = prediction$se)
154184
}
155-
156-
list(response = prediction$predictions, se = prediction$se)
157185
},
158186

159187
.hotstart = function(task) {
160-
model = self$models
188+
model = self$model$model
161189
model$num.trees = self$param_set$values$num.trees
162-
model
190+
list(model = model)
163191
},
164192

165193
.extract_oob_error = function() {
166-
self$model$prediction.error
194+
self$model$model$prediction.error
167195
}
168196
)
169197
)

R/bibentries.R

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,5 +108,16 @@ bibentries = c( # nolint start
108108
number = "1",
109109
pages = "1--17",
110110
doi = "10.18637/jss.v077.i01"
111-
)
111+
),
112+
hutter_2015 = bibentry("inproceedings",
113+
title = "Algorithm runtime prediction: methods and evaluation",
114+
author = "Hutter, Frank and Xu, Lin and Hoos, Holger H. and Leyton-Brown, Kevin",
115+
year = "2015",
116+
publisher = "AAAI Press",
117+
booktitle = "Proceedings of the 24th International Conference on Artificial Intelligence",
118+
pages = "4197--4201",
119+
series = "IJCAI'15",
120+
doi = "10.5555/2832747.2832840"
121+
)
122+
112123
) # nolint end

R/helpers_ranger.R

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,18 @@ convert_ratio = function(pv, target, ratio, n) {
3535
)
3636
}
3737

38-
39-
40-
41-
ranger_selected_features = function(self) {
42-
if (is.null(self$model)) {
38+
ranger_selected_features = function(model, feature_names) {
39+
if (is.null(model)) {
4340
stopf("No model stored")
4441
}
4542

46-
splitvars = ranger::treeInfo(object = self$model, tree = 1)$splitvarName
43+
splitvars = ranger::treeInfo(object = model, tree = 1)$splitvarName
4744
i = 2
48-
while (i <= self$model$num.trees &&
49-
!all(self$state$feature_names %in% splitvars)) {
50-
sv = ranger::treeInfo(object = self$model, tree = i)$splitvarName
45+
while (i <= model$num.trees && !all(feature_names %in% splitvars)) {
46+
sv = ranger::treeInfo(object = model, tree = i)$splitvarName
5147
splitvars = union(splitvars, sv)
5248
i = i + 1
5349
}
5450

55-
# order the names of the selected features in the same order as in the task
56-
self$state$feature_names[self$state$feature_names %in% splitvars]
51+
splitvars[!is.na(splitvars)]
5752
}
58-
59-

0 commit comments

Comments
 (0)