Skip to content

Commit e32f163

Browse files
committed
Merge branch 'main' into release
2 parents 280fb34 + 4ef052a commit e32f163

11 files changed

+28
-24
lines changed

R/HotstartStack.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ HotstartStack = R6Class("HotstartStack",
138138
hotstart_id = learner$param_set$ids(tags = "hotstart")
139139

140140
set(self$stack, j = "cost", value = NA_real_)
141-
cost = self$stack[list(.task_hash, .learner_hash), "cost" := map_dbl(get("start_learner"), function(l) calculate_cost(l, learner, hotstart_id)) , on = c("task_hash", "learner_hash")
141+
cost = self$stack[list(.task_hash, .learner_hash), "cost" := map_dbl(get("start_learner"), function(l) calculate_cost(l, learner, hotstart_id)), on = c("task_hash", "learner_hash")
142142
][, get("cost")]
143143
self$stack[, "cost" := NULL]
144144
cost

R/Learner.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ Learner = R6Class("Learner",
275275

276276
pred_typs = replace(self$predict_types, self$predict_types == self$predict_type, paste0("[", self$predict_type, "]"))
277277
encapsulation = self$encapsulation[[1L]]
278-
fallback = if (encapsulation != 'none') class(self$fallback)[[1L]] else "-"
278+
fallback = if (encapsulation != "none") class(self$fallback)[[1L]] else "-"
279279

280280
cat_cli({
281281
cli_li("Predict Types: {pred_typs}")

R/LearnerClassifDebug.R

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,11 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
196196
}
197197

198198
model = list(
199-
response = as.character(sample(task$truth(), 1L, prob = private$.get_weights(task))),
200-
pid = Sys.getpid(),
201-
id = UUIDgenerate(),
202-
random_number = sample(100000, 1),
203-
iter = if (isTRUE(pv$early_stopping))
204-
sample(pv$iter %??% 1L, 1L)
205-
else
206-
pv$iter %??% 1L
199+
response = as.character(sample(task$truth(), 1L, prob = private$.get_weights(task))),
200+
pid = Sys.getpid(),
201+
id = UUIDgenerate(),
202+
random_number = sample(100000, 1),
203+
iter = if (isTRUE(pv$early_stopping)) sample(pv$iter %??% 1L, 1L) else pv$iter %??% 1L
207204
)
208205

209206
if (!is.null(valid_truth)) {

R/LearnerRegrDebug.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
6464
#' @return Named `numeric()`.
6565
importance = function() {
6666
if (is.null(self$model)) {
67-
error_input("No model stored")
67+
error_input("No model stored")
6868
}
6969
fns = self$state$feature_names
7070
set_names(rep(0, length(fns)), fns)

R/MeasureRegrPinball.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
MeasureRegrPinball = R6Class("MeasureRegrPinball",
2929
inherit = MeasureRegr,
3030
public = list(
31-
#' @description
32-
#' Creates a new instance of this [R6][R6::R6Class] class.
31+
#' @description
32+
#' Creates a new instance of this [R6][R6::R6Class] class.
3333
initialize = function(alpha = 0.5) {
3434
param_set = ps(alpha = p_dbl(lower = 0, upper = 1))
3535
param_set$set_values(alpha = alpha)

R/MeasureRegrRQR.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ MeasureRegrRQR = R6Class("MeasureRQR",
7878
truth = prediction$truth,
7979
response = prediction$data$quantiles[, which(probs == alpha)],
8080
alpha = alpha
81-
)
81+
)
8282
)
8383

8484
denominator = sum(
@@ -89,8 +89,8 @@ MeasureRegrRQR = R6Class("MeasureRQR",
8989
)
9090
)
9191

92-
1 - (numerator / denominator)
93-
}
92+
1 - (numerator / denominator)
93+
}
9494
)
9595
)
9696

R/PredictionClassif.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ PredictionClassif = R6Class("PredictionClassif", inherit = Prediction,
104104
weights = NULL,
105105
check = TRUE,
106106
extra = NULL
107-
) {
107+
) {
108108

109109
pdata = new_prediction_data(
110110
list(row_ids = row_ids, truth = truth, response = response, prob = prob, weights = weights, extra = extra),

R/PredictionDataRegr.R

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,18 @@ c.PredictionDataRegr = function(..., keep_duplicates = TRUE) { # nolint
110110
error_input("Cannot rbind predictions: Some predictions have extra data, others do not")
111111
}
112112

113-
elems = c("row_ids", "truth", intersect(predict_types[[1L]], c("response", "se")), if ("weights" %chin% names(dots[[1L]])) "weights")
113+
nn = names(dots[[1L]])
114+
elems = c("row_ids", "truth", intersect(predict_types[[1L]], c("response", "se")), if ("weights" %chin% nn) "weights")
114115
tab = map_dtr(dots, function(x) x[elems], .fill = FALSE)
115-
quantiles = do.call(rbind, map(dots, "quantiles"))
116116

117-
extra = if ("extra" %chin% names(dots[[1L]])) {
117+
quantiles = if ("quantiles" %chin% nn) {
118+
quantiles = map(dots, "quantiles")
119+
attrs = attributes(quantiles[[1L]])
120+
quantiles = do.call(rbind, quantiles)
121+
setattr(quantiles, "probs", attrs$props)
122+
setattr(quantiles, "response", attrs$response)
123+
}
124+
extra = if ("extra" %chin% nn) {
118125
rbindlist(map(dots, "extra"), fill = TRUE, use.names = TRUE)
119126
}
120127

R/PredictionRegr.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
6565
weights = NULL,
6666
check = TRUE,
6767
extra = NULL
68-
) {
68+
) {
6969
pdata = new_prediction_data(
7070
list(row_ids = row_ids, truth = truth, response = response, se = se, quantiles = quantiles, distr = distr, weights = weights, extra = extra),
7171
task_type = "regr"

R/as_result_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ as_result_data = function(
5050
learner_states = NULL,
5151
data_extra = NULL,
5252
store_backends = TRUE
53-
) {
53+
) {
5454
assert_task(task)
5555
assert_learners(learners, task = task)
5656
assert_resampling(resampling, instantiated = TRUE)

0 commit comments

Comments
 (0)