Skip to content

Commit a83a279

Browse files
committed
First implementation
1 parent d3a5ebd commit a83a279

File tree

7 files changed

+305
-5
lines changed

7 files changed

+305
-5
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ Imports:
4141
rpart (>= 4.1.16),
4242
ROCR,
4343
Rtsne,
44-
signal
44+
signal,
45+
VGAM
4546
Suggests:
4647
testthat
4748
Remotes:

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ export(mlClassificationBoosting)
5151
export(mlClassificationDecisionTree)
5252
export(mlClassificationKnn)
5353
export(mlClassificationLda)
54+
export(mlClassificationLogistic)
5455
export(mlClassificationNaiveBayes)
5556
export(mlClassificationNeuralNetwork)
5657
export(mlClassificationRandomForest)

R/commonMachineLearningClassification.R

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
if (type == "lda" || type == "randomForest" || type == "boosting") {
6363
# Require at least 2 features
6464
ready <- length(options[["predictors"]][options[["predictors"]] != ""]) >= 2 && options[["target"]] != ""
65-
} else if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm" || type == "naivebayes") {
65+
} else if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm" || type == "naivebayes" || type == "logistic") {
6666
# Require at least 1 features
6767
ready <- length(options[["predictors"]][options[["predictors"]] != ""]) >= 1 && options[["target"]] != ""
6868
}
@@ -93,7 +93,8 @@
9393
"neuralnet" = .neuralnetClassification(dataset, options, jaspResults),
9494
"rpart" = .decisionTreeClassification(dataset, options, jaspResults),
9595
"svm" = .svmClassification(dataset, options, jaspResults),
96-
"naivebayes" = .naiveBayesClassification(dataset, options, jaspResults)
96+
"naivebayes" = .naiveBayesClassification(dataset, options, jaspResults),
97+
"logistic" = .logisticRegressionClassification(dataset, options, jaspResults)
9798
)
9899
})
99100
if (isTryError(p)) { # Fail gracefully
@@ -116,7 +117,8 @@
116117
"neuralnet" = gettext("Neural Network Classification"),
117118
"rpart" = gettext("Decision Tree Classification"),
118119
"svm" = gettext("Support Vector Machine Classification"),
119-
"naivebayes" = gettext("Naive Bayes Classification")
120+
"naivebayes" = gettext("Naive Bayes Classification"),
121+
"logistic" = gettext("Logistic / Multinomial Regression")
120122
)
121123
tableTitle <- gettextf("Model Summary: %1$s", title)
122124
table <- createJaspTable(tableTitle)
@@ -147,6 +149,8 @@
147149
table$addColumnInfo(name = "vectors", title = gettext("Support Vectors"), type = "integer")
148150
} else if (type == "naivebayes") {
149151
table$addColumnInfo(name = "smoothing", title = gettext("Smoothing"), type = "number")
152+
} else if (type == "logistic") {
153+
table$addColumnInfo(name = "family", title = gettext("Family"), type = "string")
150154
}
151155
# Add common columns
152156
table$addColumnInfo(name = "nTrain", title = gettext("n(Train)"), type = "integer")
@@ -164,7 +168,7 @@
164168
}
165169
# If no analysis is run, specify the required variables in a footnote
166170
if (!ready) {
167-
table$addFootnote(gettextf("Please provide a target variable and at least %i feature variable(s).", if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm") 1L else 2L))
171+
table$addFootnote(gettextf("Please provide a target variable and at least %i feature variable(s).", if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm" || type == "logistic") 1L else 2L))
168172
}
169173
if (options[["savePath"]] != "") {
170174
validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0)
@@ -312,6 +316,14 @@
312316
testAcc = classificationResult[["testAcc"]]
313317
)
314318
table$addRows(row)
319+
} else if (type == "logistic") {
320+
row <- data.frame(
321+
family = classificationResult[["family"]],
322+
nTrain = nTrain,
323+
nTest = classificationResult[["ntest"]],
324+
testAcc = classificationResult[["testAcc"]]
325+
)
326+
table$addRows(row)
315327
}
316328
# Save the applied model if requested
317329
if (options[["saveModel"]] && options[["savePath"]] != "") {

R/mlClassificationLogistic.R

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#
2+
# Copyright (C) 2013-2021 University of Amsterdam
3+
#
4+
# This program is free software: you can redistribute it and/or modify
5+
# it under the terms of the GNU General Public License as published by
6+
# the Free Software Foundation, either version 2 of the License, or
7+
# (at your option) any later version.
8+
#
9+
# This program is distributed in the hope that it will be useful,
10+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
# GNU General Public License for more details.
13+
#
14+
# You should have received a copy of the GNU General Public License
15+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
16+
#
17+
18+
mlClassificationLogistic <- function(jaspResults, dataset, options, ...) {
19+
20+
# Preparatory work
21+
dataset <- .mlClassificationReadData(dataset, options)
22+
.mlClassificationErrorHandling(dataset, options, type = "logistic")
23+
24+
# Check if analysis is ready to run
25+
ready <- .mlClassificationReady(options, type = "logistic")
26+
27+
# Compute results and create the model summary table
28+
.mlClassificationTableSummary(dataset, options, jaspResults, ready, position = 1, type = "logistic")
29+
30+
# If the user wants to add the classes to the data set
31+
.mlClassificationAddPredictionsToData(dataset, options, jaspResults, ready)
32+
33+
# Add test set indicator to data
34+
.mlAddTestIndicatorToData(options, jaspResults, ready, purpose = "classification")
35+
36+
# Create the data split plot
37+
.mlPlotDataSplit(dataset, options, jaspResults, ready, position = 2, purpose = "classification", type = "logistic")
38+
39+
# Create the confusion table
40+
.mlClassificationTableConfusion(dataset, options, jaspResults, ready, position = 3)
41+
42+
# Create the class proportions table
43+
.mlClassificationTableProportions(dataset, options, jaspResults, ready, position = 4)
44+
45+
# Create the validation measures table
46+
.mlClassificationTableMetrics(dataset, options, jaspResults, ready, position = 5)
47+
48+
# # Create the variable importance table
49+
# .mlTableFeatureImportance(options, jaspResults, ready, position = 6, purpose = "classification")
50+
51+
# # Create the shap table
52+
# .mlTableShap(dataset, options, jaspResults, ready, position = 7, purpose = "classification")
53+
54+
# # Create the ROC curve
55+
# .mlClassificationPlotRoc(dataset, options, jaspResults, ready, position = 8, type = "logistic")
56+
57+
# Create the Andrews curves
58+
.mlClassificationPlotAndrews(dataset, options, jaspResults, ready, position = 9)
59+
60+
# # Decision boundaries
61+
# .mlClassificationPlotBoundaries(dataset, options, jaspResults, ready, position = 10, type = "logistic")
62+
}
63+
64+
.logisticRegressionClassification <- function(dataset, options, jaspResults, ready) {
65+
# Import model formula from jaspResults
66+
formula <- jaspResults[["formula"]]$object
67+
# Split the data into training and test sets
68+
if (options[["holdoutData"]] == "testSetIndicator" && options[["testSetIndicatorVariable"]] != "") {
69+
# Select observations according to a user-specified indicator (included when indicator = 1)
70+
trainingIndex <- which(dataset[, options[["testSetIndicatorVariable"]]] == 0)
71+
} else {
72+
# Sample a percentage of the total data set
73+
trainingIndex <- sample.int(nrow(dataset), size = ceiling((1 - options[["testDataManual"]]) * nrow(dataset)))
74+
}
75+
trainingSet <- dataset[trainingIndex, ]
76+
# Create the generated test set indicator
77+
testIndicatorColumn <- rep(1, nrow(dataset))
78+
testIndicatorColumn[trainingIndex] <- 0
79+
# Just create a train and a test set (no optimization)
80+
testSet <- dataset[-trainingIndex, ]
81+
if (nlevels(trainingSet[[options[["target"]]]]) == 2) {
82+
family = "binomial"
83+
trainingFit <- stats::glm(formula, data = trainingSet, family = family)
84+
# Use the specified model to make predictions for dataset
85+
testPredictions <- levels(trainingSet[[options[["target"]]]])[round(predict(trainingFit, newdata = testSet, type = "response"), 0) + 1]
86+
dataPredictions <- levels(trainingSet[[options[["target"]]]])[round(predict(trainingFit, newdata = dataset, type = "response"), 0) + 1]
87+
} else {
88+
family <- "multinomial"
89+
trainingFit <- VGAM::vglm(formula, data = trainingSet, family = family)
90+
# Use the specified model to make predictions for dataset
91+
testPredictions <- .mlClassificationMultinomialPredictions(trainingSet, options, predict(trainingFit, newdata = testSet))
92+
dataPredictions <- .mlClassificationMultinomialPredictions(trainingSet, options, predict(trainingFit, newdata = dataset))
93+
}
94+
# Create results object
95+
result <- list()
96+
result[["formula"]] <- formula
97+
result[["family"]] <- family
98+
result[["model"]] <- trainingFit
99+
result[["confTable"]] <- table("Pred" = testPredictions, "Real" = testSet[, options[["target"]]])
100+
result[["testAcc"]] <- sum(diag(prop.table(result[["confTable"]])))
101+
# result[["auc"]] <- .classificationCalcAUC(testSet, trainingSet, options, "logisticClassification")
102+
result[["ntrain"]] <- nrow(trainingSet)
103+
result[["ntest"]] <- nrow(testSet)
104+
result[["testReal"]] <- testSet[, options[["target"]]]
105+
result[["testPred"]] <- testPredictions
106+
result[["train"]] <- trainingSet
107+
result[["test"]] <- testSet
108+
result[["testIndicatorColumn"]] <- testIndicatorColumn
109+
result[["classes"]] <- dataPredictions
110+
# result[["explainer"]] <- DALEX::explain(result[["model"]], type = "classification", data = result[["train"]], y = result[["train"]][, options[["target"]]], predict_function = function(model, data) predict(model, newdata = data, type = "raw"))
111+
# if (nlevels(result[["testReal"]]) == 2) {
112+
# result[["explainer_fi"]] <- DALEX::explain(result[["model"]], type = "classification", data = result[["train"]], y = as.numeric(result[["train"]][, options[["target"]]]) - 1, predict_function = function(model, data) predict(model, newdata = data, type = "class"))
113+
# } else {
114+
# result[["explainer_fi"]] <- DALEX::explain(result[["model"]], type = "multiclass", data = result[["train"]], y = result[["train"]][, options[["target"]]] , predict_function = function(model, data) predict(model, newdata = data, type = "raw"))
115+
# }
116+
return(result)
117+
}
118+
119+
.mlClassificationMultinomialPredictions <- function(trainingSet, options, predictions) {
120+
num_categories <- ncol(predictions) + 1
121+
probs <- matrix(0, nrow = nrow(predictions), ncol = num_categories)
122+
for (i in 1:(num_categories - 1)) {
123+
probs[, i] <- exp(predictions[, i])
124+
}
125+
probs[, num_categories] <- 1
126+
row_sums <- rowSums(probs)
127+
probs <- probs / row_sums
128+
predicted_category <- apply(probs, 1, which.max)
129+
categories <- levels(trainingSet[[options[["target"]]]])
130+
predicted_categories <- categories[predicted_category]
131+
return(predicted_categories)
132+
}

inst/Description.qml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ Description
9898
func: "mlClassificationLda"
9999
}
100100
Analysis
101+
{
102+
menu: qsTr("Logistic / Multinomial")
103+
title: qsTr("Logistic / Multinomial Classification")
104+
func: "mlClassificationLogistic"
105+
}
106+
Analysis
101107
{
102108
menu: qsTr("Naive Bayes")
103109
title: qsTr("Naive Bayes Classification")
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//
2+
// Copyright (C) 2013-2021 University of Amsterdam
3+
//
4+
// This program is free software: you can redistribute it and/or modify
5+
// it under the terms of the GNU Affero General Public License as
6+
// published by the Free Software Foundation, either version 3 of the
7+
// License, or (at your option) any later version.
8+
//
9+
// This program is distributed in the hope that it will be useful,
10+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
// GNU Affero General Public License for more details.
13+
//
14+
// You should have received a copy of the GNU Affero General Public
15+
// License along with this program. If not, see
16+
// <http://www.gnu.org/licenses/>.
17+
//
18+
19+
import QtQuick 2.8
20+
import QtQuick.Layouts 1.3
21+
import JASP.Controls 1.0
22+
import JASP.Widgets 1.0
23+
24+
import "./common/ui" as UI
25+
import "./common/tables" as TAB
26+
import "./common/figures" as FIG
27+
28+
Form
29+
{
30+
info: qsTr("Logistic regression.")
31+
32+
UI.VariablesFormClassification { id: vars }
33+
34+
Group
35+
{
36+
title: qsTr("Tables")
37+
38+
TAB.ConfusionMatrix { }
39+
TAB.ClassProportions { }
40+
TAB.ModelPerformance { }
41+
TAB.FeatureImportance { }
42+
TAB.ExplainPredictions { }
43+
}
44+
45+
Group
46+
{
47+
title: qsTr("Plots")
48+
49+
FIG.DataSplit { }
50+
FIG.RocCurve { }
51+
FIG.AndrewsCurve { }
52+
FIG.DecisionBoundary { }
53+
}
54+
55+
UI.ExportResults { enabled: vars.predictorCount > 0 && vars.targetCount > 0 }
56+
UI.DataSplit { trainingValidationSplit: false }
57+
58+
Section
59+
{
60+
title: qsTr("Training Parameters")
61+
62+
Group
63+
{
64+
title: qsTr("Algorithmic Settings")
65+
66+
UI.ScaleVariables { }
67+
UI.SetSeed { }
68+
}
69+
70+
RadioButtonGroup
71+
{
72+
name: "modelOptimization"
73+
visible: false
74+
75+
RadioButton
76+
{
77+
name: "manual"
78+
checked: true
79+
}
80+
}
81+
}
82+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
context("Machine Learning Logistic Regression Classification")
2+
3+
# Test fixed model #############################################################
4+
options <- initMlOptions("mlClassificationLogistic")
5+
options$addIndicator <- FALSE
6+
options$addPredictions <- FALSE
7+
options$classProportionsTable <- TRUE
8+
options$holdoutData <- "holdoutManual"
9+
options$modelOptimization <- "manual"
10+
options$modelValid <- "validationManual"
11+
options$predictionsColumn <- ""
12+
options$predictors <- c("Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width")
13+
options$predictors.types <- rep("scale", 4)
14+
options$saveModel <- FALSE
15+
options$savePath <- ""
16+
options$setSeed <- TRUE
17+
options$target <- "Species"
18+
options$target.types <- "nominal"
19+
options$testDataManual <- 0.2
20+
options$testIndicatorColumn <- ""
21+
options$testSetIndicatorVariable <- ""
22+
options$validationDataManual <- 0.2
23+
options$validationMeasures <- TRUE
24+
options$tableShap <- TRUE
25+
options$fromIndex <- 1
26+
options$toIndex <- 5
27+
options$featureImportanceTable <- TRUE
28+
set.seed(1)
29+
results <- jaspTools::runAnalysis("mlClassificationLogistic", "iris.csv", options)
30+
31+
test_that("Class Proportions table results match", {
32+
table <- results[["results"]][["classProportionsTable"]][["data"]]
33+
jaspTools::expect_equal_tables(table,
34+
list(0.333333333333333, "setosa", 0.333333333333333, 0.333333333333333,
35+
0.333333333333333, "versicolor", 0.266666666666667, 0.35, 0.333333333333333,
36+
"virginica", 0.4, 0.316666666666667))
37+
})
38+
39+
test_that("Model Summary: Logistic / Multinomial Regression table results match", {
40+
table <- results[["results"]][["classificationTable"]][["data"]]
41+
jaspTools::expect_equal_tables(table,
42+
list("multinomial", 30, 120, 1))
43+
})
44+
45+
test_that("Confusion Matrix table results match", {
46+
table <- results[["results"]][["confusionTable"]][["data"]]
47+
jaspTools::expect_equal_tables(table,
48+
list("Observed", "setosa", 10, 0, 0, "", "versicolor", 0, 8, 0, "",
49+
"virginica", 0, 0, 12))
50+
})
51+
52+
test_that("Data Split plot matches", {
53+
plotName <- results[["results"]][["plotDataSplit"]][["data"]]
54+
testPlot <- results[["state"]][["figures"]][[plotName]][["obj"]]
55+
jaspTools::expect_equal_plots(testPlot, "data-split")
56+
})
57+
58+
test_that("Model Performance Metrics table results match", {
59+
table <- results[["results"]][["validationMeasures"]][["data"]]
60+
jaspTools::expect_equal_tables(table,
61+
list(1, "", 1, 0, 0, 0, 0, "setosa", 1, 1, 1, 1, 0.333333333333333,
62+
10, 1, "<unicode>", 1, 1, 0, 0, 0, 0, "versicolor", 1, 1, 1,
63+
1, 0.266666666666667, 8, 1, "<unicode>", 1, 1, 0, 0, 0, 0, "virginica",
64+
1, 1, 1, 1, 0.4, 12, 1, "<unicode>", 1, 1, 0, 0, 0, 0, "Average / Total",
65+
1, 1, 1, 1, 1, 30, 1, "<unicode>"))
66+
})

0 commit comments

Comments
 (0)