Skip to content

Commit 2928da0

Browse files
committed
Support different link functions
1 parent ddd7783 commit 2928da0

File tree

4 files changed

+29
-14
lines changed

4 files changed

+29
-14
lines changed

R/commonMachineLearningClassification.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
"noOfTrees", "maxTrees", "baggingFraction", "noOfPredictors", "numberOfPredictors", # Random forest
3434
"complexityParameter", "degree", "gamma", "cost", "tolerance", "epsilon", "maxCost", # Support vector machine
3535
"smoothingParameter", # Naive Bayes
36-
"intercept" # Logistic
36+
"intercept", "link" # Logistic
3737
)
3838
if (includeSaveOptions) {
3939
opt <- c(opt, "saveModel", "savePath")
@@ -588,7 +588,7 @@
588588
levels(predictions) <- unique(dataset[, options[["target"]]])
589589
} else if (type == "logistic") {
590590
if (classificationResult[["family"]] == "binomial") {
591-
fit <- glm(formula, data = dataset, family = stats::binomial(link = "logit"))
591+
fit <- stats::glm(formula, data = dataset, family = stats::binomial(link = options[["link"]]))
592592
predictions <- as.factor(round(predict(fit, grid, type = "response"), 0))
593593
levels(predictions) <- unique(dataset[, options[["target"]]])
594594
} else {
@@ -746,7 +746,7 @@
746746
fit <- e1071::naiveBayes(formula = formula, data = typeData, laplace = options[["smoothingParameter"]])
747747
score <- max.col(predict(fit, test, type = "raw"))
748748
} else if (type == "logistic") {
749-
fit <- glm(formula, data = typeData, family = stats::binomial(link = "logit"))
749+
fit <- stats::glm(formula, data = typeData, family = stats::binomial(link = options[["link"]]))
750750
score <- round(predict(fit, test, type = "response"), 0)
751751
}
752752
pred <- ROCR::prediction(score, actual.class)
@@ -1167,7 +1167,7 @@
11671167
}
11681168

11691169
.calcAUCScore.logisticClassification <- function(AUCformula, test, typeData, options, jaspResults, ...) {
1170-
fit <- glm(AUCformula, data = typeData, family = stats::binomial(link = "logit"))
1170+
fit <- stats::glm(AUCformula, data = typeData, family = stats::binomial(link = options[["link"]]))
11711171
score <- round(predict(fit, test, type = "response"), 0)
11721172
return(score)
11731173
}

R/mlClassificationLogisticMultinomial.R

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,18 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
8888
}
8989
if (nlevels(trainingSet[[options[["target"]]]]) == 2) {
9090
family = "binomial"
91-
linkFunction <- "logit"
92-
trainingFit <- glm(formula, data = trainingSet, family = stats::binomial(link = linkFunction))
91+
linkFunction <- options[["link"]]
92+
trainingFit <- stats::glm(formula, data = trainingSet, family = stats::binomial(link = linkFunction))
9393
# Use the specified model to make predictions for dataset
94-
testPredictions <- .mlClassificationLogisticPredictions(trainingSet, options, predict(trainingFit, newdata = testSet, type = "response"))
95-
dataPredictions <- .mlClassificationLogisticPredictions(trainingSet, options, predict(trainingFit, newdata = dataset, type = "response"))
94+
testPredictions <- .mlClassificationLogisticPredictions(trainingSet, options, stats::predict(trainingFit, newdata = testSet, type = "response"))
95+
dataPredictions <- .mlClassificationLogisticPredictions(trainingSet, options, stats::predict(trainingFit, newdata = dataset, type = "response"))
9696
} else {
9797
family <- "multinomial"
9898
linkFunction <- "logit"
9999
trainingFit <- VGAM::vglm(formula, data = trainingSet, family = VGAM::multinomial())
100100
# Use the specified model to make predictions for dataset
101-
testPredictions <- .mlClassificationMultinomialPredictions(trainingSet, options, predict(trainingFit, newdata = testSet))
102-
dataPredictions <- .mlClassificationMultinomialPredictions(trainingSet, options, predict(trainingFit, newdata = dataset))
101+
testPredictions <- .mlClassificationMultinomialPredictions(trainingSet, options, VGAM::predict(trainingFit, newdata = testSet))
102+
dataPredictions <- .mlClassificationMultinomialPredictions(trainingSet, options, VGAM::predict(trainingFit, newdata = dataset))
103103
}
104104
# Create results object
105105
result <- list()
@@ -108,6 +108,7 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
108108
result[["link"]] <- linkFunction
109109
if (family == "binomial") {
110110
result[["model"]] <- trainingFit
111+
result[["model"]]$link <- result[["link"]]
111112
} else {
112113
model <- lapply(slotNames(trainingFit), function(x) slot(trainingFit, x))
113114
names(model) <- slotNames(trainingFit)
@@ -211,14 +212,14 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
211212
table[["lower"]] <- coefs[, "lower"]
212213
table[["upper"]] <- coefs[, "upper"]
213214
}
214-
if (options[["formula"]]) { # TODO FOR MULTINOMIAL
215+
if (options[["formula"]]) {
215216
if (classificationResult[["family"]] == "binomial") {
216217
one_cat <- levels(factor(classificationResult[["train"]][[options[["target"]]]]))[2]
217218
if (options[["intercept"]]) {
218-
regform <- paste0("logit(p<sub>", options[["target"]], " = ", one_cat, "</sub>) = ", round(as.numeric(coefs[, 1])[1], 3))
219+
regform <- paste0(options[["link"]], "(p<sub>", options[["target"]], " = ", one_cat, "</sub>) = ", round(as.numeric(coefs[, 1])[1], 3))
219220
start <- 2
220221
} else {
221-
regform <- paste0("logit(p<sub>", options[["target"]], " = ", one_cat, "</sub>) = ")
222+
regform <- paste0(options[["link"]], "(p<sub>", options[["target"]], " = ", one_cat, "</sub>) = ")
222223
start <- 1
223224
}
224225
for (i in start:nrow(coefs)) {

R/mlPrediction.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ is.jaspMachineLearning <- function(x) {
388388
row[["lambda"]] <- model[["lambda.min"]]
389389
} else if (inherits(model, "glm")) {
390390
row[["family"]] <- gettext("Binomial")
391-
row[["link"]] <- gettext("Logit")
391+
row[["link"]] <- paste0(toupper(substr(model[["link"]], 1, 1)), substr(model[["link"]], 2, nchar(model[["link"]])))
392392
} else if (inherits(model, "vglm")) {
393393
row[["family"]] <- gettext("Multinomial")
394394
row[["link"]] <- gettext("Logit")

inst/qml/mlClassificationLogisticMultinomial.qml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,20 @@ Form
6565
{
6666
title: qsTr("Algorithmic Settings")
6767

68+
DropDown
69+
{
70+
name: "link"
71+
indexDefaultValue: 0
72+
label: qsTr("Link function (for binary classification)")
73+
values:
74+
[
75+
{ label: qsTr("Logit"), value: "logit"},
76+
{ label: qsTr("Probit"), value: "probit"},
77+
{ label: qsTr("Cauchit"), value: "cauchit"},
78+
{ label: qsTr("C log-log"), value: "cloglog"},
79+
{ label: qsTr("Log"), value: "log"}
80+
]
81+
}
6882
REGU.Intercept { }
6983
UI.ScaleVariables { }
7084
UI.SetSeed { }

0 commit comments

Comments
 (0)