Skip to content

Commit d9a3a59

Browse files
committed
Predictions for multinomial
1 parent 82a9b8e commit d9a3a59

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

R/mlClassificationLogisticMultinomial.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
109109
model <- lapply(slotNames(trainingFit), function(x) slot(trainingFit, x))
110110
names(model) <- slotNames(trainingFit)
111111
model[["original"]] <- trainingFit
112+
model[["target"]] <- trainingSet[[options[["target"]]]]
112113
class(model) <- "vglm"
113114
result[["model"]] <- model
114115
}

R/mlPrediction.R

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,18 @@ is.jaspMachineLearning <- function(x) {
145145
as.character(levels(as.factor(model$model[, 1]))[round(predict(model, newdata = dataset, type = "response"), 0) + 1])
146146
}
147147
.mlPredictionGetPredictions.vglm <- function(model, dataset) {
148-
# TODO
148+
model[["original"]]@terms$terms <- model[["terms"]]
149+
logodds <- predict(model[["original"]], newdata = dataset)
150+
ncategories <- ncol(logodds) + 1
151+
probabilities <- matrix(0, nrow = nrow(logodds), ncol = ncategories)
152+
for (i in seq_len(ncategories - 1)) {
153+
probabilities[, i] <- exp(logodds[, i])
154+
}
155+
probabilities[, ncategories] <- 1
156+
row_sums <- rowSums(probabilities)
157+
probabilities <- probabilities / row_sums
158+
predicted_columns <- apply(probabilities, 1, which.max)
159+
as.character(levels(as.factor(model$target))[predicted_columns])
149160
}
150161

151162
# S3 method to make find out number of observations in training data
@@ -186,7 +197,7 @@ is.jaspMachineLearning <- function(x) {
186197
nrow(model[["data"]])
187198
}
188199
.mlPredictionGetTrainingN.vglm <- function(model) {
189-
nrow(model$x)
200+
nrow(model[["x"]])
190201
}
191202

192203
# S3 method to decode the model variables in the result object
@@ -253,7 +264,8 @@ is.jaspMachineLearning <- function(x) {
253264
return(model)
254265
}
255266
.decodeJaspMLobject.vglm <- function(model) {
256-
# TODO
267+
formula <- formula(paste(decodeColNames(strsplit(as.character(model$terms), " ")[[1]][1]), "~", paste0(decodeColNames(strsplit(strsplit(as.character(model$terms), split = " ~ ")[[1]][2], split = " + ", fixed = TRUE)[[1]]), collapse = " + ")))
268+
model$terms <- stats::terms(formula)
257269
return(model)
258270
}
259271

0 commit comments

Comments
 (0)