44# '
55# ' @description
66# ' Random regression forest.
7- # ' Calls [ranger::ranger()] from package \CRANpkg{ranger}.
7+ # ' Calls `ranger()` from package \CRANpkg{ranger}.
8+ # '
9+ # ' @details
10+ # ' Additionally to the uncertainty estimation methods provided by the ranger package, the learner provides a ensemble standard deviation and law of total variance uncertainty estimation.
11+ # ' Both methods compute the empirical mean and variance of the training data points that fall into the predicted leaf nodes.
12+ # ' The ensemble standard deviation method calculates the standard deviation of the mean of the leaf nodes.
13+ # ' The law of total variance method calculates the mean of the variance of the leaf nodes plus the variance of the means of the leaf nodes.
14+ # ' Formulas for the ensemble standard deviation and law of total variance method are given in Hutter et al. (2015).
15+ # '
16+ # ' For these 2 methods, the parameter `sigma2.threshold` can be used to set a threshold for the variance of the leaf nodes,
17+ # ' this is a minimal value for the variance of the leaf nodes, if the variance is below this threshold, it is set to this value (as described in the paper).
18+ # ' Default is 1e-2.
819# '
920# ' @inheritSection mlr_learners_classif.ranger Custom mlr3 parameters
1021# ' @inheritSection mlr_learners_classif.ranger Initial parameter values
1324# ' @template learner
1425# '
1526# ' @references
16- # ' `r format_bib("wright_2017", "breiman_2001")`
27+ # ' `r format_bib("wright_2017", "breiman_2001", "hutter_2015" )`
1728# '
1829# ' @export
1930# ' @template seealso_learner
@@ -50,15 +61,16 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
5061 sample.fraction = p_dbl(0L , 1L , tags = " train" ),
5162 save.memory = p_lgl(default = FALSE , tags = " train" ),
5263 scale.permutation.importance = p_lgl(default = FALSE , tags = " train" , depends = quote(importance == " permutation" )),
53- se.method = p_fct(c(" jack" , " infjack" ), default = " infjack" , tags = " predict" ), # FIXME: only works if predict_type == "se". How to set dependency?
64+ se.method = p_fct(c(" jack" , " infjack" , " ensemble_standard_deviation" , " law_of_total_variance" ), default = " infjack" , tags = " predict" ),
65+ sigma2.threshold = p_dbl(default = 1e-2 , tags = " train" ),
5466 seed = p_int(default = NULL , special_vals = list (NULL ), tags = c(" train" , " predict" )),
5567 split.select.weights = p_uty(default = NULL , tags = " train" ),
5668 splitrule = p_fct(c(" variance" , " extratrees" , " maxstat" , " beta" , " poisson" ), default = " variance" , tags = " train" ),
5769 verbose = p_lgl(default = TRUE , tags = c(" train" , " predict" )),
5870 write.forest = p_lgl(default = TRUE , tags = " train" )
5971 )
6072
61- ps $ set_values(num.threads = 1L )
73+ ps $ set_values(num.threads = 1L , sigma2.threshold = 1e-2 )
6274
6375 super $ initialize(
6476 id = " regr.ranger" ,
@@ -79,14 +91,14 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
7991 # '
8092 # ' @return Named `numeric()`.
8193 importance = function () {
82- if (is.null(self $ model )) {
94+ if (is.null(self $ model $ model )) {
8395 stopf(" No model stored" )
8496 }
85- if (self $ model $ importance.mode == " none" ) {
97+ if (self $ model $ model $ importance.mode == " none" ) {
8698 stopf(" No importance stored" )
8799 }
88100
89- sort(self $ model $ variable.importance , decreasing = TRUE )
101+ sort(self $ model $ model $ variable.importance , decreasing = TRUE )
90102 },
91103
92104 # ' @description
@@ -98,8 +110,8 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
98110 return (self $ state $ oob_error )
99111 }
100112
101- if (! is.null(self $ model )) {
102- return (self $ model $ prediction.error )
113+ if (! is.null(self $ model $ model )) {
114+ return (self $ model $ model $ prediction.error )
103115 }
104116
105117 stopf(" No model stored" )
@@ -110,14 +122,17 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
110122 # '
111123 # ' @return `character()`.
112124 selected_features = function () {
113- ranger_selected_features(self )
125+ ranger_selected_features(self $ model $ model , self $ state $ feature_names )
114126 }
115127 ),
116128
117129 private = list (
118130 .train = function (task ) {
119131 pv = self $ param_set $ get_values(tags = " train" )
120132 pv = convert_ratio(pv , " mtry" , " mtry.ratio" , length(task $ feature_names ))
133+ pv $ se.method = NULL
134+ sigma2_threshold = pv $ sigma2.threshold
135+ pv $ sigma2.threshold = NULL
121136 pv $ case.weights = get_weights(task , private )
122137
123138 if (self $ predict_type == " se" ) {
@@ -127,43 +142,56 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
127142 if (self $ predict_type == " quantiles" ) {
128143 pv $ quantreg = TRUE # nolint
129144 }
130-
131- invoke(ranger :: ranger ,
145+ data = task $ data()
146+ model = invoke(ranger :: ranger ,
132147 dependent.variable.name = task $ target_names ,
133- data = task $ data() ,
148+ data = data ,
134149 .args = pv
135150 )
151+
152+ if (isTRUE(self $ param_set $ values $ se.method %in% c(" ensemble_standard_deviation" , " law_of_total_variance" ))) {
153+ # num.threads is the only thing from the param set we want to pass here and not set manually
154+ prediction_nodes = mlr3misc :: invoke(predict , model , data = data , type = " terminalNodes" , predict.all = TRUE , num.threads = pv $ num.threads )
155+ storage.mode(prediction_nodes $ predictions ) = " integer"
156+ mu_sigma = .Call(" c_ranger_mu_sigma" , prediction_nodes $ predictions , task $ truth(), sigma2_threshold )
157+ list (model = model , mu_sigma = mu_sigma )
158+ } else {
159+ list (model = model )
160+ }
136161 },
137162
138163 .predict = function (task ) {
139164 pv = self $ param_set $ get_values(tags = " predict" )
140165 newdata = ordered_features(task , self )
141166
142- prediction = invoke(predict , self $ model ,
143- data = newdata ,
144- type = self $ predict_type ,
145- quantiles = private $ .quantiles ,
146- .args = pv )
147-
148- if (self $ predict_type == " quantiles" ) {
149- assert_quantiles(self , quantile_response = TRUE )
150- quantiles = prediction $ predictions
151- setattr(quantiles , " probs" , private $ .quantiles )
152- setattr(quantiles , " response" , private $ .quantile_response )
153- return (list (quantiles = quantiles ))
167+ if (isTRUE(pv $ se.method %in% c(" ensemble_standard_deviation" , " law_of_total_variance" ))) {
168+ prediction_nodes = mlr3misc :: invoke(predict , self $ model $ model , data = newdata , type = " terminalNodes" , .args = pv [setdiff(names(pv ), " se.method" )], predict.all = TRUE )
169+ storage.mode(prediction_nodes $ predictions ) = " integer"
170+ method = if (pv $ se.method == " ensemble_standard_deviation" ) 0 else 1
171+ .Call(" c_ranger_var" , prediction_nodes $ predictions , self $ model $ mu_sigma , method )
172+ } else {
173+ prediction = mlr3misc :: invoke(predict , self $ model $ model , data = newdata , type = self $ predict_type , quantiles = private $ .quantiles , .args = pv )
174+
175+ if (self $ predict_type == " quantiles" ) {
176+ assert_quantiles(self , quantile_response = TRUE )
177+ quantiles = prediction $ predictions
178+ setattr(quantiles , " probs" , private $ .quantiles )
179+ setattr(quantiles , " response" , private $ .quantile_response )
180+ return (list (quantiles = quantiles ))
181+ }
182+
183+ list (response = prediction $ predictions , se = prediction $ se )
154184 }
155-
156- list (response = prediction $ predictions , se = prediction $ se )
157185 },
158186
159187 .hotstart = function (task ) {
160- model = self $ models
188+ model = self $ model $ model
161189 model $ num.trees = self $ param_set $ values $ num.trees
162- model
190+ list ( model = model )
163191 },
164192
165193 .extract_oob_error = function () {
166- self $ model $ prediction.error
194+ self $ model $ model $ prediction.error
167195 }
168196 )
169197)
0 commit comments