Skip to content

Commit 50a04d3

Browse files
committed
Make AUC work
1 parent 4dc4a38 commit 50a04d3

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

R/commonMachineLearningClassification.R

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,3 +1138,24 @@
11381138
score <- max.col(predict(fit, test, type = "raw"))
11391139
return(score)
11401140
}
1141+
1142+
.calcAUCScore.logisticClassification <- function(AUCformula, test, typeData, options, jaspResults, ...) {
1143+
fit <- glm(AUCformula, data = typeData, family = "binomial")
1144+
score <- round(predict(fit, test, type = "response"), 0)
1145+
return(score)
1146+
}
1147+
1148+
.calcAUCScore.multinomialClassification <- function(AUCformula, test, typeData, options, jaspResults, ...) {
1149+
fit <- VGAM::vglm(AUCformula, data = typeData, family = "multinomial")
1150+
logodds <- as.data.frame(predict(fit, test))
1151+
ncategories <- ncol(logodds) + 1
1152+
probabilities <- matrix(0, nrow = nrow(logodds), ncol = ncategories)
1153+
for (i in seq_len(ncategories - 1)) {
1154+
probabilities[, i] <- exp(logodds[, i])
1155+
}
1156+
probabilities[, ncategories] <- 1
1157+
row_sums <- rowSums(probabilities)
1158+
probabilities <- probabilities / row_sums
1159+
score <- apply(probabilities, 1, which.max)
1160+
return(score)
1161+
}

R/mlClassificationLogisticMultinomial.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
106106
result[["model"]] <- trainingFit
107107
result[["confTable"]] <- table("Pred" = testPredictions, "Real" = testSet[, options[["target"]]])
108108
result[["testAcc"]] <- sum(diag(prop.table(result[["confTable"]])))
109-
# result[["auc"]] <- .classificationCalcAUC(testSet, trainingSet, options, "logisticClassification")
109+
result[["auc"]] <- .classificationCalcAUC(testSet, trainingSet, options, if (family == "binomial") "logisticClassification" else "multinomialClassification")
110110
result[["ntrain"]] <- nrow(trainingSet)
111111
result[["ntest"]] <- nrow(testSet)
112112
result[["testReal"]] <- testSet[, options[["target"]]]
@@ -207,7 +207,7 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
207207
table[["lower"]] <- coefs[, "lower"]
208208
table[["upper"]] <- coefs[, "upper"]
209209
}
210-
if (options[["formula"]]) {
210+
if (options[["formula"]]) { # TODO FOR MULTINOMIAL
211211
one_cat <- levels(factor(classificationResult[["train"]][[options[["target"]]]]))[2]
212212
if (options[["intercept"]]) {
213213
regform <- paste0("logit(p<sub>", options[["target"]], " = ", one_cat, "</sub>) = ", round(as.numeric(coefs[, 1])[1], 3))

0 commit comments

Comments
 (0)