Skip to content

Commit 5c3599e

Browse files
committed
Make roc plot work
1 parent 50a04d3 commit 5c3599e

File tree

4 files changed

+55
-24
lines changed

4 files changed

+55
-24
lines changed

R/commonMachineLearningClassification.R

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,9 @@
721721
} else if (type == "naivebayes") {
722722
fit <- e1071::naiveBayes(formula = formula, data = typeData, laplace = options[["smoothingParameter"]])
723723
score <- max.col(predict(fit, test, type = "raw"))
724+
} else if (type == "logistic") {
725+
fit <- glm(formula, data = typeData, family = "binomial")
726+
score <- round(predict(fit, test, type = "response"), 0)
724727
}
725728
pred <- ROCR::prediction(score, actual.class)
726729
nbperf <- ROCR::performance(pred, "tpr", "fpr")
@@ -1144,18 +1147,3 @@
11441147
score <- round(predict(fit, test, type = "response"), 0)
11451148
return(score)
11461149
}
1147-
1148-
.calcAUCScore.multinomialClassification <- function(AUCformula, test, typeData, options, jaspResults, ...) {
1149-
fit <- VGAM::vglm(AUCformula, data = typeData, family = "multinomial")
1150-
logodds <- as.data.frame(predict(fit, test))
1151-
ncategories <- ncol(logodds) + 1
1152-
probabilities <- matrix(0, nrow = nrow(logodds), ncol = ncategories)
1153-
for (i in seq_len(ncategories - 1)) {
1154-
probabilities[, i] <- exp(logodds[, i])
1155-
}
1156-
probabilities[, ncategories] <- 1
1157-
row_sums <- rowSums(probabilities)
1158-
probabilities <- probabilities / row_sums
1159-
score <- apply(probabilities, 1, which.max)
1160-
return(score)
1161-
}

R/mlClassificationLogisticMultinomial.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
5353

5454
.mlClassificationLogisticTableCoef(options, jaspResults, ready, position = 8)
5555

56-
# # Create the ROC curve
57-
# .mlClassificationPlotRoc(dataset, options, jaspResults, ready, position = 10, type = "logistic") # position + 1 for regression equation
56+
# Create the ROC curve
57+
.mlClassificationPlotRoc(dataset, options, jaspResults, ready, position = 10, type = "logistic") # position + 1 for regression equation
5858

5959
# Create the Andrews curves
6060
.mlClassificationPlotAndrews(dataset, options, jaspResults, ready, position = 11)
@@ -106,7 +106,7 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
106106
result[["model"]] <- trainingFit
107107
result[["confTable"]] <- table("Pred" = testPredictions, "Real" = testSet[, options[["target"]]])
108108
result[["testAcc"]] <- sum(diag(prop.table(result[["confTable"]])))
109-
result[["auc"]] <- .classificationCalcAUC(testSet, trainingSet, options, if (family == "binomial") "logisticClassification" else "multinomialClassification")
109+
result[["auc"]] <- .classificationCalcAUC(testSet, trainingSet, options, "logisticClassification")
110110
result[["ntrain"]] <- nrow(trainingSet)
111111
result[["ntest"]] <- nrow(testSet)
112112
result[["testReal"]] <- testSet[, options[["target"]]]
Lines changed: 42 additions & 0 deletions
Loading

tests/testthat/test-mlclassificationlogisticmultinomial.R

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ test_that("Class Proportions table results match", {
3636
"virginica", 0.4, 0.316666666666667))
3737
})
3838

39-
test_that("Model Summary: Logistic / Multinomial Regression table results match", {
39+
test_that("Model Summary: Multinomial Regression Classification table results match", {
4040
table <- results[["results"]][["classificationTable"]][["data"]]
4141
jaspTools::expect_equal_tables(table,
4242
list("multinomial", 30, 120, 1))
@@ -58,9 +58,10 @@ test_that("Data Split plot matches", {
5858
test_that("Model Performance Metrics table results match", {
5959
table <- results[["results"]][["validationMeasures"]][["data"]]
6060
jaspTools::expect_equal_tables(table,
61-
list(1, "", 1, 0, 0, 0, 0, "setosa", 1, 1, 1, 1, 0.333333333333333,
62-
10, 1, "<unicode>", 1, 1, 0, 0, 0, 0, "versicolor", 1, 1, 1,
63-
1, 0.266666666666667, 8, 1, "<unicode>", 1, 1, 0, 0, 0, 0, "virginica",
64-
1, 1, 1, 1, 0.4, 12, 1, "<unicode>", 1, 1, 0, 0, 0, 0, "Average / Total",
65-
1, 1, 1, 1, 1, 30, 1, "<unicode>"))
61+
list(1, 1, 1, 0, 0, 0, 0, "setosa", 1, 1, 1, 1, 0.333333333333333,
62+
10, 1, "<unicode>", 1, 0.613636363636364, 1, 0, 0, 0, 0, "versicolor",
63+
1, 1, 1, 1, 0.266666666666667, 8, 1, "<unicode>", 1, 1, 1, 0,
64+
0, 0, 0, "virginica", 1, 1, 1, 1, 0.4, 12, 1, "<unicode>", 1,
65+
0.871212121212121, 1, 0, 0, 0, 0, "Average / Total", 1, 1, 1,
66+
1, 1, 30, 1, "<unicode>"))
6667
})

0 commit comments

Comments
 (0)