Skip to content

Commit ef552ae

Browse files
committed
Make decision boundary work
1 parent 5c3599e commit ef552ae

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

R/commonMachineLearningClassification.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,26 @@
582582
fit <- e1071::naiveBayes(formula, data = dataset, laplace = options[["smoothingParameter"]])
583583
predictions <- as.factor(max.col(predict(fit, newdata = grid, type = "raw")))
584584
levels(predictions) <- unique(dataset[, options[["target"]]])
585+
} else if (type == "logistic") {
586+
if (classificationResult[["family"]] == "binomial") {
587+
fit <- glm(formula, data = dataset, family = "binomial")
588+
predictions <- as.factor(round(predict(fit, grid, type = "response"), 0))
589+
levels(predictions) <- unique(dataset[, options[["target"]]])
590+
} else {
591+
fit <- VGAM::vglm(formula, data = dataset, family = "multinomial")
592+
logodds <- predict(fit, newdata = grid)
593+
ncategories <- ncol(logodds) + 1
594+
probabilities <- matrix(0, nrow = nrow(logodds), ncol = ncategories)
595+
for (i in seq_len(ncategories - 1)) {
596+
probabilities[, i] <- exp(logodds[, i])
597+
}
598+
probabilities[, ncategories] <- 1
599+
row_sums <- rowSums(probabilities)
600+
probabilities <- probabilities / row_sums
601+
predicted_columns <- apply(probabilities, 1, which.max)
602+
categories <- levels(dataset[[options[["target"]]]])
603+
predictions <- as.factor(categories[predicted_columns])
604+
}
585605
}
586606
shapes <- rep(21, nrow(dataset))
587607
if (type == "svm") {

R/mlClassificationLogisticMultinomial.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
5959
# Create the Andrews curves
6060
.mlClassificationPlotAndrews(dataset, options, jaspResults, ready, position = 11)
6161

62-
# # Decision boundaries
63-
# .mlClassificationPlotBoundaries(dataset, options, jaspResults, ready, position = 12, type = "logistic")
62+
# Decision boundaries
63+
.mlClassificationPlotBoundaries(dataset, options, jaspResults, ready, position = 12, type = "logistic")
6464
}
6565

6666
.logisticMultinomialClassification <- function(dataset, options, jaspResults, ready) {

0 commit comments

Comments
 (0)