Skip to content

Commit fb42caf

Browse files
authored
Logistic / Multinomial regression (#369)
* First implementation * Update mlClassificationLogistic.R * Add intercept option * Add coefficients table for logistic * Update * Change file names * Update mlClassificationLogisticMultinomial.R * Update coef table * Make AUC work * Make roc plot work * Make decision boundary work * Start with prediction * Prediction for logit regression * Predictions for multinomial * Formula for multinomial * Update table content * explainers * Update mlClassificationLogisticMultinomial.R * Update test-mlclassificationlogisticmultinomial.R * Add info text * Update * Update mlClassificationLogisticMultinomial.R * coef table for multinomial * Add test for logistic * Support different link functions * Change decode method * Complementary log log * Fix little bug in regularized linear regression where intercept was incorrectly shown in equation * Ensure first "+" is not shown if there is no intercept * Fix bug for single variable logistic without intercept
1 parent d3a5ebd commit fb42caf

13 files changed

+807
-11
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ Imports:
4141
rpart (>= 4.1.16),
4242
ROCR,
4343
Rtsne,
44-
signal
44+
signal,
45+
VGAM
4546
Suggests:
4647
testthat
4748
Remotes:

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ export(mlClassificationBoosting)
5151
export(mlClassificationDecisionTree)
5252
export(mlClassificationKnn)
5353
export(mlClassificationLda)
54+
export(mlClassificationLogisticMultinomial)
5455
export(mlClassificationNaiveBayes)
5556
export(mlClassificationNeuralNetwork)
5657
export(mlClassificationRandomForest)

R/commonMachineLearningClassification.R

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
"mutationMethod", "survivalMethod", "elitismProportion", "candidates", # Neural network
3333
"noOfTrees", "maxTrees", "baggingFraction", "noOfPredictors", "numberOfPredictors", # Random forest
3434
"complexityParameter", "degree", "gamma", "cost", "tolerance", "epsilon", "maxCost", # Support vector machine
35-
"smoothingParameter" # Naive Bayes
35+
"smoothingParameter", # Naive Bayes
36+
"intercept", "link" # Logistic
3637
)
3738
if (includeSaveOptions) {
3839
opt <- c(opt, "saveModel", "savePath")
@@ -62,7 +63,7 @@
6263
if (type == "lda" || type == "randomForest" || type == "boosting") {
6364
# Require at least 2 features
6465
ready <- length(options[["predictors"]][options[["predictors"]] != ""]) >= 2 && options[["target"]] != ""
65-
} else if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm" || type == "naivebayes") {
66+
} else if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm" || type == "naivebayes" || type == "logistic") {
6667
# Require at least 1 features
6768
ready <- length(options[["predictors"]][options[["predictors"]] != ""]) >= 1 && options[["target"]] != ""
6869
}
@@ -93,7 +94,8 @@
9394
"neuralnet" = .neuralnetClassification(dataset, options, jaspResults),
9495
"rpart" = .decisionTreeClassification(dataset, options, jaspResults),
9596
"svm" = .svmClassification(dataset, options, jaspResults),
96-
"naivebayes" = .naiveBayesClassification(dataset, options, jaspResults)
97+
"naivebayes" = .naiveBayesClassification(dataset, options, jaspResults),
98+
"logistic" = .logisticMultinomialClassification(dataset, options, jaspResults)
9799
)
98100
})
99101
if (isTryError(p)) { # Fail gracefully
@@ -116,7 +118,8 @@
116118
"neuralnet" = gettext("Neural Network Classification"),
117119
"rpart" = gettext("Decision Tree Classification"),
118120
"svm" = gettext("Support Vector Machine Classification"),
119-
"naivebayes" = gettext("Naive Bayes Classification")
121+
"naivebayes" = gettext("Naive Bayes Classification"),
122+
"logistic" = gettext("Logistic / Multinomial Regression Classification")
120123
)
121124
tableTitle <- gettextf("Model Summary: %1$s", title)
122125
table <- createJaspTable(tableTitle)
@@ -147,6 +150,9 @@
147150
table$addColumnInfo(name = "vectors", title = gettext("Support Vectors"), type = "integer")
148151
} else if (type == "naivebayes") {
149152
table$addColumnInfo(name = "smoothing", title = gettext("Smoothing"), type = "number")
153+
} else if (type == "logistic") {
154+
table$addColumnInfo(name = "family", title = gettext("Family"), type = "string")
155+
table$addColumnInfo(name = "link", title = gettext("Link"), type = "string")
150156
}
151157
# Add common columns
152158
table$addColumnInfo(name = "nTrain", title = gettext("n(Train)"), type = "integer")
@@ -164,7 +170,7 @@
164170
}
165171
# If no analysis is run, specify the required variables in a footnote
166172
if (!ready) {
167-
table$addFootnote(gettextf("Please provide a target variable and at least %i feature variable(s).", if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm") 1L else 2L))
173+
table$addFootnote(gettextf("Please provide a target variable and at least %i feature variable(s).", if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm" || type == "logistic") 1L else 2L))
168174
}
169175
if (options[["savePath"]] != "") {
170176
validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0)
@@ -312,6 +318,22 @@
312318
testAcc = classificationResult[["testAcc"]]
313319
)
314320
table$addRows(row)
321+
} else if (type == "logistic") {
322+
if (classificationResult[["family"]] == "binomial") {
323+
table$title <- gettext("Model Summary: Logistic Regression Classification")
324+
} else {
325+
table$title <- gettext("Model Summary: Multinomial Regression Classification")
326+
}
327+
family <- classificationResult[["family"]]
328+
link <- classificationResult[["link"]]
329+
row <- data.frame(
330+
family = paste0(toupper(substr(family, 1, 1)), substr(family, 2, nchar(family))),
331+
link = paste0(toupper(substr(link, 1, 1)), substr(link, 2, nchar(link))),
332+
nTrain = nTrain,
333+
nTest = classificationResult[["ntest"]],
334+
testAcc = classificationResult[["testAcc"]]
335+
)
336+
table$addRows(row)
315337
}
316338
# Save the applied model if requested
317339
if (options[["saveModel"]] && options[["savePath"]] != "") {
@@ -564,6 +586,26 @@
564586
fit <- e1071::naiveBayes(formula, data = dataset, laplace = options[["smoothingParameter"]])
565587
predictions <- as.factor(max.col(predict(fit, newdata = grid, type = "raw")))
566588
levels(predictions) <- unique(dataset[, options[["target"]]])
589+
} else if (type == "logistic") {
590+
if (classificationResult[["family"]] == "binomial") {
591+
fit <- stats::glm(formula, data = dataset, family = stats::binomial(link = options[["link"]]))
592+
predictions <- as.factor(round(predict(fit, grid, type = "response"), 0))
593+
levels(predictions) <- unique(dataset[, options[["target"]]])
594+
} else {
595+
fit <- VGAM::vglm(formula, data = dataset, family = VGAM::multinomial())
596+
logodds <- predict(fit, newdata = grid)
597+
ncategories <- ncol(logodds) + 1
598+
probabilities <- matrix(0, nrow = nrow(logodds), ncol = ncategories)
599+
for (i in seq_len(ncategories - 1)) {
600+
probabilities[, i] <- exp(logodds[, i])
601+
}
602+
probabilities[, ncategories] <- 1
603+
row_sums <- rowSums(probabilities)
604+
probabilities <- probabilities / row_sums
605+
predicted_columns <- apply(probabilities, 1, which.max)
606+
categories <- levels(dataset[[options[["target"]]]])
607+
predictions <- as.factor(categories[predicted_columns])
608+
}
567609
}
568610
shapes <- rep(21, nrow(dataset))
569611
if (type == "svm") {
@@ -703,6 +745,9 @@
703745
} else if (type == "naivebayes") {
704746
fit <- e1071::naiveBayes(formula = formula, data = typeData, laplace = options[["smoothingParameter"]])
705747
score <- max.col(predict(fit, test, type = "raw"))
748+
} else if (type == "logistic") {
749+
fit <- stats::glm(formula, data = typeData, family = stats::binomial(link = options[["link"]]))
750+
score <- round(predict(fit, test, type = "response"), 0)
706751
}
707752
pred <- ROCR::prediction(score, actual.class)
708753
nbperf <- ROCR::performance(pred, "tpr", "fpr")
@@ -1120,3 +1165,9 @@
11201165
score <- max.col(predict(fit, test, type = "raw"))
11211166
return(score)
11221167
}
1168+
1169+
.calcAUCScore.logisticClassification <- function(AUCformula, test, typeData, options, jaspResults, ...) {
1170+
fit <- stats::glm(AUCformula, data = typeData, family = stats::binomial(link = options[["link"]]))
1171+
score <- round(predict(fit, test, type = "response"), 0)
1172+
return(score)
1173+
}

0 commit comments

Comments
 (0)