Skip to content

Commit 7c86ce2

Browse files
committed
Add intercept option
1 parent ac4d14e commit 7c86ce2

File tree

4 files changed

+11
-2
lines changed

4 files changed

+11
-2
lines changed

R/commonMachineLearningClassification.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
"mutationMethod", "survivalMethod", "elitismProportion", "candidates", # Neural network
3333
"noOfTrees", "maxTrees", "baggingFraction", "noOfPredictors", "numberOfPredictors", # Random forest
3434
"complexityParameter", "degree", "gamma", "cost", "tolerance", "epsilon", "maxCost", # Support vector machine
35-
"smoothingParameter" # Naive Bayes
35+
"smoothingParameter", # Naive Bayes
36+
"intercept" # Logistic
3637
)
3738
if (includeSaveOptions) {
3839
opt <- c(opt, "saveModel", "savePath")

R/mlClassificationLogistic.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ mlClassificationLogistic <- function(jaspResults, dataset, options, ...) {
7878
testIndicatorColumn[trainingIndex] <- 0
7979
# Just create a train and a test set (no optimization)
8080
testSet <- dataset[-trainingIndex, ]
81+
# Create the formula
82+
if (options[["intercept"]]) {
83+
formula <- formula(paste(options[["target"]], "~ 1 + ", paste(options[["predictors"]], collapse = " + ")))
84+
} else {
85+
formula <- formula(paste(options[["target"]], "~ 0 + ", paste(options[["predictors"]], collapse = " + ")))
86+
}
8187
if (nlevels(trainingSet[[options[["target"]]]]) == 2) {
8288
family = "binomial"
8389
trainingFit <- stats::glm(formula, data = trainingSet, family = family)

inst/qml/mlClassificationLogistic.qml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import JASP.Widgets 1.0
2424
import "./common/ui" as UI
2525
import "./common/tables" as TAB
2626
import "./common/figures" as FIG
27+
import "./common/analyses/regularized" as REGU
2728

2829
Form
2930
{
@@ -63,6 +64,7 @@ Form
6364
{
6465
title: qsTr("Algorithmic Settings")
6566

67+
REGU.Intercept { }
6668
UI.ScaleVariables { }
6769
UI.SetSeed { }
6870
}

tests/testthat/helper-ml.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ mlOptions <- function(analysis) {
2727
files <- c(files, list.files(testthat::test_path("..", "..", "inst", "qml", "common", "analyses", "randomforest"), full.names = TRUE))
2828
} else if (analysis %in% c("mlClassificationSvm", "mlRegressionSvm")) {
2929
files <- c(files, list.files(testthat::test_path("..", "..", "inst", "qml", "common", "analyses", "svm"), full.names = TRUE))
30-
} else if (analysis %in% c("mlRegressionLinear", "mlRegressionRegularized")) {
30+
} else if (analysis %in% c("mlClassificationLogistic", "mlRegressionLinear", "mlRegressionRegularized")) {
3131
files <- c(files, list.files(testthat::test_path("..", "..", "inst", "qml", "common", "analyses", "regularized"), full.names = TRUE))
3232
}
3333
options <- lapply(files, jaspTools:::readQML) |>

0 commit comments

Comments
 (0)