Skip to content

Commit 9f8a105

Browse files
committed
Class probabilities in prediction analysis
1 parent d0bdbe4 commit 9f8a105

File tree

3 files changed

+90
-33
lines changed

3 files changed

+90
-33
lines changed

R/mlPrediction.R

Lines changed: 79 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -88,74 +88,103 @@ is.jaspMachineLearning <- function(x) {
8888
}
8989
.mlPredictionGetPredictions.kknn <- function(model, dataset) {
9090
if (inherits(model, "jaspClassification")) {
91-
as.character(kknn:::predict.train.kknn(model[["predictive"]], dataset))
91+
hard <- as.character(kknn:::predict.train.kknn(model[["predictive"]], dataset))
92+
soft <- kknn:::predict.train.kknn(model[["predictive"]], dataset, type = "prob")
93+
return(list(hard, soft))
9294
} else if (inherits(model, "jaspRegression")) {
93-
as.numeric(kknn:::predict.train.kknn(model[["predictive"]], dataset))
95+
hard <- as.numeric(kknn:::predict.train.kknn(model[["predictive"]], dataset))
96+
return(list(hard))
9497
}
9598
}
9699
.mlPredictionGetPredictions.lda <- function(model, dataset) {
97-
as.character(MASS:::predict.lda(model, newdata = dataset)$class)
100+
hard <- as.character(MASS:::predict.lda(model, newdata = dataset)$class)
101+
soft <- MASS:::predict.lda(model, newdata = dataset)$posterior
102+
return(list(hard, soft))
98103
}
99104
.mlPredictionGetPredictions.lm <- function(model, dataset) {
100-
as.numeric(predict(model, newdata = dataset))
105+
hard <- as.numeric(predict(model, newdata = dataset))
106+
return(list(hard))
101107
}
102108
.mlPredictionGetPredictions.gbm <- function(model, dataset) {
103109
if (inherits(model, "jaspClassification")) {
104-
tmp <- gbm:::predict.gbm(model, newdata = dataset, n.trees = model[["n.trees"]], type = "response")
105-
as.character(colnames(tmp)[apply(tmp, 1, which.max)])
110+
soft <- gbm:::predict.gbm(model, newdata = dataset, n.trees = model[["n.trees"]], type = "response")
111+
hard <- as.character(colnames(soft)[apply(soft, 1, which.max)])
112+
return(list(hard, soft[, , 1]))
106113
} else if (inherits(model, "jaspRegression")) {
107-
as.numeric(gbm:::predict.gbm(model, newdata = dataset, n.trees = model[["n.trees"]], type = "response"))
114+
hard <- as.numeric(gbm:::predict.gbm(model, newdata = dataset, n.trees = model[["n.trees"]], type = "response"))
115+
return(list(hard))
108116
}
109117
}
110118
.mlPredictionGetPredictions.randomForest <- function(model, dataset) {
111119
if (inherits(model, "jaspClassification")) {
112-
as.character(randomForest:::predict.randomForest(model, newdata = dataset))
120+
hard <- as.character(randomForest:::predict.randomForest(model, newdata = dataset))
121+
soft <- predict(model, newdata = dataset, type = "prob")
122+
return(list(hard, soft))
113123
} else if (inherits(model, "jaspRegression")) {
114-
as.numeric(randomForest:::predict.randomForest(model, newdata = dataset))
124+
hard <- as.numeric(randomForest:::predict.randomForest(model, newdata = dataset))
125+
return(list(hard))
115126
}
116127
}
117128
.mlPredictionGetPredictions.cv.glmnet <- function(model, dataset) {
118-
as.numeric(glmnet:::predict.cv.glmnet(model, newx = data.matrix(dataset)))
129+
hard <- as.numeric(glmnet:::predict.cv.glmnet(model, newx = data.matrix(dataset)))
130+
return(list(hard))
119131
}
120132
.mlPredictionGetPredictions.nn <- function(model, dataset) {
121133
if (inherits(model, "jaspClassification")) {
122-
as.character(levels(factor(model[["data"]][[model[["jaspVars"]][["encoded"]]$target]]))[max.col(neuralnet:::predict.nn(model, newdata = dataset))])
134+
soft <- neuralnet:::predict.nn(model, newdata = dataset)
135+
colnames(soft) <- levels(factor(model[["data"]][[model[["jaspVars"]][["encoded"]]$target]]))
136+
hard <- colnames(soft)[apply(soft, 1, which.max)]
137+
return(list(hard, soft))
123138
} else if (inherits(model, "jaspRegression")) {
124-
as.numeric(neuralnet:::predict.nn(model, newdata = dataset))
139+
hard <- as.numeric(neuralnet:::predict.nn(model, newdata = dataset))
140+
return(list(hard))
125141
}
126142
}
127143
.mlPredictionGetPredictions.rpart <- function(model, dataset) {
128144
if (inherits(model, "jaspClassification")) {
129-
as.character(levels(factor(model[["data"]][[model[["jaspVars"]][["encoded"]]$target]]))[max.col(rpart:::predict.rpart(model, newdata = dataset))])
145+
soft <- rpart:::predict.rpart(model, newdata = dataset)
146+
colnames(soft) <- levels(factor(model[["data"]][[model[["jaspVars"]][["encoded"]]$target]]))
147+
hard <- colnames(soft)[apply(soft, 1, which.max)]
148+
return(list(hard, soft))
130149
} else if (inherits(model, "jaspRegression")) {
131-
as.numeric(rpart:::predict.rpart(model, newdata = dataset))
150+
hard <- as.numeric(rpart:::predict.rpart(model, newdata = dataset))
151+
return(list(hard))
132152
}
133153
}
134154
.mlPredictionGetPredictions.svm <- function(model, dataset) {
135155
if (inherits(model, "jaspClassification")) {
136-
as.character(levels(factor(model[["data"]][[model[["jaspVars"]][["encoded"]]$target]]))[e1071:::predict.svm(model, newdata = dataset)])
156+
soft <- attr(e1071:::predict.svm(model, newdata = dataset, probability = TRUE), "probabilities")
157+
hard <- as.character(e1071:::predict.svm(model, newdata = dataset))
158+
return(list(hard, soft))
137159
} else if (inherits(model, "jaspRegression")) {
138-
as.numeric(e1071:::predict.svm(model, newdata = dataset))
160+
hard <- as.numeric(e1071:::predict.svm(model, newdata = dataset))
161+
return(list(hard))
139162
}
140163
}
141164
.mlPredictionGetPredictions.naiveBayes <- function(model, dataset) {
142-
as.character(e1071:::predict.naiveBayes(model, newdata = dataset, type = "class"))
165+
soft <- e1071:::predict.naiveBayes(model, newdata = dataset, type = "raw")
166+
hard <- as.character(e1071:::predict.naiveBayes(model, newdata = dataset, type = "class"))
167+
return(list(hard, soft))
143168
}
144169
.mlPredictionGetPredictions.glm <- function(model, dataset) {
145-
as.character(levels(as.factor(model$model[[model[["jaspVars"]][["encoded"]]$target]]))[round(predict(model, newdata = dataset, type = "response"), 0) + 1])
170+
probs <- predict(model, newdata = dataset, type = "response")
171+
soft <- matrix(c(1 - probs, probs), ncol = 2)
172+
colnames(soft) <- levels(as.factor(model$model[[model[["jaspVars"]][["encoded"]]$target]]))
173+
hard <- colnames(soft)[apply(soft, 1, which.max)]
174+
return(list(hard, soft))
146175
}
147176
.mlPredictionGetPredictions.vglm <- function(model, dataset) {
148177
logodds <- predict(model[["original"]], newdata = dataset)
149178
ncategories <- ncol(logodds) + 1
150-
probabilities <- matrix(0, nrow = nrow(logodds), ncol = ncategories)
179+
soft <- matrix(0, nrow = nrow(logodds), ncol = ncategories)
151180
for (i in seq_len(ncategories - 1)) {
152-
probabilities[, i] <- exp(logodds[, i])
181+
soft[, i] <- exp(logodds[, i])
153182
}
154-
probabilities[, ncategories] <- 1
155-
row_sums <- rowSums(probabilities)
156-
probabilities <- probabilities / row_sums
157-
predicted_columns <- apply(probabilities, 1, which.max)
158-
as.character(levels(as.factor(model$target))[predicted_columns])
183+
soft[, ncategories] <- 1
184+
soft <- soft / rowSums(soft)
185+
colnames(soft) <- as.character(levels(as.factor(model$target)))
186+
hard <- colnames(soft)[apply(soft, 1, which.max)]
187+
return(list(hard, soft))
159188
}
160189

161190
# S3 method to make find out number of observations in training data
@@ -372,7 +401,7 @@ is.jaspMachineLearning <- function(x) {
372401
if (!ready) {
373402
return()
374403
}
375-
predictions <- .mlPredictionsState(model, dataset, options, jaspResults, ready)
404+
predictions <- .mlPredictionsState(model, dataset, options, jaspResults, ready)[[1]]
376405
indexes <- options[["fromIndex"]]:options[["toIndex"]]
377406
selection <- predictions[indexes]
378407
cols <- list(row = indexes, pred = selection)
@@ -397,12 +426,29 @@ is.jaspMachineLearning <- function(x) {
397426
}
398427

399428
.mlPredictionsAddPredictions <- function(model, dataset, options, jaspResults, ready) {
400-
if (options[["addPredictions"]] && is.null(jaspResults[["predictionsColumn"]]) && options[["predictionsColumn"]] != "" && ready) {
401-
predictionsColumn <- rep(NA, max(as.numeric(rownames(dataset))))
402-
predictionsColumn[as.numeric(rownames(dataset))] <- .mlPredictionsState(model, dataset, options, jaspResults, ready)
403-
jaspResults[["predictionsColumn"]] <- createJaspColumn(columnName = options[["predictionsColumn"]])
404-
jaspResults[["predictionsColumn"]]$dependOn(options = c("predictionsColumn", "predictors", "trainedModelFilePath", "scaleVariables", "addPredictions"))
405-
if (inherits(model, "jaspClassification")) jaspResults[["predictionsColumn"]]$setNominal(predictionsColumn)
406-
if (inherits(model, "jaspRegression")) jaspResults[["predictionsColumn"]]$setScale(predictionsColumn)
429+
if (options[["addPredictions"]] && options[["predictionsColumn"]] != "" && ready) {
430+
predictions <- .mlPredictionsState(model, dataset, options, jaspResults, ready)
431+
# Add hard predictions for regression and classification
432+
if (is.null(jaspResults[["predictionsColumn"]])) {
433+
predictionsColumn <- rep(NA, max(as.numeric(rownames(dataset))))
434+
predictionsColumn[as.numeric(rownames(dataset))] <- predictions[[1]]
435+
jaspResults[["predictionsColumn"]] <- createJaspColumn(columnName = options[["predictionsColumn"]])
436+
jaspResults[["predictionsColumn"]]$dependOn(options = c("predictionsColumn", "predictors", "trainedModelFilePath", "scaleVariables", "addPredictions"))
437+
if (inherits(model, "jaspClassification")) jaspResults[["predictionsColumn"]]$setNominal(predictionsColumn)
438+
if (inherits(model, "jaspRegression")) jaspResults[["predictionsColumn"]]$setScale(predictionsColumn)
439+
}
440+
# Add predicted probabilities for classification only
441+
if (inherits(model, "jaspClassification") && options[["addProbabilities"]]) {
442+
classNames <- colnames(predictions[[2]])
443+
for (i in seq_along(classNames)) {
444+
colName <- paste0(decodeColNames(options[["predictionsColumn"]]), "_", classNames[i])
445+
if (!is.null(jaspResults[[colName]])) {
446+
break
447+
}
448+
jaspResults[[colName]] <- createJaspColumn(columnName = colName)
449+
jaspResults[[colName]]$dependOn(options = c("predictionsColumn", "predictors", "trainedModelFilePath", "scaleVariables", "addPredictions", "addProbabilities"))
450+
jaspResults[[colName]]$setScale(predictions[[2]][, i])
451+
}
452+
}
407453
}
408454
}

inst/qml/common/ui/ExportResults.qml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Group
2424
{
2525
property alias enabled: exportSection.enabled
2626
property alias showSave: saveGroup.visible
27+
property bool showProbs: false
2728

2829
id: exportSection
2930
title: qsTr("Export Results")
@@ -45,6 +46,15 @@ Group
4546
enabled: addPredictions.checked
4647
info: qsTr("The column name for the predicted values.")
4748
}
49+
50+
CheckBox
51+
{
52+
id: probabilities
53+
name: "addProbabilities"
54+
text: qsTr("Add probabilities (classification only)")
55+
visible: showProbs
56+
info: qsTr("In classification analyses, also add the predicted probabilities for each class to the data.")
57+
}
4858
}
4959

5060
Group

inst/qml/mlPrediction.qml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,5 +116,6 @@ Form
116116
UI.ExportResults {
117117
enabled: predictors.count > 1
118118
showSave: false
119+
showProbs: true
119120
}
120121
}

0 commit comments

Comments
 (0)