Skip to content

Commit 0170eac

Browse files
committed
Merge branch 'feat/inner_valid' of github.com:mlr-org/mlr3pipelines into feat/inner_valid
2 parents 5cda404 + 2e12c95 commit 0170eac

File tree

8 files changed

+34
-51
lines changed

8 files changed

+34
-51
lines changed

NAMESPACE

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ S3method(as_pipeop,Filter)
1111
S3method(as_pipeop,Learner)
1212
S3method(as_pipeop,PipeOp)
1313
S3method(as_pipeop,default)
14-
S3method(disable_internal_tuning,GraphLearner)
1514
S3method(marshal_model,Multiplicity)
1615
S3method(marshal_model,graph_learner_model)
1716
S3method(marshal_model,pipeop_impute_learner_state)

R/GraphLearner.R

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,13 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
272272
if (!some_pipeops_validate) {
273273
lg$warn("GraphLearner '%s' specifies a validation set, but none of its Learners use it.", self$id)
274274
}
275+
} else {
276+
# otherwise the pipeops will preprocess this unnecessarily
277+
if (!is.null(task$internal_valid_task)) {
278+
prev_itv = task$internal_valid_task
279+
on.exit({task$internal_valid_task = prev_itv}, add = TRUE)
280+
task$internal_valid_task = NULL
281+
}
275282
}
276283

277284
on.exit({self$graph$state = NULL})
@@ -417,21 +424,6 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
417424
invisible(learner)
418425
}
419426

420-
421-
#' @export
422-
disable_internal_tuning.GraphLearner = function(learner, ids, ...) {
423-
pvs = learner$param_set$values
424-
on.exit({learner$param_set$values = pvs}, add = TRUE)
425-
if (length(ids)) {
426-
walk(learner_wrapping_pipeops(learner), function(po) {
427-
disable_internal_tuning(po$learner, ids = po$param_set$ids()[sprintf("%s.%s", po$id, po$param_set$ids()) %in% ids])
428-
})
429-
}
430-
on.exit()
431-
invisible(learner)
432-
}
433-
434-
435427
#' @export
436428
marshal_model.graph_learner_model = function(model, inplace = FALSE, ...) {
437429
xm = map(.x = model, .f = marshal_model, inplace = inplace, ...)

R/pipeline_branch.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,9 @@ pipeline_branch = function(graphs, prefix_branchops = "", prefix_paths = FALSE)
9191
pmap(list(
9292
src_id = branch_id, dst_id = gin$op.id,
9393
src_channel = branch_chan, dst_channel = gin$channel.name),
94-
graph$add_edge)
94+
graph$add_edge)
9595
})
9696
graph
9797
}
9898

9999
mlr_graphs$add("branch", pipeline_branch)
100-

inst/testthat/helper_functions.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ expect_datapreproc_pipeop_class = function(poclass, constargs = list(), task,
437437
expect_true(task$nrow >= 5)
438438

439439
# overlap between use and test rows
440-
tasktrain$divide(tasktrain$row_roles$use[seq(n_use - 2, n_use)], remove = FALSE)
440+
tasktrain$divide(ids = tasktrain$row_roles$use[seq(n_use - 2, n_use)], remove = FALSE)
441441
tasktrain$row_roles$use = tasktrain$row_roles$use[seq(1, n_use - 2)]
442442

443443
taskpredict = tasktrain$clone(deep = TRUE)

man/mlr_learners_avg.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_pipeops_tunethreshold.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_GraphLearner.R

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ context("GraphLearner")
22

33
test_that("basic graphlearner tests", {
44
skip_if_not_installed("rpart")
5-
skip_on_cran() # takes too long
5+
skip_on_cran() # takes too long
66
task = mlr_tasks$get("iris")
77

88
lrn = mlr_learners$get("classif.rpart")
@@ -40,8 +40,8 @@ test_that("basic graphlearner tests", {
4040
expect_true(run_experiment(task, glrn)$ok)
4141
glrn2$train(task)
4242
glrn2_clone$state = glrn2$state
43-
# glrn2_clone$state$log = glrn2_clone$state$log$clone(deep = TRUE) # FIXME: this can go when mlr-org/mlr3#343 is fixed
44-
# glrn2_clone$state$model$classif.rpart$log = glrn2_clone$state$model$classif.rpart$log$clone(deep = TRUE) # FIXME: this can go when mlr-org/mlr3#343 is fixed
43+
# glrn2_clone$state$log = glrn2_clone$state$log$clone(deep = TRUE) # FIXME: this can go when mlr-org/mlr3#343 is fixed
44+
# glrn2_clone$state$model$classif.rpart$log = glrn2_clone$state$model$classif.rpart$log$clone(deep = TRUE) # FIXME: this can go when mlr-org/mlr3#343 is fixed
4545
expect_deep_clone(glrn2_clone, glrn2$clone(deep = TRUE))
4646
expect_prediction_classif({
4747
graphpred2 = glrn2$predict(task)
@@ -109,7 +109,7 @@ test_that("GraphLearner clone_graph FALSE", {
109109
# check that the GraphLearner predicts what we expect
110110
expect_true(isTRUE(all.equal(gl$predict(tsk("iris")), expected_prediction)))
111111

112-
expect_false(gr1$is_trained) # predicting with GraphLearner resets Graph state
112+
expect_false(gr1$is_trained) # predicting with GraphLearner resets Graph state
113113

114114
expect_identical(gl$graph, gr1)
115115

@@ -177,7 +177,7 @@ test_that("graphlearner parameters behave as they should", {
177177

178178
test_that("graphlearner type inference", {
179179
skip_if_not_installed("rpart")
180-
skip_on_cran() # takes too long
180+
skip_on_cran() # takes too long
181181
# default: classif
182182
lrn = GraphLearner$new(mlr_pipeops$get("nop"))
183183
expect_equal(lrn$task_type, "classif")
@@ -246,15 +246,15 @@ test_that("graphlearner type inference", {
246246

247247
test_that("graphlearner type inference - branched", {
248248
skip_if_not_installed("rpart")
249-
skip_on_cran() # takes too long
249+
skip_on_cran() # takes too long
250250

251251
# default: classif
252252

253253
lrn = GraphLearner$new(gunion(list(
254-
mlr_pipeops$get(id = "l1", "learner", lrn("classif.rpart")),
255-
po("nop") %>>% mlr_pipeops$get(id = "l2", "learner", lrn("classif.rpart"))
254+
mlr_pipeops$get(id = "l1", "learner", lrn("classif.rpart")),
255+
po("nop") %>>% mlr_pipeops$get(id = "l2", "learner", lrn("classif.rpart"))
256256

257-
)) %>>%
257+
)) %>>%
258258
po("classifavg") %>>%
259259
po(id = "n2", "nop"))
260260
expect_equal(lrn$task_type, "classif")
@@ -281,9 +281,9 @@ test_that("graphlearner type inference - branched", {
281281

282282
# inference when multiple input, but one is a Task
283283
lrn = GraphLearner$new(gunion(list(
284-
mlr_pipeops$get(id = "l1", "learner", lrn("regr.rpart")),
285-
po("nop") %>>% mlr_pipeops$get(id = "l2", "learner", lrn("regr.rpart"))
286-
)) %>>%
284+
mlr_pipeops$get(id = "l1", "learner", lrn("regr.rpart")),
285+
po("nop") %>>% mlr_pipeops$get(id = "l2", "learner", lrn("regr.rpart"))
286+
)) %>>%
287287
po("regravg") %>>%
288288
po(id = "n2", "nop"))
289289
expect_equal(lrn$task_type, "regr")
@@ -311,7 +311,7 @@ test_that("graphlearner type inference - branched", {
311311

312312
test_that("graphlearner predict type inference", {
313313
skip_if_not_installed("rpart")
314-
skip_on_cran() # takes too long
314+
skip_on_cran() # takes too long
315315
# Getter:
316316

317317
# Classification
@@ -403,7 +403,9 @@ test_that("graphlearner predict type inference", {
403403
expect_equal(lrn$graph$pipeops[[lrr$id]]$predict_type, "prob")
404404

405405
# Errors:
406-
expect_error({lrrp = po(lrn("classif.featureless", predict_type = "se"))})
406+
expect_error({
407+
lrrp = po(lrn("classif.featureless", predict_type = "se"))
408+
})
407409
})
408410

409411

@@ -439,7 +441,6 @@ test_that("GraphLearner model", {
439441

440442
expect_equal(lr$graph_model$pipeops$classif.rpart$learner_model$importance(), imp)
441443

442-
443444
})
444445

445446
test_that("predict() function for Graph", {
@@ -468,7 +469,6 @@ test_that("predict() function for Graph", {
468469
p1$response
469470
)
470471

471-
472472
})
473473

474474
test_that("base_learner() works", {
@@ -558,20 +558,20 @@ test_that("GraphLearner hashes", {
558558
expect_string(all.equal(po("copy", 2)$hash, po("copy", 3)$hash), "mismatch")
559559

560560

561-
lr1 <- lrn("classif.rpart")
562-
lr2 <- lrn("classif.rpart", fallback = lrn("classif.rpart"))
561+
lr1 = lrn("classif.rpart")
562+
lr2 = lrn("classif.rpart", fallback = lrn("classif.rpart"))
563563

564564
expect_string(all.equal(lr1$hash, lr2$hash), "mismatch")
565565
expect_string(all.equal(lr1$phash, lr2$phash), "mismatch")
566566

567-
lr1 <- as_learner(as_pipeop(lr1))
568-
lr2 <- as_learner(as_pipeop(lr2))
567+
lr1 = as_learner(as_pipeop(lr1))
568+
lr2 = as_learner(as_pipeop(lr2))
569569

570570
expect_string(all.equal(lr1$hash, lr2$hash), "mismatch")
571571
expect_string(all.equal(lr1$phash, lr2$phash), "mismatch")
572572

573-
lr1 <- as_learner(as_pipeop(lr1))
574-
lr2 <- as_learner(as_pipeop(lr2))
573+
lr1 = as_learner(as_pipeop(lr1))
574+
lr2 = as_learner(as_pipeop(lr2))
575575

576576
expect_string(all.equal(lr1$hash, lr2$hash), "mismatch")
577577
expect_string(all.equal(lr1$phash, lr2$phash), "mismatch")
@@ -625,13 +625,6 @@ test_that("internal_tuned_values", {
625625
expect_equal(names(glrn2$internal_tuned_values), "classif.debug.iter")
626626
})
627627

628-
test_that("disable_internal_tuning", {
629-
glrn = as_learner(as_pipeop(lrn("classif.debug", iter = 100, early_stopping = TRUE)))
630-
disable_internal_tuning(glrn, "classif.debug.iter")
631-
expect_false(glrn$graph$pipeops$classif.debug$param_set$values$early_stopping)
632-
expect_error(disable_internal_tuning(glrn, "classif.debug.abc"), "subset of")
633-
})
634-
635628
test_that("set_validate", {
636629
glrn = as_learner(as_pipeop(lrn("classif.debug", validate = 0.3)))
637630
set_validate(glrn, "test")

tests/testthat/test_pipeop_impute.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ test_that("More tests for Integers", {
406406
test_that("impute, test rows and affect_columns", {
407407
po_impute = po("imputeconstant", affect_columns = selector_name("insulin"), constant = 2)
408408
task = tsk("pima")
409-
task$divide(1:30)
409+
task$divide(ids = 1:30)
410410
outtrain = po_impute$train(list(task))[[1L]]
411411
outpredict = po_impute$predict(list(task$internal_valid_task))[[1L]]
412412
expect_true(isTRUE(all.equal(outtrain$internal_valid_task$data(), outpredict$data())))

0 commit comments

Comments
 (0)