Skip to content

Commit 93d3d3e

Browse files
PhilippProlarskotthoff
authored andcommitted
Ranger case weights (#2418)
* add case weights and repair minprop * solve case.weights problem
1 parent bc7f986 commit 93d3d3e

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

R/RLearner_classif_ranger.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ makeRLearner.classif.ranger = function() {
3636
}
3737

3838
#' @export
39-
trainLearner.classif.ranger = function(.learner, .task, .subset, .weights = NULL, mtry, mtry.perc, min.node.size, ...) {
39+
trainLearner.classif.ranger = function(.learner, .task, .subset, .weights = NULL, mtry, mtry.perc, min.node.size, case.weights, ...) {
4040
tn = getTaskTargetNames(.task)
4141
if (missing(mtry)) {
4242
if (missing(mtry.perc)) {
@@ -52,8 +52,11 @@ trainLearner.classif.ranger = function(.learner, .task, .subset, .weights = NULL
5252
min.node.size = 1
5353
}
5454
}
55+
if (missing(case.weights)) {
56+
case.weights = .weights
57+
}
5558
ranger::ranger(formula = NULL, dependent.variable = tn, data = getTaskData(.task, .subset),
56-
probability = (.learner$predict.type == "prob"), case.weights = .weights, mtry = mtry, min.node.size = min.node.size, ...)
59+
probability = (.learner$predict.type == "prob"), case.weights = case.weights, mtry = mtry, min.node.size = min.node.size, ...)
5760
}
5861

5962
#' @export

R/RLearner_regr_ranger.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ makeRLearner.regr.ranger = function() {
3838
}
3939

4040
#' @export
41-
trainLearner.regr.ranger = function(.learner, .task, .subset, .weights = NULL, keep.inbag = NULL, mtry, mtry.perc, ...) {
41+
trainLearner.regr.ranger = function(.learner, .task, .subset, .weights = NULL, keep.inbag = NULL, mtry, mtry.perc, case.weights, ...) {
4242
tn = getTaskTargetNames(.task)
4343
if (missing(mtry)) {
4444
if (missing(mtry.perc)) {
@@ -47,10 +47,13 @@ trainLearner.regr.ranger = function(.learner, .task, .subset, .weights = NULL, k
4747
mtry = max(1, floor(mtry.perc * getTaskNFeats(.task)))
4848
}
4949
}
50+
if (missing(case.weights)) {
51+
case.weights = .weights
52+
}
5053
keep.inbag = if (is.null(keep.inbag)) FALSE else keep.inbag
5154
keep.inbag = if (.learner$predict.type == "se") TRUE else keep.inbag
5255
ranger::ranger(formula = NULL, dependent.variable = tn, data = getTaskData(.task, .subset),
53-
case.weights = .weights, keep.inbag = keep.inbag, mtry = mtry, ...)
56+
case.weights = case.weights, keep.inbag = keep.inbag, mtry = mtry, ...)
5457
}
5558

5659
#' @export

0 commit comments

Comments
 (0)