@@ -88,17 +88,17 @@ is.jaspMachineLearning <- function(x) {
8888}
8989.mlPredictionGetPredictions.kknn  <-  function (model , dataset ) {
9090  if  (inherits(model , " jaspClassification"  )) {
91-     hard  <-  as.character(kknn ::: predict.train.kknn(model [[" predictive"  ]], dataset ))
9291    soft  <-  kknn ::: predict.train.kknn(model [[" predictive"  ]], dataset , type  =  " prob"  )
92+     hard  <-  colnames(soft )[max.col(soft , ties.method  =  " random"  )]
9393    return (list (hard , soft ))
9494  } else  if  (inherits(model , " jaspRegression"  )) {
9595    hard  <-  as.numeric(kknn ::: predict.train.kknn(model [[" predictive"  ]], dataset ))
9696    return (list (hard ))
9797  }
9898}
9999.mlPredictionGetPredictions.lda  <-  function (model , dataset ) {
100-   hard  <-  as.character(MASS ::: predict.lda(model , newdata  =  dataset )$ class )
101100  soft  <-  MASS ::: predict.lda(model , newdata  =  dataset )$ posterior 
101+   hard  <-  colnames(soft )[max.col(soft , ties.method  =  " random"  )]
102102  return (list (hard , soft ))
103103}
104104.mlPredictionGetPredictions.lm  <-  function (model , dataset ) {
@@ -107,18 +107,18 @@ is.jaspMachineLearning <- function(x) {
107107}
108108.mlPredictionGetPredictions.gbm  <-  function (model , dataset ) {
109109  if  (inherits(model , " jaspClassification"  )) {
110-     soft  <-  gbm ::: predict.gbm(model , newdata  =  dataset , n.trees  =  model [[" n.trees"  ]], type  =  " response"  )
111-     hard  <-  as.character( colnames(soft )[apply (soft , 1 ,  which.max )]) 
112-     return (list (hard , soft [, ,  1 ] ))
110+     soft  <-  gbm ::: predict.gbm(model , newdata  =  dataset , n.trees  =  model [[" n.trees"  ]], type  =  " response"  )[, ,  1 ] 
111+     hard  <-  colnames(soft )[max.col (soft , ties.method   =   " random " )] 
112+     return (list (hard , soft ))
113113  } else  if  (inherits(model , " jaspRegression"  )) {
114114    hard  <-  as.numeric(gbm ::: predict.gbm(model , newdata  =  dataset , n.trees  =  model [[" n.trees"  ]], type  =  " response"  ))
115115    return (list (hard ))
116116  }
117117}
118118.mlPredictionGetPredictions.randomForest  <-  function (model , dataset ) {
119119  if  (inherits(model , " jaspClassification"  )) {
120-     hard  <-  as.character(randomForest ::: predict.randomForest(model , newdata  =  dataset ))
121120    soft  <-  predict(model , newdata  =  dataset , type  =  " prob"  )
121+     hard  <-  colnames(soft )[max.col(soft , ties.method  =  " random"  )]
122122    return (list (hard , soft ))
123123  } else  if  (inherits(model , " jaspRegression"  )) {
124124    hard  <-  as.numeric(randomForest ::: predict.randomForest(model , newdata  =  dataset ))
@@ -133,7 +133,7 @@ is.jaspMachineLearning <- function(x) {
133133  if  (inherits(model , " jaspClassification"  )) {
134134    soft  <-  neuralnet ::: predict.nn(model , newdata  =  dataset )
135135    colnames(soft ) <-  levels(factor (model [[" data"  ]][[model [[" jaspVars"  ]][[" encoded"  ]]$ target ]]))
136-     hard  <-  colnames(soft )[apply (soft , 1 ,  which.max )]
136+     hard  <-  colnames(soft )[max.col (soft , ties.method   =   " random "  )]
137137    return (list (hard , soft ))
138138  } else  if  (inherits(model , " jaspRegression"  )) {
139139    hard  <-  as.numeric(neuralnet ::: predict.nn(model , newdata  =  dataset ))
@@ -144,7 +144,7 @@ is.jaspMachineLearning <- function(x) {
144144  if  (inherits(model , " jaspClassification"  )) {
145145    soft  <-  rpart ::: predict.rpart(model , newdata  =  dataset )
146146    colnames(soft ) <-  levels(factor (model [[" data"  ]][[model [[" jaspVars"  ]][[" encoded"  ]]$ target ]]))
147-     hard  <-  colnames(soft )[apply (soft , 1 ,  which.max )]
147+     hard  <-  colnames(soft )[max.col (soft , ties.method   =   " random "  )]
148148    return (list (hard , soft ))
149149  } else  if  (inherits(model , " jaspRegression"  )) {
150150    hard  <-  as.numeric(rpart ::: predict.rpart(model , newdata  =  dataset ))
@@ -154,7 +154,7 @@ is.jaspMachineLearning <- function(x) {
154154.mlPredictionGetPredictions.svm  <-  function (model , dataset ) {
155155  if  (inherits(model , " jaspClassification"  )) {
156156    soft  <-  attr(e1071 ::: predict.svm(model , newdata  =  dataset , probability  =  TRUE ), " probabilities"  )
157-     hard  <-  as.character( e1071 ::: predict.svm( model ,  newdata  =  dataset )) 
157+     hard  <-  colnames( soft )[max.col( soft ,  ties.method  =  " random " )] 
158158    return (list (hard , soft ))
159159  } else  if  (inherits(model , " jaspRegression"  )) {
160160    hard  <-  as.numeric(e1071 ::: predict.svm(model , newdata  =  dataset ))
@@ -163,14 +163,14 @@ is.jaspMachineLearning <- function(x) {
163163}
164164.mlPredictionGetPredictions.naiveBayes  <-  function (model , dataset ) {
165165  soft  <-  e1071 ::: predict.naiveBayes(model , newdata  =  dataset , type  =  " raw"  )
166-   hard  <-  as.character( e1071 ::: predict.naiveBayes( model ,  newdata  =  dataset ,  type   =   " class " )) 
166+   hard  <-  colnames( soft )[max.col( soft ,  ties.method  =  " random " )] 
167167  return (list (hard , soft ))
168168}
169169.mlPredictionGetPredictions.glm  <-  function (model , dataset ) {
170170  probs  <-  predict(model , newdata  =  dataset , type  =  " response"  )
171171  soft  <-  matrix (c(1  -  probs , probs ), ncol  =  2 )
172172  colnames(soft ) <-  levels(as.factor(model $ model [[model [[" jaspVars"  ]][[" encoded"  ]]$ target ]]))
173-   hard  <-  colnames(soft )[apply (soft , 1 ,  which.max )]
173+   hard  <-  colnames(soft )[max.col (soft , ties.method   =   " random "  )]
174174  return (list (hard , soft ))
175175}
176176.mlPredictionGetPredictions.vglm  <-  function (model , dataset ) {
@@ -183,7 +183,7 @@ is.jaspMachineLearning <- function(x) {
183183  soft [, ncategories ] <-  1 
184184  soft  <-  soft  /  rowSums(soft )
185185  colnames(soft ) <-  as.character(levels(as.factor(model $ target )))
186-   hard  <-  colnames(soft )[apply (soft , 1 ,  which.max )]
186+   hard  <-  colnames(soft )[max.col (soft , ties.method   =   " random "  )]
187187  return (list (hard , soft ))
188188}
189189
0 commit comments