Skip to content

Commit 587ecd4

Browse files
committed
Update
1 parent c298a2d commit 587ecd4

File tree

4 files changed

+21
-6
lines changed

4 files changed

+21
-6
lines changed

R/commonMachineLearningClassification.R

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@
152152
table$addColumnInfo(name = "smoothing", title = gettext("Smoothing"), type = "number")
153153
} else if (type == "logistic") {
154154
table$addColumnInfo(name = "family", title = gettext("Family"), type = "string")
155+
table$addColumnInfo(name = "link", title = gettext("Link"), type = "string")
155156
}
156157
# Add common columns
157158
table$addColumnInfo(name = "nTrain", title = gettext("n(Train)"), type = "integer")
@@ -324,8 +325,10 @@
324325
table$title <- gettext("Model Summary: Multinomial Regression Classification")
325326
}
326327
family <- classificationResult[["family"]]
328+
link <- classificationResult[["link"]]
327329
row <- data.frame(
328330
family = paste0(toupper(substr(family, 1, 1)), substr(family, 2, nchar(family))),
331+
link = paste0(toupper(substr(link, 1, 1)), substr(link, 2, nchar(link))),
329332
nTrain = nTrain,
330333
nTest = classificationResult[["ntest"]],
331334
testAcc = classificationResult[["testAcc"]]
@@ -585,11 +588,11 @@
585588
levels(predictions) <- unique(dataset[, options[["target"]]])
586589
} else if (type == "logistic") {
587590
if (classificationResult[["family"]] == "binomial") {
588-
fit <- glm(formula, data = dataset, family = "binomial")
591+
fit <- glm(formula, data = dataset, family = stats::binomial(link = "logit"))
589592
predictions <- as.factor(round(predict(fit, grid, type = "response"), 0))
590593
levels(predictions) <- unique(dataset[, options[["target"]]])
591594
} else {
592-
fit <- VGAM::vglm(formula, data = dataset, family = "multinomial")
595+
fit <- VGAM::vglm(formula, data = dataset, family = VGAM::multinomial())
593596
logodds <- predict(fit, newdata = grid)
594597
ncategories <- ncol(logodds) + 1
595598
probabilities <- matrix(0, nrow = nrow(logodds), ncol = ncategories)
@@ -743,7 +746,7 @@
743746
fit <- e1071::naiveBayes(formula = formula, data = typeData, laplace = options[["smoothingParameter"]])
744747
score <- max.col(predict(fit, test, type = "raw"))
745748
} else if (type == "logistic") {
746-
fit <- glm(formula, data = typeData, family = "binomial")
749+
fit <- glm(formula, data = typeData, family = stats::binomial(link = "logit"))
747750
score <- round(predict(fit, test, type = "response"), 0)
748751
}
749752
pred <- ROCR::prediction(score, actual.class)
@@ -1164,7 +1167,7 @@
11641167
}
11651168

11661169
.calcAUCScore.logisticClassification <- function(AUCformula, test, typeData, options, jaspResults, ...) {
1167-
fit <- glm(AUCformula, data = typeData, family = "binomial")
1170+
fit <- glm(AUCformula, data = typeData, family = stats::binomial(link = "logit"))
11681171
score <- round(predict(fit, test, type = "response"), 0)
11691172
return(score)
11701173
}

R/mlClassificationLogisticMultinomial.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,14 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
8888
}
8989
if (nlevels(trainingSet[[options[["target"]]]]) == 2) {
9090
family = "binomial"
91-
trainingFit <- glm(formula, data = trainingSet, family = stats::binomial(link = "logit"))
91+
linkFunction <- "logit"
92+
trainingFit <- glm(formula, data = trainingSet, family = stats::binomial(link = linkFunction))
9293
# Use the specified model to make predictions for dataset
9394
testPredictions <- .mlClassificationLogisticPredictions(trainingSet, options, predict(trainingFit, newdata = testSet, type = "response"))
9495
dataPredictions <- .mlClassificationLogisticPredictions(trainingSet, options, predict(trainingFit, newdata = dataset, type = "response"))
9596
} else {
9697
family <- "multinomial"
98+
linkFunction <- "logit"
9799
trainingFit <- VGAM::vglm(formula, data = trainingSet, family = VGAM::multinomial())
98100
# Use the specified model to make predictions for dataset
99101
testPredictions <- .mlClassificationMultinomialPredictions(trainingSet, options, predict(trainingFit, newdata = testSet))
@@ -103,6 +105,7 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
103105
result <- list()
104106
result[["formula"]] <- formula
105107
result[["family"]] <- family
108+
result[["link"]] <- linkFunction
106109
if (family == "binomial") {
107110
result[["model"]] <- trainingFit
108111
} else {

R/mlPrediction.R

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,9 @@ is.jaspMachineLearning <- function(x) {
365365
table$addColumnInfo(name = "mtry", title = gettext("Features per split"), type = "integer")
366366
} else if (inherits(model, "cv.glmnet")) {
367367
table$addColumnInfo(name = "lambda", title = "\u03BB", type = "number")
368+
} else if (inherits(model, "glm") || inherits(model, "vglm")) {
369+
table$addColumnInfo(name = "family", title = gettext("Family"), type = "string")
370+
table$addColumnInfo(name = "link", title = gettext("Link"), type = "string")
368371
}
369372
table$addColumnInfo(name = "ntrain", title = gettext("n(Train)"), type = "integer")
370373
table$addColumnInfo(name = "nnew", title = gettext("n(New)"), type = "integer")
@@ -383,6 +386,12 @@ is.jaspMachineLearning <- function(x) {
383386
row[["mtry"]] <- model[["mtry"]]
384387
} else if (inherits(model, "cv.glmnet")) {
385388
row[["lambda"]] <- model[["lambda.min"]]
389+
} else if (inherits(model, "glm")) {
390+
row[["family"]] <- gettext("Binomial")
391+
row[["link"]] <- gettext("Logit")
392+
} else if (inherits(model, "vglm")) {
393+
row[["family"]] <- gettext("Multinomial")
394+
row[["link"]] <- gettext("Logit")
386395
}
387396
if (length(presentVars) > 0) {
388397
row[["nnew"]] <- nrow(dataset)

tests/testthat/test-mlclassificationlogisticmultinomial.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ test_that("Class Proportions table results match", {
3939
test_that("Model Summary: Multinomial Regression Classification table results match", {
4040
table <- results[["results"]][["classificationTable"]][["data"]]
4141
jaspTools::expect_equal_tables(table,
42-
list("Multinomial", 30, 120, 1))
42+
list("Multinomial", "Logit", 30, 120, 1))
4343
})
4444

4545
test_that("Confusion Matrix table results match", {

0 commit comments

Comments
 (0)