Skip to content

Commit bde825d

Browse files
committed
Also fix logistic
1 parent d6b86c1 commit bde825d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

R/commonMachineLearningClassification.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -585,8 +585,8 @@
585585
} else if (type == "logistic") {
586586
if (classificationResult[["family"]] == "binomial") {
587587
fit <- stats::glm(formula, data = dataset, family = stats::binomial(link = options[["link"]]))
588-
predictions <- as.factor(round(predict(fit, grid, type = "response"), 0))
589-
levels(predictions) <- unique(dataset[, options[["target"]]])
588+
probabilities <- predict(fit, grid, type = "response")
589+
predictions <- levels(dataset[, options[["target"]]])[round(probabilities, 0) + 1]
590590
} else {
591591
fit <- VGAM::vglm(formula, data = dataset, family = VGAM::multinomial())
592592
logodds <- predict(fit, newdata = grid)

0 commit comments

Comments
 (0)