Skip to content

Commit 994410f

Browse files
authored
fix: remove unsupported predict parameters from xgboost (#356)
* fix: remove unsupported predict parameters from xgboost * ... * ... * ... * ...
1 parent 06eabf3 commit 994410f

28 files changed

+47
-19
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import(paradox)
3434
importFrom(R6,R6Class)
3535
importFrom(mlr3,LearnerClassif)
3636
importFrom(mlr3,LearnerRegr)
37+
importFrom(mlr3,assert_quantiles)
3738
importFrom(mlr3,assert_validate)
3839
importFrom(mlr3,mlr_learners)
3940
importFrom(stats,predict)

R/LearnerClassifXgboost.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,6 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
129129
num_parallel_tree = p_int(1L, default = 1L, tags = c("train", "control")),
130130
objective = p_uty(default = "binary:logistic", tags = c("train", "predict", "control")),
131131
one_drop = p_lgl(default = FALSE, tags = "train", depends = quote(booster == "dart")),
132-
outputmargin = p_lgl(default = FALSE, tags = "predict"),
133-
predcontrib = p_lgl(default = FALSE, tags = "predict"),
134-
predinteraction = p_lgl(default = FALSE, tags = "predict"),
135-
predleaf = p_lgl(default = FALSE, tags = "predict"),
136132
print_every_n = p_int(1L, default = 1L, tags = "train", depends = quote(verbose == 1L)),
137133
process_type = p_fct(c("default", "update"), default = "default", tags = "train"),
138134
rate_drop = p_dbl(0, 1, default = 0, tags = "train", depends = quote(booster == "dart")),
@@ -222,6 +218,10 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
222218
lvls = task$class_names
223219
nlvls = length(lvls)
224220

221+
if (isTRUE(pv$predcontrib) || isTRUE(pv$predinteraction) || isTRUE(pv$predleaf)) {
222+
warningf("Predicting contributions, interactions, or leaf values with $predict() is not supported. ")
223+
}
224+
225225
if (is.null(pv$objective)) {
226226
pv$objective = if (nlvls == 2L) "binary:logistic" else "multi:softprob"
227227
}

R/LearnerRegrXgboost.R

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,6 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
109109
num_parallel_tree = p_int(1L, default = 1L, tags = "train"),
110110
objective = p_uty(default = "reg:squarederror", tags = c("train", "predict")),
111111
one_drop = p_lgl(default = FALSE, tags = "train", depends = quote(booster == "dart")),
112-
outputmargin = p_lgl(default = FALSE, tags = "predict"),
113-
predcontrib = p_lgl(default = FALSE, tags = "predict"),
114-
predinteraction = p_lgl(default = FALSE, tags = "predict"),
115-
predleaf = p_lgl(default = FALSE, tags = "predict"),
116112
print_every_n = p_int(1L, default = 1L, tags = "train", depends = quote(verbose == 1L)),
117113
process_type = p_fct(c("default", "update"), default = "default", tags = "train"),
118114
rate_drop = p_dbl(0, 1, default = 0, tags = "train", depends = quote(booster == "dart")),

R/zzz.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#' @import mlr3misc
44
#' @import checkmate
55
#' @importFrom R6 R6Class
6-
#' @importFrom mlr3 mlr_learners LearnerClassif LearnerRegr assert_validate
6+
#' @importFrom mlr3 mlr_learners LearnerClassif LearnerRegr assert_validate assert_quantiles
77
#' @importFrom stats predict reformulate
88
#' @importFrom utils packageVersion
99
#'

inst/paramtest/test_paramtest_classif.xgboost.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ test_that("predict classif.xgboost", {
6060
exclude = c(
6161
"object", # handled by mlr3
6262
"newdata", # handled by mlr3o
63-
"objective" # defined in xgboost::xgboost and already in param set
63+
"objective", # defined in xgboost::xgboost and already in param set
64+
"outputmargin", # not supported
65+
"predcontrib", # not supported
66+
"predinteraction", # not supported
67+
"predleaf" # not supported
6468
)
6569

6670
ParamTest = run_paramtest(learner, fun, exclude, tag = "predict")

inst/paramtest/test_paramtest_regr.xgboost.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ test_that("predict regr.xgboost", {
6060
exclude = c(
6161
"object", # handled by mlr3
6262
"newdata", # handled by mlr3
63-
"objective" # defined in xgboost::xgboost and already in param set
63+
"objective", # defined in xgboost::xgboost and already in param set
64+
"outputmargin", # not supported
65+
"predcontrib", # not supported
66+
"predinteraction", # not supported
67+
"predleaf" # not supported
6468
)
6569

6670
ParamTest = run_paramtest(learner, fun, exclude, tag = "predict")

man-roxygen/note_xgboost.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,7 @@
22
#' To compute on GPUs, you first need to compile \CRANpkg{xgboost} yourself and link
33
#' against CUDA.
44
#' See \url{https://xgboost.readthedocs.io/en/stable/build.html#building-with-gpu-support}.
5+
#'
6+
#' The `outputmargin`, `predcontrib`, `predinteraction`, and `predleaf` parameters are not supported.
7+
#' You can still call e.g. `predict(learner$model, newdata = newdata, outputmargin = TRUE)` to get these predictions.
8+

man/mlr_learners_classif.cv_glmnet.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_learners_classif.glmnet.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_learners_classif.kknn.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)