-
-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Description
Moving here from slds-lmu/paper_2023_survival_benchmark#9
The underlying issue is the (exceedingly rare?) case where a learners as a p_uty that's a numeric vector of length > 1, in this example the survival SVM in mlr3extralearners which has a param gamma.mu = c(x, y), which is tuned by creating proxy p_dbls and using a trafo to "ressamble" the param passed down to the learner.
Minimal reprex for unnesting
xdt <- data.table::data.table(
gamma = 2,
mu = 1,
x_domain = list(
list(gamma.mu = c(2, 1))
)
)
mlr3misc::unnest(xdt, "x_domain")Reprex for learner with problematic param
library(mlr3)
library(paradox)
library(mlr3misc)
library(mlr3tuning)
LearnerRegrDebugMulti = R6::R6Class("LearnerRegrDebugMulti", inherit = LearnerRegr,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
super$initialize(
id = "regr.debugmulti",
feature_types = c("logical", "integer", "numeric", "character", "factor", "ordered"),
predict_types = c("response"),
param_set = ps(
x = p_dbl(0, 1, tags = "train"),
# same as in surv.svm
gamma.mu = p_uty(tags = c("train", "required"))
),
properties = "missings",
man = "mlr3::mlr_learners_regr.debugmulti",
label = "Debug Learner for Regression"
)
}
),
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
truth = task$truth()
model = list(
response = mean(truth),
se = sd(truth),
pid = Sys.getpid()
)
set_class(model, "regr.debug_model")
},
.predict = function(task) {
n = task$nrow
pv = self$param_set$get_values(tags = "predict")
predict_types = "response"
prediction = named_list(mlr_reflections$learner_predict_types[["regr"]][[predict_types]])
for (pt in names(prediction)) {
value = rep.int(self$model[[pt]], n)
prediction[[pt]] = value
}
return(prediction)
}
)
)
mlr_learners$add("regr.debugmulti", function() LearnerRegrDebugMulti$new())
lrn_base = lrn("regr.debugmulti", gamma.mu = c(0, 0))
instance = ti(
task = tsk("mtcars"),
learner = lrn_base,
search_space = ps(
gamma = p_dbl(0, 1),
mu = p_dbl(0, 1),
.extra_trafo = function(x, param_set) {
# learner has tuple param gamma.mu = c(x, y)
# we tune separately and reassemble via trafo
x$gamma.mu = c(x$gamma, x$mu)
x$gamma = x$mu = NULL
x
}
),
resampling = rsmp("holdout"),
terminator = trm("evals", n_evals = 3)
)
archive = tnr("grid_search")$optimize(instance)
#> INFO [13:23:15.137] [bbotk] Starting to optimize 2 parameter(s) with '<OptimizerBatchGridSearch>' and '<TerminatorEvals> [n_evals=3, k=0]'
# [...truncated]
as.data.table(instance$archive)
#> Error: Tables have different number of rows (x: 3, y: 6)
# Because of this step
mlr3misc::unnest(archive, "x_domain")
#> Error: Tables have different number of rows (x: 1, y: 2)Created on 2024-09-19 with reprex v2.1.1
Metadata
Metadata
Assignees
Labels
No labels