Skip to content

Commit d6b86c1

Browse files
committed
Fix wrong order in decision boundary plot
1 parent 46f1746 commit d6b86c1

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

R/commonMachineLearningClassification.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -564,13 +564,13 @@
564564
act.fct = jaspResults[["actfct"]]$object,
565565
linear.output = FALSE
566566
)
567-
predictions <- as.factor(max.col(predict(fit, newdata = grid)))
568-
levels(predictions) <- unique(dataset[, options[["target"]]])
567+
probabilities <- predict(fit, newdata = grid)
568+
predictions <- levels(dataset[, options[["target"]]])[apply(probabilities, 1, which.max)]
569569
} else if (type == "rpart") {
570570
classificationResult <- jaspResults[["classificationResult"]]$object
571571
fit <- rpart::rpart(formula, data = dataset, method = "class", control = rpart::rpart.control(minsplit = options[["minObservationsForSplit"]], minbucket = options[["minObservationsInNode"]], maxdepth = options[["interactionDepth"]], cp = classificationResult[["penalty"]]))
572-
predictions <- as.factor(max.col(predict(fit, newdata = grid)))
573-
levels(predictions) <- unique(dataset[, options[["target"]]])
572+
probabilities <- predict(fit, newdata = grid)
573+
predictions <- colnames(probabilities)[apply(probabilities, 1, which.max)]
574574
} else if (type == "svm") {
575575
classificationResult <- jaspResults[["classificationResult"]]$object
576576
fit <- e1071::svm(formula,
@@ -580,8 +580,8 @@
580580
predictions <- predict(fit, newdata = grid)
581581
} else if (type == "naivebayes") {
582582
fit <- e1071::naiveBayes(formula, data = dataset, laplace = options[["smoothingParameter"]])
583-
predictions <- as.factor(max.col(predict(fit, newdata = grid, type = "raw")))
584-
levels(predictions) <- unique(dataset[, options[["target"]]])
583+
probabilities <- predict(fit, newdata = grid, type = "raw")
584+
predictions <- colnames(probabilities)[apply(probabilities, 1, which.max)]
585585
} else if (type == "logistic") {
586586
if (classificationResult[["family"]] == "binomial") {
587587
fit <- stats::glm(formula, data = dataset, family = stats::binomial(link = options[["link"]]))

0 commit comments

Comments
 (0)