Skip to content

Commit ac4d14e

Browse files
committed
Update mlClassificationLogistic.R
1 parent a83a279 commit ac4d14e

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

R/mlClassificationLogistic.R

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ mlClassificationLogistic <- function(jaspResults, dataset, options, ...) {
8282
family = "binomial"
8383
trainingFit <- stats::glm(formula, data = trainingSet, family = family)
8484
# Use the specified model to make predictions for dataset
85-
testPredictions <- levels(trainingSet[[options[["target"]]]])[round(predict(trainingFit, newdata = testSet, type = "response"), 0) + 1]
86-
dataPredictions <- levels(trainingSet[[options[["target"]]]])[round(predict(trainingFit, newdata = dataset, type = "response"), 0) + 1]
85+
testPredictions <- .mlClassificationLogisticPredictions(trainingSet, options, predict(trainingFit, newdata = testSet, type = "response"))
86+
dataPredictions <- .mlClassificationLogisticPredictions(trainingSet, options, predict(trainingFit, newdata = dataset, type = "response"))
8787
} else {
8888
family <- "multinomial"
8989
trainingFit <- VGAM::vglm(formula, data = trainingSet, family = family)
@@ -116,17 +116,23 @@ mlClassificationLogistic <- function(jaspResults, dataset, options, ...) {
116116
return(result)
117117
}
118118

119-
.mlClassificationMultinomialPredictions <- function(trainingSet, options, predictions) {
120-
num_categories <- ncol(predictions) + 1
121-
probs <- matrix(0, nrow = nrow(predictions), ncol = num_categories)
122-
for (i in 1:(num_categories - 1)) {
123-
probs[, i] <- exp(predictions[, i])
119+
.mlClassificationLogisticPredictions <- function(trainingSet, options, probabilities) {
120+
categories <- levels(trainingSet[[options[["target"]]]])
121+
predicted_categories <- categories[round(probabilities, 0) + 1]
122+
return(predicted_categories)
123+
}
124+
125+
.mlClassificationMultinomialPredictions <- function(trainingSet, options, logodds) {
126+
ncategories <- ncol(logodds) + 1
127+
probabilities <- matrix(0, nrow = nrow(logodds), ncol = ncategories)
128+
for (i in seq_len(ncategories - 1)) {
129+
probabilities[, i] <- exp(logodds[, i])
124130
}
125-
probs[, num_categories] <- 1
126-
row_sums <- rowSums(probs)
127-
probs <- probs / row_sums
128-
predicted_category <- apply(probs, 1, which.max)
131+
probabilities[, ncategories] <- 1
132+
row_sums <- rowSums(probabilities)
133+
probabilities <- probabilities / row_sums
134+
predicted_columns <- apply(probabilities, 1, which.max)
129135
categories <- levels(trainingSet[[options[["target"]]]])
130-
predicted_categories <- categories[predicted_category]
136+
predicted_categories <- categories[predicted_columns]
131137
return(predicted_categories)
132138
}

0 commit comments

Comments
 (0)