Skip to content

Commit 1818d85

Browse files
committed
Add coefficients table for logistic
1 parent 7c86ce2 commit 1818d85

File tree

2 files changed

+68
-4
lines changed

2 files changed

+68
-4
lines changed

R/mlClassificationLogistic.R

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,16 @@ mlClassificationLogistic <- function(jaspResults, dataset, options, ...) {
5151
# # Create the shap table
5252
# .mlTableShap(dataset, options, jaspResults, ready, position = 7, purpose = "classification")
5353

54+
.mlClassificationLogisticTableCoef(options, jaspResults, ready, position = 8)
55+
5456
# # Create the ROC curve
55-
# .mlClassificationPlotRoc(dataset, options, jaspResults, ready, position = 8, type = "logistic")
57+
# .mlClassificationPlotRoc(dataset, options, jaspResults, ready, position = 10, type = "logistic") # position + 1 for regression equation
5658

5759
# Create the Andrews curves
58-
.mlClassificationPlotAndrews(dataset, options, jaspResults, ready, position = 9)
60+
.mlClassificationPlotAndrews(dataset, options, jaspResults, ready, position = 11)
5961

6062
# # Decision boundaries
61-
# .mlClassificationPlotBoundaries(dataset, options, jaspResults, ready, position = 10, type = "logistic")
63+
# .mlClassificationPlotBoundaries(dataset, options, jaspResults, ready, position = 12, type = "logistic")
6264
}
6365

6466
.logisticRegressionClassification <- function(dataset, options, jaspResults, ready) {
@@ -86,7 +88,7 @@ mlClassificationLogistic <- function(jaspResults, dataset, options, ...) {
8688
}
8789
if (nlevels(trainingSet[[options[["target"]]]]) == 2) {
8890
family = "binomial"
89-
trainingFit <- stats::glm(formula, data = trainingSet, family = family)
91+
trainingFit <- glm(formula, data = trainingSet, family = family)
9092
# Use the specified model to make predictions for dataset
9193
testPredictions <- .mlClassificationLogisticPredictions(trainingSet, options, predict(trainingFit, newdata = testSet, type = "response"))
9294
dataPredictions <- .mlClassificationLogisticPredictions(trainingSet, options, predict(trainingFit, newdata = dataset, type = "response"))
@@ -122,6 +124,67 @@ mlClassificationLogistic <- function(jaspResults, dataset, options, ...) {
122124
return(result)
123125
}
124126

127+
.mlClassificationLogisticTableCoef <- function(options, jaspResults, ready, position) {
128+
if (!is.null(jaspResults[["coefTable"]]) || !options[["coefTable"]]) {
129+
return()
130+
}
131+
table <- createJaspTable(gettext("Regression Coefficients"))
132+
table$position <- position
133+
table$dependOn(options = c("coefTable", "coefTableConfInt", "coefTableConfIntLevel", "formula", .mlClassificationDependencies()))
134+
table$addColumnInfo(name = "var", title = "", type = "string")
135+
table$addColumnInfo(name = "coefs", title = gettextf("Coefficient (%s)", "\u03B2"), type = "number")
136+
table$addColumnInfo(name = "se", title = gettext("Standard Error"), type = "number")
137+
table$addColumnInfo(name = "t", title = gettext("t"), type = "number")
138+
table$addColumnInfo(name = "p", title = gettext("p"), type = "pvalue")
139+
if (options[["coefTableConfInt"]]) {
140+
overtitle <- gettextf("%1$s%% Confidence interval", round(options[["coefTableConfIntLevel"]] * 100, 3))
141+
table$addColumnInfo(name = "lower", title = gettext("Lower"), type = "number", overtitle = overtitle)
142+
table$addColumnInfo(name = "upper", title = gettext("Upper"), type = "number", overtitle = overtitle)
143+
}
144+
if (options[["scaleVariables"]]) {
145+
table$addFootnote(gettext("The regression coefficients for numeric features are standardized."))
146+
} else {
147+
table$addFootnote(gettext("The regression coefficients are unstandardized."))
148+
}
149+
jaspResults[["coefTable"]] <- table
150+
if (!ready) {
151+
if (options[["target"]] == "" && length(unlist(options[["predictors"]])) > 0) {
152+
table[["var"]] <- c(if (options[["intercept"]]) "(Intercept)" else NULL, options[["predictors"]])
153+
}
154+
return()
155+
}
156+
classificationResult <- jaspResults[["classificationResult"]]$object
157+
model <- classificationResult[["model"]]
158+
coefs <- summary(model)$coefficients
159+
conf_int <- confint(model, level = options[["coefTableConfIntLevel"]])
160+
coefs <- cbind(coefs, lower = conf_int[, 1], upper = conf_int[, 2])
161+
table[["var"]] <- rownames(coefs)
162+
table[["coefs"]] <- as.numeric(coefs[, 1])
163+
table[["se"]] <- as.numeric(coefs[, 2])
164+
table[["t"]] <- as.numeric(coefs[, 3])
165+
table[["p"]] <- as.numeric(coefs[, 4])
166+
if (options[["coefTableConfInt"]]) {
167+
table[["lower"]] <- coefs[, "lower"]
168+
table[["upper"]] <- coefs[, "upper"]
169+
}
170+
if (options[["formula"]]) {
171+
if (options[["intercept"]]) {
172+
regform <- paste0("logit(", options[["target"]], ") = ", round(as.numeric(coefs[, 1])[1], 3))
173+
start <- 2
174+
} else {
175+
regform <- paste0("logit(", options[["target"]], ") = ")
176+
start <- 1
177+
}
178+
for (i in start:nrow(coefs)) {
179+
regform <- paste0(regform, if (round(as.numeric(coefs[, 1])[i], 3) < 0) " - " else " + ", abs(round(as.numeric(coefs[, 1])[i], 3)), " x ", rownames(coefs)[i])
180+
}
181+
formula <- createJaspHtml(gettextf("<b>Regression equation:</b>\n%1$s", regform), "p")
182+
formula$position <- position + 1
183+
formula$dependOn(options = c("coefTable", "formula"), optionsFromObject = jaspResults[["classificationResult"]])
184+
jaspResults[["regressionFormula"]] <- formula
185+
}
186+
}
187+
125188
.mlClassificationLogisticPredictions <- function(trainingSet, options, probabilities) {
126189
categories <- levels(trainingSet[[options[["target"]]]])
127190
predicted_categories <- categories[round(probabilities, 0) + 1]

inst/qml/mlClassificationLogistic.qml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Form
4141
TAB.ModelPerformance { }
4242
TAB.FeatureImportance { }
4343
TAB.ExplainPredictions { }
44+
REGU.CoefficientTable { confint: true }
4445
}
4546

4647
Group

0 commit comments

Comments
 (0)