Skip to content

Commit 2e24f7f

Browse files
authored
fix: weights_measure and stratum (#1406)
* fix: weights_measure and stratum * ...
1 parent 361a37c commit 2e24f7f

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

R/as_prediction_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ as_prediction_data.list = function(x, task, row_ids = task$row_ids, check = TRUE
4545
}
4646

4747
if ("weights_measure" %chin% task$properties) {
48-
x$weights = task$weights_measure[list(row_ids), "weight"][[1L]]
48+
x$weights = task$weights_measure[list(row_id = row_ids), on = "row_id", "weight"][[1L]]
4949
}
5050

5151
task = if (task$task_type == "unsupervised") train_task else task

tests/testthat/test_Task.R

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,3 +1020,13 @@ test_that("materialize_view works with duplicates", {
10201020
task2$materialize_view()
10211021
expect_equal(task$data(), task2$data())
10221022
})
1023+
1024+
test_that("weights_measure + stratum works during resampling (#1405)", {
1025+
data = cbind(datasets::iris, data.frame(w = rep(c(1, 10, 100), each = 50)))
1026+
# 150 rows works, but 151 rows fails
1027+
data = data[c(seq(150), 1), ]
1028+
task = TaskClassif$new("iris_weights_measure", as_data_backend(data, target = "Species"), target = "Species")
1029+
task$set_col_roles("w", "weights_measure")
1030+
task$set_col_roles("Species", roles = c("target", "stratum"))
1031+
expect_resample_result(resample(task, lrn("classif.featureless"), rsmp("cv", folds = 3)))
1032+
})

0 commit comments

Comments
 (0)