Skip to content

Commit 82a9b8e

Browse files
committed
Prediction for logit regression
1 parent 09b2b45 commit 82a9b8e

File tree

2 files changed

+5
-10
lines changed

2 files changed

+5
-10
lines changed

R/mlClassificationLogisticMultinomial.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
236236

237237
.mlClassificationLogisticPredictions <- function(trainingSet, options, probabilities) {
238238
categories <- levels(trainingSet[[options[["target"]]]])
239-
predicted_categories <- categories[round(probabilities, 0) + 1]
239+
predicted_categories <- factor(categories[round(probabilities, 0) + 1], levels = levels(trainingSet[[options[["target"]]]]))
240240
return(predicted_categories)
241241
}
242242

@@ -251,6 +251,6 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
251251
probabilities <- probabilities / row_sums
252252
predicted_columns <- apply(probabilities, 1, which.max)
253253
categories <- levels(trainingSet[[options[["target"]]]])
254-
predicted_categories <- categories[predicted_columns]
254+
predicted_categories <- factor(categories[predicted_columns], levels = levels(trainingSet[[options[["target"]]]]))
255255
return(predicted_categories)
256256
}

R/mlPrediction.R

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ is.jaspMachineLearning <- function(x) {
142142
as.character(e1071:::predict.naiveBayes(model, newdata = dataset, type = "class"))
143143
}
144144
.mlPredictionGetPredictions.glm <- function(model, dataset) {
145-
# TODO
145+
as.character(levels(as.factor(model$model[, 1]))[round(predict(model, newdata = dataset, type = "response"), 0) + 1])
146146
}
147147
.mlPredictionGetPredictions.vglm <- function(model, dataset) {
148148
# TODO
@@ -248,7 +248,8 @@ is.jaspMachineLearning <- function(x) {
248248
return(model)
249249
}
250250
.decodeJaspMLobject.glm <- function(model) {
251-
# TODO
251+
formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + ")))
252+
model$terms <- stats::terms(formula)
252253
return(model)
253254
}
254255
.decodeJaspMLobject.vglm <- function(model) {
@@ -352,8 +353,6 @@ is.jaspMachineLearning <- function(x) {
352353
table$addColumnInfo(name = "mtry", title = gettext("Features per split"), type = "integer")
353354
} else if (inherits(model, "cv.glmnet")) {
354355
table$addColumnInfo(name = "lambda", title = "\u03BB", type = "number")
355-
} else if (inherits(model, "glm") || inherits(model, "vglm")) {
356-
table$addColumnInfo(name = "family", title = gettext("Family"), type = "string")
357356
}
358357
table$addColumnInfo(name = "ntrain", title = gettext("n(Train)"), type = "integer")
359358
table$addColumnInfo(name = "nnew", title = gettext("n(New)"), type = "integer")
@@ -372,10 +371,6 @@ is.jaspMachineLearning <- function(x) {
372371
row[["mtry"]] <- model[["mtry"]]
373372
} else if (inherits(model, "cv.glmnet")) {
374373
row[["lambda"]] <- model[["lambda.min"]]
375-
} else if (inherits(model, "glm")) {
376-
row[["family"]] <- gettext("binomial")
377-
} else if (inherits(model, "vglm")) {
378-
row[["family"]] <- gettext("multinomial")
379374
}
380375
if (length(presentVars) > 0) {
381376
row[["nnew"]] <- nrow(dataset)

0 commit comments

Comments
 (0)