Skip to content

Commit 4dc4a38

Browse files
committed
Update coef table
1 parent 775b948 commit 4dc4a38

File tree

3 files changed

+50
-10
lines changed

3 files changed

+50
-10
lines changed

R/mlClassificationLogisticMultinomial.R

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
139139
if (options[["coefTableConfInt"]]) {
140140
overtitle <- gettextf("%1$s%% Confidence interval", round(options[["coefTableConfIntLevel"]] * 100, 3))
141141
table$addColumnInfo(name = "lower", title = gettext("Lower"), type = "number", overtitle = overtitle)
142-
table$addColumnInfo(name = "upper", title = gettext("Upper"), type = "number", overtitle = overtitle)
142+
table$addColumnInfo(name = "upper", title = gettext("Upper"), type = "number", overtitle = overtitle)
143143
}
144144
if (options[["scaleVariables"]]) {
145145
table$addFootnote(gettext("The regression coefficients for numeric features are standardized."))
@@ -155,14 +155,54 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
155155
}
156156
classificationResult <- jaspResults[["classificationResult"]]$object
157157
model <- classificationResult[["model"]]
158-
coefs <- summary(model)$coefficients
159-
conf_int <- confint(model, level = options[["coefTableConfIntLevel"]])
160-
coefs <- cbind(coefs, lower = conf_int[, 1], upper = conf_int[, 2])
158+
if (classificationResult[["family"]] == "binomial") {
159+
coefs <- summary(model)$coefficients
160+
conf_int <- confint(model, level = options[["coefTableConfIntLevel"]])
161+
coefs <- cbind(coefs, lower = conf_int[, 1], upper = conf_int[, 2])
162+
colnames(coefs) <- c("est", "se", "t", "p", "lower", "upper")
163+
vars <- rownames(coefs)
164+
for (i in seq_along(vars)) {
165+
if (!(vars[i] %in% options[["predictors"]]) && vars[i] != "(Intercept)") {
166+
for (j in options[["predictors"]]) {
167+
vars[i] <- gsub(pattern = j, replacement = paste0(j, " ("), x = vars[i])
168+
}
169+
vars[i] <- paste0(vars[i], ")")
170+
}
171+
}
172+
rownames(coefs) <- vars
173+
} else {
174+
coefs <- cbind(model@coefficients, confint(model, level = options[["coefTableConfIntLevel"]]))
175+
colnames(coefs) <- c("est", "lower", "upper")
176+
vars <- rownames(coefs)
177+
for (i in seq_along(vars)) {
178+
for (j in c("(Intercept)", options[["predictors"]])) {
179+
if (!grepl(j, vars[i])) {
180+
next
181+
}
182+
splitvar <- strsplit(vars[i], split = ":")[[1]]
183+
if (grepl(paste0(j, "[A-Za-z]+:"), vars[i])) {
184+
repl_part1 <- paste0(gsub(pattern = j, replacement = paste0(j, " ("), x = splitvar[1]), ")")
185+
} else {
186+
repl_part1 <- j
187+
}
188+
repl_part2 <- levels(factor(classificationResult[["train"]][[options[["target"]]]]))[as.numeric(splitvar[2])]
189+
vars[i] <- paste0(repl_part1, " : ", repl_part2)
190+
}
191+
}
192+
rownames(coefs) <- vars
193+
}
161194
table[["var"]] <- rownames(coefs)
162-
table[["coefs"]] <- as.numeric(coefs[, 1])
163-
table[["se"]] <- as.numeric(coefs[, 2])
164-
table[["t"]] <- as.numeric(coefs[, 3])
165-
table[["p"]] <- as.numeric(coefs[, 4])
195+
table[["coefs"]] <- as.numeric(coefs[, "est"])
196+
if (classificationResult[["family"]] == "binomial") {
197+
table[["se"]] <- as.numeric(coefs[, "se"])
198+
table[["t"]] <- as.numeric(coefs[, "t"])
199+
table[["p"]] <- as.numeric(coefs[, "p"])
200+
} else {
201+
table[["se"]] <- rep(".", nrow(coefs))
202+
table[["t"]] <- rep(".", nrow(coefs))
203+
table[["p"]] <- rep(".", nrow(coefs))
204+
table$addFootnote(gettext("Standard errors, t-values and p-values are not available for multinomial regression coefficients."))
205+
}
166206
if (options[["coefTableConfInt"]]) {
167207
table[["lower"]] <- coefs[, "lower"]
168208
table[["upper"]] <- coefs[, "upper"]

tests/testthat/helper-ml.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ mlOptions <- function(analysis) {
2727
files <- c(files, list.files(testthat::test_path("..", "..", "inst", "qml", "common", "analyses", "randomforest"), full.names = TRUE))
2828
} else if (analysis %in% c("mlClassificationSvm", "mlRegressionSvm")) {
2929
files <- c(files, list.files(testthat::test_path("..", "..", "inst", "qml", "common", "analyses", "svm"), full.names = TRUE))
30-
} else if (analysis %in% c("mlClassificationLogistic", "mlRegressionLinear", "mlRegressionRegularized")) {
30+
} else if (analysis %in% c("mlClassificationLogisticMultinomial", "mlRegressionLinear", "mlRegressionRegularized")) {
3131
files <- c(files, list.files(testthat::test_path("..", "..", "inst", "qml", "common", "analyses", "regularized"), full.names = TRUE))
3232
}
3333
options <- lapply(files, jaspTools:::readQML) |>

tests/testthat/test-mlclassificationlogisticmultinomial.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
context("Machine Learning Logistic Regression Classification")
1+
context("Machine Learning Logistic / Multinomial Regression Classification")
22

33
# Test fixed model #############################################################
44
options <- initMlOptions("mlClassificationLogisticMultinomial")

0 commit comments

Comments
 (0)