@@ -75,6 +75,12 @@ is.jaspMachineLearning <- function(x) {
75
75
.mlPredictionGetModelType.naiveBayes <- function (model ) {
76
76
gettext(" Naive Bayes" )
77
77
}
78
+ .mlPredictionGetModelType.glm <- function (model ) {
79
+ gettext(" Logistic regression" )
80
+ }
81
+ .mlPredictionGetModelType.vglm <- function (model ) {
82
+ gettext(" Multinomial regression" )
83
+ }
78
84
79
85
# S3 method to make predictions using the model
80
86
.mlPredictionGetPredictions <- function (model , dataset ) {
@@ -135,6 +141,12 @@ is.jaspMachineLearning <- function(x) {
135
141
.mlPredictionGetPredictions.naiveBayes <- function (model , dataset ) {
136
142
as.character(e1071 ::: predict.naiveBayes(model , newdata = dataset , type = " class" ))
137
143
}
144
+ .mlPredictionGetPredictions.glm <- function (model , dataset ) {
145
+ # TODO
146
+ }
147
+ .mlPredictionGetPredictions.vglm <- function (model , dataset ) {
148
+ # TODO
149
+ }
138
150
139
151
# S3 method to make find out number of observations in training data
140
152
.mlPredictionGetTrainingN <- function (model ) {
@@ -170,6 +182,12 @@ is.jaspMachineLearning <- function(x) {
170
182
.mlPredictionGetTrainingN.naiveBayes <- function (model ) {
171
183
nrow(model [[" data" ]])
172
184
}
185
+ .mlPredictionGetTrainingN.glm <- function (model ) {
186
+ nrow(model [[" data" ]])
187
+ }
188
+ .mlPredictionGetTrainingN.vglm <- function (model ) {
189
+ nrow(model $ x )
190
+ }
173
191
174
192
# S3 method to decode the model variables in the result object
175
193
# so that they can be matched to variables in the prediction analysis
@@ -229,6 +247,14 @@ is.jaspMachineLearning <- function(x) {
229
247
names(model [[" tables" ]]) <- decodeColNames(names(model [[" tables" ]]))
230
248
return (model )
231
249
}
250
+ .decodeJaspMLobject.glm <- function (model ) {
251
+ # TODO
252
+ return (model )
253
+ }
254
+ .decodeJaspMLobject.vglm <- function (model ) {
255
+ # TODO
256
+ return (model )
257
+ }
232
258
233
259
.mlPredictionReadModel <- function (options ) {
234
260
if (options [[" trainedModelFilePath" ]] != " " ) {
@@ -238,7 +264,7 @@ is.jaspMachineLearning <- function(x) {
238
264
if (! is.jaspMachineLearning(model )) {
239
265
jaspBase ::: .quitAnalysis(gettext(" Error: The trained model is not created in JASP." ))
240
266
}
241
- if (! (any(c(" kknn" , " lda" , " gbm" , " randomForest" , " cv.glmnet" , " nn" , " rpart" , " svm" , " lm" , " naiveBayes" ) %in% class(model )))) {
267
+ if (! (any(c(" kknn" , " lda" , " gbm" , " randomForest" , " cv.glmnet" , " nn" , " rpart" , " svm" , " lm" , " naiveBayes" , " glm " , " vglm " ) %in% class(model )))) {
242
268
jaspBase ::: .quitAnalysis(gettextf(" The trained model (type: %1$s) is currently not supported in JASP." , paste(class(model ), collapse = " , " )))
243
269
}
244
270
if (model [[" jaspVersion" ]] != .baseCitation ) {
@@ -326,6 +352,8 @@ is.jaspMachineLearning <- function(x) {
326
352
table $ addColumnInfo(name = " mtry" , title = gettext(" Features per split" ), type = " integer" )
327
353
} else if (inherits(model , " cv.glmnet" )) {
328
354
table $ addColumnInfo(name = " lambda" , title = " \u 03BB" , type = " number" )
355
+ } else if (inherits(model , " glm" ) || inherits(model , " vglm" )) {
356
+ table $ addColumnInfo(name = " family" , title = gettext(" Family" ), type = " string" )
329
357
}
330
358
table $ addColumnInfo(name = " ntrain" , title = gettext(" n(Train)" ), type = " integer" )
331
359
table $ addColumnInfo(name = " nnew" , title = gettext(" n(New)" ), type = " integer" )
@@ -344,6 +372,10 @@ is.jaspMachineLearning <- function(x) {
344
372
row [[" mtry" ]] <- model [[" mtry" ]]
345
373
} else if (inherits(model , " cv.glmnet" )) {
346
374
row [[" lambda" ]] <- model [[" lambda.min" ]]
375
+ } else if (inherits(model , " glm" )) {
376
+ row [[" family" ]] <- gettext(" binomial" )
377
+ } else if (inherits(model , " vglm" )) {
378
+ row [[" family" ]] <- gettext(" multinomial" )
347
379
}
348
380
if (length(presentVars ) > 0 ) {
349
381
row [[" nnew" ]] <- nrow(dataset )
0 commit comments