Skip to content

Commit 280fb34

Browse files
committed
primary iters
1 parent 02a82af commit 280fb34

File tree

1 file changed

+15
-24
lines changed

1 file changed

+15
-24
lines changed

tests/testthat/test_Measure.R

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -275,38 +275,29 @@ test_that("measure weights", {
275275

276276
})
277277

278-
279-
280-
281278
test_that("primary iters are respected", {
282279
task = tsk("sonar")
283-
resampling = rsmp("cv")$instantiate(task)
284-
train_sets = map(1:10, function(i) resampling$train_set(i))
285-
test_sets = map(1:10, function(i) resampling$train_set(i))
286-
r1 = rsmp("custom")$instantiate(task, train_sets = train_sets, test_sets = test_sets)
287-
get_private(r1, ".primary_iters") = 1:2
288-
r2 = rsmp("custom")$instantiate(task, train_sets = train_sets[1:2], test_sets = test_sets[1:2])
289-
r3 = rsmp("custom")$instantiate(task, train_sets = train_sets, test_sets = test_sets)
290-
291280
learner = lrn("classif.rpart", predict_type = "prob")
281+
resampling = rsmp("cv", folds = 10)
282+
resampling$instantiate(task)
283+
get_private(resampling, ".primary_iters") = 1:2
292284

293-
rr1 = resample(task, learner, r1, store_models = TRUE)
294-
rr2 = resample(task, learner, r2, store_models = TRUE)
295-
rr3 = resample(task, learner, r3, store_models = TRUE)
296-
285+
rr = resample(task, learner, resampling)
297286
m = msr("classif.acc")
298287
m$average = "macro"
299-
expect_equal(rr1$aggregate(), rr2$aggregate())
300-
m$average = "micro"
301-
expect_equal(rr1$aggregate(), rr2$aggregate())
288+
scores = rr$score(m)$classif.acc
289+
290+
# macro aggregation
291+
expect_equal(unname(rr$aggregate(m)), mean(scores[1:2]))
292+
expect_true(unname(rr$aggregate(m)) != mean(scores))
293+
294+
# micro aggregation
295+
pred_micro = do.call(c, rr$predictions()[1:2])
296+
scores = pred_micro$score(m)
297+
expect_equal(unname(m$score(pred_micro)), unname(scores))
302298

303299
jaccard = msr("sim.jaccard")
304-
expect_error(rr1$aggregate(jaccard), "primary_iters")
305-
expect_no_error(rr2$aggregate(jaccard))
306-
jaccard$properties = c(jaccard$properties, "primary_iters")
307-
x1 = rr1$aggregate(jaccard)
308-
x2 = rr3$aggregate(jaccard)
309-
expect_equal(x1, x2)
300+
expect_error(rr$aggregate(jaccard), "primary_iters")
310301
})
311302

312303
test_that("no predict_sets required (#1094)", {

0 commit comments

Comments
 (0)