Skip to content

Commit 09b2b45

Browse files
committed
Start with prediction
1 parent ef552ae commit 09b2b45

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

R/mlClassificationLogisticMultinomial.R

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,15 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
103103
result <- list()
104104
result[["formula"]] <- formula
105105
result[["family"]] <- family
106-
result[["model"]] <- trainingFit
106+
if (family == "binomial") {
107+
result[["model"]] <- trainingFit
108+
} else {
109+
model <- lapply(slotNames(trainingFit), function(x) slot(trainingFit, x))
110+
names(model) <- slotNames(trainingFit)
111+
model[["original"]] <- trainingFit
112+
class(model) <- "vglm"
113+
result[["model"]] <- model
114+
}
107115
result[["confTable"]] <- table("Pred" = testPredictions, "Real" = testSet[, options[["target"]]])
108116
result[["testAcc"]] <- sum(diag(prop.table(result[["confTable"]])))
109117
result[["auc"]] <- .classificationCalcAUC(testSet, trainingSet, options, "logisticClassification")
@@ -171,7 +179,7 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
171179
}
172180
rownames(coefs) <- vars
173181
} else {
174-
coefs <- cbind(model@coefficients, confint(model, level = options[["coefTableConfIntLevel"]]))
182+
coefs <- cbind(model$coefficients, confint(model[["original"]], level = options[["coefTableConfIntLevel"]]))
175183
colnames(coefs) <- c("est", "lower", "upper")
176184
vars <- rownames(coefs)
177185
for (i in seq_along(vars)) {

R/mlPrediction.R

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ is.jaspMachineLearning <- function(x) {
7575
.mlPredictionGetModelType.naiveBayes <- function(model) {
7676
gettext("Naive Bayes")
7777
}
78+
.mlPredictionGetModelType.glm <- function(model) {
79+
gettext("Logistic regression")
80+
}
81+
.mlPredictionGetModelType.vglm <- function(model) {
82+
gettext("Multinomial regression")
83+
}
7884

7985
# S3 method to make predictions using the model
8086
.mlPredictionGetPredictions <- function(model, dataset) {
@@ -135,6 +141,12 @@ is.jaspMachineLearning <- function(x) {
135141
.mlPredictionGetPredictions.naiveBayes <- function(model, dataset) {
136142
as.character(e1071:::predict.naiveBayes(model, newdata = dataset, type = "class"))
137143
}
144+
.mlPredictionGetPredictions.glm <- function(model, dataset) {
145+
# TODO
146+
}
147+
.mlPredictionGetPredictions.vglm <- function(model, dataset) {
148+
# TODO
149+
}
138150

139151
# S3 method to make find out number of observations in training data
140152
.mlPredictionGetTrainingN <- function(model) {
@@ -170,6 +182,12 @@ is.jaspMachineLearning <- function(x) {
170182
.mlPredictionGetTrainingN.naiveBayes <- function(model) {
171183
nrow(model[["data"]])
172184
}
185+
.mlPredictionGetTrainingN.glm <- function(model) {
186+
nrow(model[["data"]])
187+
}
188+
.mlPredictionGetTrainingN.vglm <- function(model) {
189+
nrow(model$x)
190+
}
173191

174192
# S3 method to decode the model variables in the result object
175193
# so that they can be matched to variables in the prediction analysis
@@ -229,6 +247,14 @@ is.jaspMachineLearning <- function(x) {
229247
names(model[["tables"]]) <- decodeColNames(names(model[["tables"]]))
230248
return(model)
231249
}
250+
.decodeJaspMLobject.glm <- function(model) {
251+
# TODO
252+
return(model)
253+
}
254+
.decodeJaspMLobject.vglm <- function(model) {
255+
# TODO
256+
return(model)
257+
}
232258

233259
.mlPredictionReadModel <- function(options) {
234260
if (options[["trainedModelFilePath"]] != "") {
@@ -238,7 +264,7 @@ is.jaspMachineLearning <- function(x) {
238264
if (!is.jaspMachineLearning(model)) {
239265
jaspBase:::.quitAnalysis(gettext("Error: The trained model is not created in JASP."))
240266
}
241-
if (!(any(c("kknn", "lda", "gbm", "randomForest", "cv.glmnet", "nn", "rpart", "svm", "lm", "naiveBayes") %in% class(model)))) {
267+
if (!(any(c("kknn", "lda", "gbm", "randomForest", "cv.glmnet", "nn", "rpart", "svm", "lm", "naiveBayes", "glm", "vglm") %in% class(model)))) {
242268
jaspBase:::.quitAnalysis(gettextf("The trained model (type: %1$s) is currently not supported in JASP.", paste(class(model), collapse = ", ")))
243269
}
244270
if (model[["jaspVersion"]] != .baseCitation) {
@@ -326,6 +352,8 @@ is.jaspMachineLearning <- function(x) {
326352
table$addColumnInfo(name = "mtry", title = gettext("Features per split"), type = "integer")
327353
} else if (inherits(model, "cv.glmnet")) {
328354
table$addColumnInfo(name = "lambda", title = "\u03BB", type = "number")
355+
} else if (inherits(model, "glm") || inherits(model, "vglm")) {
356+
table$addColumnInfo(name = "family", title = gettext("Family"), type = "string")
329357
}
330358
table$addColumnInfo(name = "ntrain", title = gettext("n(Train)"), type = "integer")
331359
table$addColumnInfo(name = "nnew", title = gettext("n(New)"), type = "integer")
@@ -344,6 +372,10 @@ is.jaspMachineLearning <- function(x) {
344372
row[["mtry"]] <- model[["mtry"]]
345373
} else if (inherits(model, "cv.glmnet")) {
346374
row[["lambda"]] <- model[["lambda.min"]]
375+
} else if (inherits(model, "glm")) {
376+
row[["family"]] <- gettext("binomial")
377+
} else if (inherits(model, "vglm")) {
378+
row[["family"]] <- gettext("multinomial")
347379
}
348380
if (length(presentVars) > 0) {
349381
row[["nnew"]] <- nrow(dataset)

0 commit comments

Comments
 (0)