Skip to content

Commit 874330b

Browse files
committed
More explicit tie breaking
1 parent 9f8a105 commit 874330b

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

R/mlPrediction.R

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,17 @@ is.jaspMachineLearning <- function(x) {
8888
}
8989
.mlPredictionGetPredictions.kknn <- function(model, dataset) {
9090
if (inherits(model, "jaspClassification")) {
91-
hard <- as.character(kknn:::predict.train.kknn(model[["predictive"]], dataset))
9291
soft <- kknn:::predict.train.kknn(model[["predictive"]], dataset, type = "prob")
92+
hard <- colnames(soft)[max.col(soft, ties.method = "random")]
9393
return(list(hard, soft))
9494
} else if (inherits(model, "jaspRegression")) {
9595
hard <- as.numeric(kknn:::predict.train.kknn(model[["predictive"]], dataset))
9696
return(list(hard))
9797
}
9898
}
9999
.mlPredictionGetPredictions.lda <- function(model, dataset) {
100-
hard <- as.character(MASS:::predict.lda(model, newdata = dataset)$class)
101100
soft <- MASS:::predict.lda(model, newdata = dataset)$posterior
101+
hard <- colnames(soft)[max.col(soft, ties.method = "random")]
102102
return(list(hard, soft))
103103
}
104104
.mlPredictionGetPredictions.lm <- function(model, dataset) {
@@ -107,18 +107,18 @@ is.jaspMachineLearning <- function(x) {
107107
}
108108
.mlPredictionGetPredictions.gbm <- function(model, dataset) {
109109
if (inherits(model, "jaspClassification")) {
110-
soft <- gbm:::predict.gbm(model, newdata = dataset, n.trees = model[["n.trees"]], type = "response")
111-
hard <- as.character(colnames(soft)[apply(soft, 1, which.max)])
112-
return(list(hard, soft[, , 1]))
110+
soft <- gbm:::predict.gbm(model, newdata = dataset, n.trees = model[["n.trees"]], type = "response")[, , 1]
111+
hard <- colnames(soft)[max.col(soft, ties.method = "random")]
112+
return(list(hard, soft))
113113
} else if (inherits(model, "jaspRegression")) {
114114
hard <- as.numeric(gbm:::predict.gbm(model, newdata = dataset, n.trees = model[["n.trees"]], type = "response"))
115115
return(list(hard))
116116
}
117117
}
118118
.mlPredictionGetPredictions.randomForest <- function(model, dataset) {
119119
if (inherits(model, "jaspClassification")) {
120-
hard <- as.character(randomForest:::predict.randomForest(model, newdata = dataset))
121120
soft <- predict(model, newdata = dataset, type = "prob")
121+
hard <- colnames(soft)[max.col(soft, ties.method = "random")]
122122
return(list(hard, soft))
123123
} else if (inherits(model, "jaspRegression")) {
124124
hard <- as.numeric(randomForest:::predict.randomForest(model, newdata = dataset))
@@ -133,7 +133,7 @@ is.jaspMachineLearning <- function(x) {
133133
if (inherits(model, "jaspClassification")) {
134134
soft <- neuralnet:::predict.nn(model, newdata = dataset)
135135
colnames(soft) <- levels(factor(model[["data"]][[model[["jaspVars"]][["encoded"]]$target]]))
136-
hard <- colnames(soft)[apply(soft, 1, which.max)]
136+
hard <- colnames(soft)[max.col(soft, ties.method = "random")]
137137
return(list(hard, soft))
138138
} else if (inherits(model, "jaspRegression")) {
139139
hard <- as.numeric(neuralnet:::predict.nn(model, newdata = dataset))
@@ -144,7 +144,7 @@ is.jaspMachineLearning <- function(x) {
144144
if (inherits(model, "jaspClassification")) {
145145
soft <- rpart:::predict.rpart(model, newdata = dataset)
146146
colnames(soft) <- levels(factor(model[["data"]][[model[["jaspVars"]][["encoded"]]$target]]))
147-
hard <- colnames(soft)[apply(soft, 1, which.max)]
147+
hard <- colnames(soft)[max.col(soft, ties.method = "random")]
148148
return(list(hard, soft))
149149
} else if (inherits(model, "jaspRegression")) {
150150
hard <- as.numeric(rpart:::predict.rpart(model, newdata = dataset))
@@ -154,7 +154,7 @@ is.jaspMachineLearning <- function(x) {
154154
.mlPredictionGetPredictions.svm <- function(model, dataset) {
155155
if (inherits(model, "jaspClassification")) {
156156
soft <- attr(e1071:::predict.svm(model, newdata = dataset, probability = TRUE), "probabilities")
157-
hard <- as.character(e1071:::predict.svm(model, newdata = dataset))
157+
hard <- colnames(soft)[max.col(soft, ties.method = "random")]
158158
return(list(hard, soft))
159159
} else if (inherits(model, "jaspRegression")) {
160160
hard <- as.numeric(e1071:::predict.svm(model, newdata = dataset))
@@ -163,14 +163,14 @@ is.jaspMachineLearning <- function(x) {
163163
}
164164
.mlPredictionGetPredictions.naiveBayes <- function(model, dataset) {
165165
soft <- e1071:::predict.naiveBayes(model, newdata = dataset, type = "raw")
166-
hard <- as.character(e1071:::predict.naiveBayes(model, newdata = dataset, type = "class"))
166+
hard <- colnames(soft)[max.col(soft, ties.method = "random")]
167167
return(list(hard, soft))
168168
}
169169
.mlPredictionGetPredictions.glm <- function(model, dataset) {
170170
probs <- predict(model, newdata = dataset, type = "response")
171171
soft <- matrix(c(1 - probs, probs), ncol = 2)
172172
colnames(soft) <- levels(as.factor(model$model[[model[["jaspVars"]][["encoded"]]$target]]))
173-
hard <- colnames(soft)[apply(soft, 1, which.max)]
173+
hard <- colnames(soft)[max.col(soft, ties.method = "random")]
174174
return(list(hard, soft))
175175
}
176176
.mlPredictionGetPredictions.vglm <- function(model, dataset) {
@@ -183,7 +183,7 @@ is.jaspMachineLearning <- function(x) {
183183
soft[, ncategories] <- 1
184184
soft <- soft / rowSums(soft)
185185
colnames(soft) <- as.character(levels(as.factor(model$target)))
186-
hard <- colnames(soft)[apply(soft, 1, which.max)]
186+
hard <- colnames(soft)[max.col(soft, ties.method = "random")]
187187
return(list(hard, soft))
188188
}
189189

0 commit comments

Comments
 (0)