Skip to content

Commit bedb629

Browse files
authored
Merge pull request #264 from mlr-org/nnet
feat: allow formula as argument for nnet learner
2 parents 856d1d0 + efe7559 commit bedb629

File tree

7 files changed

+43
-13
lines changed

7 files changed

+43
-13
lines changed

NEWS.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
# mlr3learners 0.5.5
1+
# mlr3learners 0.5.6-9000
2+
3+
* Added formula argument to `nnet` learner and support feature type `"integer"`
4+
5+
# mlr3learners 0.5.6
26

37
- Enable new early stopping mechanism for xgboost.
48
- Improved documentation.

R/LearnerClassifNnet.R

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
#' - Adjusted default: 3L.
1818
#' - Reason for change: no default in `nnet()`.
1919
#'
20+
#' @section Custom mlr3 parameters:
21+
#' - `formula`: if not provided, the formula is set to `task$formula()`.
22+
#'
2023
#' @references
2124
#' `r format_bib("ripley_1996")`
2225
#'
@@ -46,14 +49,15 @@ LearnerClassifNnet = R6Class("LearnerClassifNnet",
4649
size = p_int(0L, default = 3L, tags = "train"),
4750
skip = p_lgl(default = FALSE, tags = "train"),
4851
subset = p_uty(tags = "train"),
49-
trace = p_lgl(default = TRUE, tags = "train")
52+
trace = p_lgl(default = TRUE, tags = "train"),
53+
formula = p_uty(tags = "train")
5054
)
5155
ps$values = list(size = 3L)
5256

5357
super$initialize(
5458
id = "classif.nnet",
5559
packages = c("mlr3learners", "nnet"),
56-
feature_types = c("numeric", "factor", "ordered"),
60+
feature_types = c("numeric", "factor", "ordered", "integer"),
5761
predict_types = c("prob", "response"),
5862
param_set = ps,
5963
properties = c("twoclass", "multiclass", "weights"),
@@ -68,9 +72,11 @@ LearnerClassifNnet = R6Class("LearnerClassifNnet",
6872
if ("weights" %in% task$properties) {
6973
pv = insert_named(pv, list(weights = task$weights$weight))
7074
}
71-
f = task$formula()
75+
if (is.null(pv$formula)) {
76+
pv$formula = task$formula()
77+
}
7278
data = task$data()
73-
invoke(nnet::nnet.formula, formula = f, data = data, .args = pv)
79+
invoke(nnet::nnet.formula, data = data, .args = pv)
7480
},
7581

7682
.predict = function(task) {

R/LearnerRegrNnet.R

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
#' - Adjusted default: 3L.
1818
#' - Reason for change: no default in `nnet()`.
1919
#'
20+
#' @section Custom mlr3 parameters:
21+
#' - `formula`: if not provided, the formula is set to `task$formula()`.
22+
#'
2023
#' @references
2124
#' `r format_bib("ripley_1996")`
2225
#'
@@ -46,14 +49,15 @@ LearnerRegrNnet = R6Class("LearnerRegrNnet",
4649
size = p_int(0L, default = 3L, tags = "train"),
4750
skip = p_lgl(default = FALSE, tags = "train"),
4851
subset = p_uty(tags = "train"),
49-
trace = p_lgl(default = TRUE, tags = "train")
52+
trace = p_lgl(default = TRUE, tags = "train"),
53+
formula = p_uty(tags = "train")
5054
)
5155
ps$values = list(size = 3L)
5256

5357
super$initialize(
5458
id = "regr.nnet",
5559
packages = c("mlr3learners", "nnet"),
56-
feature_types = c("numeric", "factor", "ordered"),
60+
feature_types = c("numeric", "factor", "ordered", "integer"),
5761
predict_types = c("response"),
5862
param_set = ps,
5963
properties = c("weights"),
@@ -68,10 +72,12 @@ LearnerRegrNnet = R6Class("LearnerRegrNnet",
6872
if ("weights" %in% task$properties) {
6973
pv = insert_named(pv, list(weights = task$weights$weight))
7074
}
71-
f = task$formula()
75+
if (is.null(pv$formula)) {
76+
pv$formula = task$formula()
77+
}
7278
data = task$data()
7379
# force linout = TRUE for regression
74-
invoke(nnet::nnet.formula, formula = f, data = data, linout = TRUE, .args = pv)
80+
invoke(nnet::nnet.formula, data = data, linout = TRUE, .args = pv)
7581
},
7682

7783
.predict = function(task) {

inst/paramtest/test_paramtest_classif.nnet.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ test_that("classif.nnet", {
77
"x", # handled via mlr3
88
"y", # handled via mlr3
99
"weights", # handled via mlr3
10-
"formula", # handled via mlr3
1110
"data", # handled via mlr3
1211
"entropy", # automatically set to TRUE if two-class task
1312
"softmax", # automatically set to TRUE if multi-class task

inst/paramtest/test_paramtest_regr.nnet.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ test_that("regr.nnet", {
77
"x", # handled via mlr3
88
"y", # handled via mlr3
99
"weights", # handled via mlr3
10-
"formula", # handled via mlr3
1110
"data", # handled via mlr3
1211
"linout", # automatically set to TRUE, since it's the regression learner
1312
"entropy", # mutually exclusive with linout

man/mlr_learners_classif.nnet.Rd

Lines changed: 9 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_learners_regr.nnet.Rd

Lines changed: 9 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)