diff --git a/DESCRIPTION b/DESCRIPTION index 0512eca1..b56e2b4b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -41,7 +41,8 @@ Imports: rpart (>= 4.1.16), ROCR, Rtsne, - signal + signal, + VGAM Suggests: testthat Remotes: diff --git a/NAMESPACE b/NAMESPACE index 20ca9f20..7efc6821 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -51,6 +51,7 @@ export(mlClassificationBoosting) export(mlClassificationDecisionTree) export(mlClassificationKnn) export(mlClassificationLda) +export(mlClassificationLogisticMultinomial) export(mlClassificationNaiveBayes) export(mlClassificationNeuralNetwork) export(mlClassificationRandomForest) diff --git a/R/commonMachineLearningClassification.R b/R/commonMachineLearningClassification.R index 0408e4f9..6345d19d 100644 --- a/R/commonMachineLearningClassification.R +++ b/R/commonMachineLearningClassification.R @@ -32,7 +32,8 @@ "mutationMethod", "survivalMethod", "elitismProportion", "candidates", # Neural network "noOfTrees", "maxTrees", "baggingFraction", "noOfPredictors", "numberOfPredictors", # Random forest "complexityParameter", "degree", "gamma", "cost", "tolerance", "epsilon", "maxCost", # Support vector machine - "smoothingParameter" # Naive Bayes + "smoothingParameter", # Naive Bayes + "intercept", "link" # Logistic ) if (includeSaveOptions) { opt <- c(opt, "saveModel", "savePath") @@ -62,7 +63,7 @@ if (type == "lda" || type == "randomForest" || type == "boosting") { # Require at least 2 features ready <- length(options[["predictors"]][options[["predictors"]] != ""]) >= 2 && options[["target"]] != "" - } else if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm" || type == "naivebayes") { + } else if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm" || type == "naivebayes" || type == "logistic") { # Require at least 1 features ready <- length(options[["predictors"]][options[["predictors"]] != ""]) >= 1 && options[["target"]] != "" } @@ -93,7 +94,8 @@ "neuralnet" = .neuralnetClassification(dataset, options, jaspResults), "rpart" = .decisionTreeClassification(dataset, options, jaspResults), "svm" = .svmClassification(dataset, options, jaspResults), - "naivebayes" = .naiveBayesClassification(dataset, options, jaspResults) + "naivebayes" = .naiveBayesClassification(dataset, options, jaspResults), + "logistic" = .logisticMultinomialClassification(dataset, options, jaspResults) ) }) if (isTryError(p)) { # Fail gracefully @@ -116,7 +118,8 @@ "neuralnet" = gettext("Neural Network Classification"), "rpart" = gettext("Decision Tree Classification"), "svm" = gettext("Support Vector Machine Classification"), - "naivebayes" = gettext("Naive Bayes Classification") + "naivebayes" = gettext("Naive Bayes Classification"), + "logistic" = gettext("Logistic / Multinomial Regression Classification") ) tableTitle <- gettextf("Model Summary: %1$s", title) table <- createJaspTable(tableTitle) @@ -147,6 +150,9 @@ table$addColumnInfo(name = "vectors", title = gettext("Support Vectors"), type = "integer") } else if (type == "naivebayes") { table$addColumnInfo(name = "smoothing", title = gettext("Smoothing"), type = "number") + } else if (type == "logistic") { + table$addColumnInfo(name = "family", title = gettext("Family"), type = "string") + table$addColumnInfo(name = "link", title = gettext("Link"), type = "string") } # Add common columns table$addColumnInfo(name = "nTrain", title = gettext("n(Train)"), type = "integer") @@ -164,7 +170,7 @@ } # If no analysis is run, specify the required variables in a footnote if (!ready) { - table$addFootnote(gettextf("Please provide a target variable and at least %i feature variable(s).", if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm") 1L else 2L)) + table$addFootnote(gettextf("Please provide a target variable and at least %i feature variable(s).", if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm" || type == "logistic") 1L else 2L)) } if (options[["savePath"]] != "") { validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0) @@ -312,6 +318,22 @@ testAcc = classificationResult[["testAcc"]] ) table$addRows(row) + } else if (type == "logistic") { + if (classificationResult[["family"]] == "binomial") { + table$title <- gettext("Model Summary: Logistic Regression Classification") + } else { + table$title <- gettext("Model Summary: Multinomial Regression Classification") + } + family <- classificationResult[["family"]] + link <- classificationResult[["link"]] + row <- data.frame( + family = paste0(toupper(substr(family, 1, 1)), substr(family, 2, nchar(family))), + link = paste0(toupper(substr(link, 1, 1)), substr(link, 2, nchar(link))), + nTrain = nTrain, + nTest = classificationResult[["ntest"]], + testAcc = classificationResult[["testAcc"]] + ) + table$addRows(row) } # Save the applied model if requested if (options[["saveModel"]] && options[["savePath"]] != "") { @@ -564,6 +586,26 @@ fit <- e1071::naiveBayes(formula, data = dataset, laplace = options[["smoothingParameter"]]) predictions <- as.factor(max.col(predict(fit, newdata = grid, type = "raw"))) levels(predictions) <- unique(dataset[, options[["target"]]]) + } else if (type == "logistic") { + if (classificationResult[["family"]] == "binomial") { + fit <- stats::glm(formula, data = dataset, family = stats::binomial(link = options[["link"]])) + predictions <- as.factor(round(predict(fit, grid, type = "response"), 0)) + levels(predictions) <- unique(dataset[, options[["target"]]]) + } else { + fit <- VGAM::vglm(formula, data = dataset, family = VGAM::multinomial()) + logodds <- predict(fit, newdata = grid) + ncategories <- ncol(logodds) + 1 + probabilities <- matrix(0, nrow = nrow(logodds), ncol = ncategories) + for (i in seq_len(ncategories - 1)) { + probabilities[, i] <- exp(logodds[, i]) + } + probabilities[, ncategories] <- 1 + row_sums <- rowSums(probabilities) + probabilities <- probabilities / row_sums + predicted_columns <- apply(probabilities, 1, which.max) + categories <- levels(dataset[[options[["target"]]]]) + predictions <- as.factor(categories[predicted_columns]) + } } shapes <- rep(21, nrow(dataset)) if (type == "svm") { @@ -703,6 +745,9 @@ } else if (type == "naivebayes") { fit <- e1071::naiveBayes(formula = formula, data = typeData, laplace = options[["smoothingParameter"]]) score <- max.col(predict(fit, test, type = "raw")) + } else if (type == "logistic") { + fit <- stats::glm(formula, data = typeData, family = stats::binomial(link = options[["link"]])) + score <- round(predict(fit, test, type = "response"), 0) } pred <- ROCR::prediction(score, actual.class) nbperf <- ROCR::performance(pred, "tpr", "fpr") @@ -1120,3 +1165,9 @@ score <- max.col(predict(fit, test, type = "raw")) return(score) } + +.calcAUCScore.logisticClassification <- function(AUCformula, test, typeData, options, jaspResults, ...) { + fit <- stats::glm(AUCformula, data = typeData, family = stats::binomial(link = options[["link"]])) + score <- round(predict(fit, test, type = "response"), 0) + return(score) +} diff --git a/R/mlClassificationLogisticMultinomial.R b/R/mlClassificationLogisticMultinomial.R new file mode 100644 index 00000000..465e7ed5 --- /dev/null +++ b/R/mlClassificationLogisticMultinomial.R @@ -0,0 +1,285 @@ +# +# Copyright (C) 2013-2021 University of Amsterdam +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, ...) { + + # Preparatory work + dataset <- .mlClassificationReadData(dataset, options) + .mlClassificationErrorHandling(dataset, options, type = "logistic") + + # Check if analysis is ready to run + ready <- .mlClassificationReady(options, type = "logistic") + + # Compute results and create the model summary table + .mlClassificationTableSummary(dataset, options, jaspResults, ready, position = 1, type = "logistic") + + # If the user wants to add the classes to the data set + .mlClassificationAddPredictionsToData(dataset, options, jaspResults, ready) + + # Add test set indicator to data + .mlAddTestIndicatorToData(options, jaspResults, ready, purpose = "classification") + + # Create the data split plot + .mlPlotDataSplit(dataset, options, jaspResults, ready, position = 2, purpose = "classification", type = "logistic") + + # Create the confusion table + .mlClassificationTableConfusion(dataset, options, jaspResults, ready, position = 3) + + # Create the class proportions table + .mlClassificationTableProportions(dataset, options, jaspResults, ready, position = 4) + + # Create the validation measures table + .mlClassificationTableMetrics(dataset, options, jaspResults, ready, position = 5) + + # Create the variable importance table + .mlTableFeatureImportance(options, jaspResults, ready, position = 6, purpose = "classification") + + # Create the shap table + .mlTableShap(dataset, options, jaspResults, ready, position = 7, purpose = "classification") + + .mlClassificationLogisticTableCoef(options, jaspResults, ready, position = 8) + + # Create the ROC curve + .mlClassificationPlotRoc(dataset, options, jaspResults, ready, position = 10, type = "logistic") # position + 1 for regression equation + + # Create the Andrews curves + .mlClassificationPlotAndrews(dataset, options, jaspResults, ready, position = 11) + + # Decision boundaries + .mlClassificationPlotBoundaries(dataset, options, jaspResults, ready, position = 12, type = "logistic") +} + +.logisticMultinomialClassification <- function(dataset, options, jaspResults, ready) { + # Import model formula from jaspResults + formula <- jaspResults[["formula"]]$object + # Split the data into training and test sets + if (options[["holdoutData"]] == "testSetIndicator" && options[["testSetIndicatorVariable"]] != "") { + # Select observations according to a user-specified indicator (included when indicator = 1) + trainingIndex <- which(dataset[, options[["testSetIndicatorVariable"]]] == 0) + } else { + # Sample a percentage of the total data set + trainingIndex <- sample.int(nrow(dataset), size = ceiling((1 - options[["testDataManual"]]) * nrow(dataset))) + } + trainingSet <- dataset[trainingIndex, ] + # Create the generated test set indicator + testIndicatorColumn <- rep(1, nrow(dataset)) + testIndicatorColumn[trainingIndex] <- 0 + # Just create a train and a test set (no optimization) + testSet <- dataset[-trainingIndex, ] + # Create the formula + if (options[["intercept"]]) { + formula <- formula(paste(options[["target"]], "~ 1 + ", paste(options[["predictors"]], collapse = " + "))) + } else { + formula <- formula(paste(options[["target"]], "~ 0 + ", paste(options[["predictors"]], collapse = " + "))) + } + if (nlevels(trainingSet[[options[["target"]]]]) == 2) { + family = "binomial" + linkFunction <- options[["link"]] + trainingFit <- stats::glm(formula, data = trainingSet, family = stats::binomial(link = linkFunction)) + # Use the specified model to make predictions for dataset + testPredictions <- .mlClassificationLogisticPredictions(trainingSet, options, stats::predict(trainingFit, newdata = testSet, type = "response")) + dataPredictions <- .mlClassificationLogisticPredictions(trainingSet, options, stats::predict(trainingFit, newdata = dataset, type = "response")) + } else { + family <- "multinomial" + linkFunction <- "logit" + trainingFit <- VGAM::vglm(formula, data = trainingSet, family = VGAM::multinomial()) + # Use the specified model to make predictions for dataset + testPredictions <- .mlClassificationMultinomialPredictions(trainingSet, options, VGAM::predict(trainingFit, newdata = testSet)) + dataPredictions <- .mlClassificationMultinomialPredictions(trainingSet, options, VGAM::predict(trainingFit, newdata = dataset)) + } + # Create results object + result <- list() + result[["formula"]] <- formula + result[["family"]] <- family + result[["link"]] <- linkFunction + if (family == "binomial") { + result[["model"]] <- trainingFit + result[["model"]]$link <- result[["link"]] + } else { + model <- lapply(slotNames(trainingFit), function(x) slot(trainingFit, x)) + names(model) <- slotNames(trainingFit) + model[["original"]] <- trainingFit + model[["target"]] <- trainingSet[[options[["target"]]]] + class(model) <- "vglm" + result[["model"]] <- model + } + result[["confTable"]] <- table("Pred" = testPredictions, "Real" = testSet[, options[["target"]]]) + result[["testAcc"]] <- sum(diag(prop.table(result[["confTable"]]))) + result[["auc"]] <- .classificationCalcAUC(testSet, trainingSet, options, "logisticClassification") + result[["ntrain"]] <- nrow(trainingSet) + result[["ntest"]] <- nrow(testSet) + result[["testReal"]] <- testSet[, options[["target"]]] + result[["testPred"]] <- testPredictions + result[["train"]] <- trainingSet + result[["test"]] <- testSet + result[["testIndicatorColumn"]] <- testIndicatorColumn + result[["classes"]] <- dataPredictions + if (family == "binomial") { + result[["explainer"]] <- DALEX::explain(result[["model"]], type = "classification", data = result[["train"]], y = result[["train"]][, options[["target"]]], predict_function = function(model, data) data.frame(1 - predict(model, newdata = data, type = "response"), predict(model, newdata = data, type = "response"))) + result[["explainer_fi"]] <- DALEX::explain(result[["model"]], type = "classification", data = result[["train"]], y = as.numeric(result[["train"]][, options[["target"]]]) - 1, predict_function = function(model, data) round(predict(model, newdata = data, type = "response"), 0) + 1) + } else { + result[["explainer"]] <- DALEX::explain(result[["model"]][["original"]], type = "multiclass", data = result[["train"]], y = result[["train"]][, options[["target"]]], predict_function = function(model, data) VGAM::predict(model, data, type = "response")) + result[["explainer_fi"]] <- result[["explainer"]] + } + return(result) +} + +.mlClassificationLogisticTableCoef <- function(options, jaspResults, ready, position) { + if (!is.null(jaspResults[["coefTable"]]) || !options[["coefTable"]]) { + return() + } + table <- createJaspTable(gettext("Regression Coefficients")) + table$position <- position + table$dependOn(options = c("coefTable", "coefTableConfInt", "coefTableConfIntLevel", "formula", .mlClassificationDependencies())) + table$addColumnInfo(name = "var", title = "", type = "string") + table$addColumnInfo(name = "coefs", title = gettextf("Coefficient (%s)", "\u03B2"), type = "number") + table$addColumnInfo(name = "se", title = gettext("Standard Error"), type = "number") + table$addColumnInfo(name = "z", title = gettext("z"), type = "number") + table$addColumnInfo(name = "p", title = gettext("p"), type = "pvalue") + if (options[["coefTableConfInt"]]) { + overtitle <- gettextf("%1$s%% Confidence interval", round(options[["coefTableConfIntLevel"]] * 100, 3)) + table$addColumnInfo(name = "lower", title = gettext("Lower"), type = "number", overtitle = overtitle) + table$addColumnInfo(name = "upper", title = gettext("Upper"), type = "number", overtitle = overtitle) + } + if (options[["scaleVariables"]]) { + table$addFootnote(gettext("The regression coefficients for numeric features are standardized.")) + } else { + table$addFootnote(gettext("The regression coefficients are unstandardized.")) + } + jaspResults[["coefTable"]] <- table + if (!ready) { + if (options[["target"]] == "" && length(unlist(options[["predictors"]])) > 0) { + table[["var"]] <- c(if (options[["intercept"]]) "(Intercept)" else NULL, options[["predictors"]]) + } + return() + } + classificationResult <- jaspResults[["classificationResult"]]$object + model <- classificationResult[["model"]] + if (classificationResult[["family"]] == "binomial") { + estimates <- coef(summary(model)) + conf_int <- confint(model, level = options[["coefTableConfIntLevel"]]) + if (!options[["intercept"]] && length(options[["predictors"]] == 1)) { + coefs <- cbind(estimates, conf_int[1], conf_int[2]) + } else { + coefs <- cbind(estimates, conf_int) + } + colnames(coefs) <- c("est", "se", "z", "p", "lower", "upper") + vars <- rownames(coefs) + for (i in seq_along(vars)) { + if (!(vars[i] %in% options[["predictors"]]) && vars[i] != "(Intercept)") { + for (j in options[["predictors"]]) { + vars[i] <- gsub(pattern = j, replacement = paste0(j, " ("), x = vars[i]) + } + vars[i] <- paste0(vars[i], ")") + } + } + rownames(coefs) <- vars + } else { + estimates <- VGAM::coef(VGAM::summaryvglm(model[["original"]])) + conf_int <- confint(model[["original"]], level = options[["coefTableConfIntLevel"]]) + coefs <- cbind(estimates, conf_int) + colnames(coefs) <- c("est", "se", "z", "p", "lower", "upper") + vars <- rownames(coefs) + for (i in seq_along(vars)) { + for (j in c("(Intercept)", options[["predictors"]])) { + if (!grepl(j, vars[i])) { + next + } + splitvar <- strsplit(vars[i], split = ":")[[1]] + if (grepl(paste0(j, "[A-Za-z]+:"), vars[i])) { + repl_part1 <- paste0(gsub(pattern = j, replacement = paste0(j, " ("), x = splitvar[1]), ")") + } else { + repl_part1 <- j + } + repl_part2 <- levels(factor(classificationResult[["train"]][[options[["target"]]]]))[as.numeric(splitvar[2])] + vars[i] <- paste0(repl_part1, " : ", repl_part2) + } + } + rownames(coefs) <- vars + } + table[["var"]] <- rownames(coefs) + table[["coefs"]] <- as.numeric(coefs[, "est"]) + table[["se"]] <- as.numeric(coefs[, "se"]) + table[["z"]] <- as.numeric(coefs[, "z"]) + table[["p"]] <- as.numeric(coefs[, "p"]) + if (options[["coefTableConfInt"]]) { + table[["lower"]] <- coefs[, "lower"] + table[["upper"]] <- coefs[, "upper"] + } + if (options[["formula"]]) { + if (classificationResult[["family"]] == "binomial") { + one_cat <- levels(factor(classificationResult[["train"]][[options[["target"]]]]))[2] + if (options[["intercept"]]) { + regform <- paste0(options[["link"]], "(p", options[["target"]], " = ", one_cat, ") = ", round(as.numeric(coefs[, 1])[1], 3)) + start <- 2 + } else { + regform <- paste0(options[["link"]], "(p", options[["target"]], " = ", one_cat, ") = ") + start <- 1 + } + for (i in start:nrow(coefs)) { + regform <- paste0(regform, if (round(as.numeric(coefs[, 1])[i], 3) < 0) " - " else (if (!options[["intercept"]] && i == 1) "" else " + "), abs(round(as.numeric(coefs[, 1])[i], 3)), " x ", rownames(coefs)[i]) + } + } else { + regform <- NULL + nlevs <- nlevels(classificationResult[["train"]][[options[["target"]]]]) + baseline_cat <- levels(classificationResult[["train"]][[options[["target"]]]])[nlevs] + for (i in seq_len(nlevs - 1)) { + current_cat <- levels(classificationResult[["train"]][[options[["target"]]]])[i] + if (options[["intercept"]]) { + part <- paste0("log(p", options[["target"]], " = ", current_cat, " / p", options[["target"]], " = ", baseline_cat, ") = ", round(as.numeric(coefs[, 1])[i], 3)) + start <- nlevs - 1 + i + } else { + part <- paste0("log(p", options[["target"]], " = ", current_cat, " / p", options[["target"]], " = ", baseline_cat, ") = ") + start <- i + } + for (j in seq(start, nrow(coefs), by = nlevs - 1)) { + part <- paste0(part, if (round(as.numeric(coefs[, 1])[j], 3) < 0) " - " else (if (!options[["intercept"]] && j == i) "" else " + "), abs(round(as.numeric(coefs[, 1])[j], 3)), " x ", strsplit(rownames(coefs)[j], " : ")[[1]][1]) + } + if (i == 1) { + regform <- paste0(regform, part, "\n\n") + } else { + regform <- paste0(regform, part) + } + } + } + formula <- createJaspHtml(gettextf("Regression equation:\n%1$s", regform), "p") + formula$position <- position + 1 + formula$dependOn(options = c("coefTable", "formula"), optionsFromObject = jaspResults[["classificationResult"]]) + jaspResults[["regressionFormula"]] <- formula + } +} + +.mlClassificationLogisticPredictions <- function(trainingSet, options, probabilities) { + categories <- levels(trainingSet[[options[["target"]]]]) + predicted_categories <- factor(categories[round(probabilities, 0) + 1], levels = levels(trainingSet[[options[["target"]]]])) + return(predicted_categories) +} + +.mlClassificationMultinomialPredictions <- function(trainingSet, options, logodds) { + ncategories <- ncol(logodds) + 1 + probabilities <- matrix(0, nrow = nrow(logodds), ncol = ncategories) + for (i in seq_len(ncategories - 1)) { + probabilities[, i] <- exp(logodds[, i]) + } + probabilities[, ncategories] <- 1 + row_sums <- rowSums(probabilities) + probabilities <- probabilities / row_sums + predicted_columns <- apply(probabilities, 1, which.max) + categories <- levels(trainingSet[[options[["target"]]]]) + predicted_categories <- factor(categories[predicted_columns], levels = levels(trainingSet[[options[["target"]]]])) + return(predicted_categories) +} diff --git a/R/mlPrediction.R b/R/mlPrediction.R index ce7f45f1..00043aba 100644 --- a/R/mlPrediction.R +++ b/R/mlPrediction.R @@ -75,6 +75,12 @@ is.jaspMachineLearning <- function(x) { .mlPredictionGetModelType.naiveBayes <- function(model) { gettext("Naive Bayes") } +.mlPredictionGetModelType.glm <- function(model) { + gettext("Logistic regression") +} +.mlPredictionGetModelType.vglm <- function(model) { + gettext("Multinomial regression") +} # S3 method to make predictions using the model .mlPredictionGetPredictions <- function(model, dataset) { @@ -135,6 +141,23 @@ is.jaspMachineLearning <- function(x) { .mlPredictionGetPredictions.naiveBayes <- function(model, dataset) { as.character(e1071:::predict.naiveBayes(model, newdata = dataset, type = "class")) } +.mlPredictionGetPredictions.glm <- function(model, dataset) { + as.character(levels(as.factor(model$model[, 1]))[round(predict(model, newdata = dataset, type = "response"), 0) + 1]) +} +.mlPredictionGetPredictions.vglm <- function(model, dataset) { + model[["original"]]@terms$terms <- model[["terms"]] + logodds <- predict(model[["original"]], newdata = dataset) + ncategories <- ncol(logodds) + 1 + probabilities <- matrix(0, nrow = nrow(logodds), ncol = ncategories) + for (i in seq_len(ncategories - 1)) { + probabilities[, i] <- exp(logodds[, i]) + } + probabilities[, ncategories] <- 1 + row_sums <- rowSums(probabilities) + probabilities <- probabilities / row_sums + predicted_columns <- apply(probabilities, 1, which.max) + as.character(levels(as.factor(model$target))[predicted_columns]) +} # S3 method to make find out number of observations in training data .mlPredictionGetTrainingN <- function(model) { @@ -170,6 +193,12 @@ is.jaspMachineLearning <- function(x) { .mlPredictionGetTrainingN.naiveBayes <- function(model) { nrow(model[["data"]]) } +.mlPredictionGetTrainingN.glm <- function(model) { + nrow(model[["data"]]) +} +.mlPredictionGetTrainingN.vglm <- function(model) { + nrow(model[["x"]]) +} # S3 method to decode the model variables in the result object # so that they can be matched to variables in the prediction analysis @@ -229,6 +258,17 @@ is.jaspMachineLearning <- function(x) { names(model[["tables"]]) <- decodeColNames(names(model[["tables"]])) return(model) } +.decodeJaspMLobject.glm <- function(model) { + vars <- all.vars(stats::terms(model)) + formula <- formula(paste(decodeColNames(vars[1]), "~", paste0(decodeColNames(vars[-1]), collapse = " + "))) + model$terms <- stats::terms(formula) + return(model) +} +.decodeJaspMLobject.vglm <- function(model) { + formula <- formula(paste(decodeColNames(strsplit(as.character(model$terms), " ")[[1]][1]), "~", paste0(decodeColNames(strsplit(strsplit(as.character(model$terms), split = " ~ ")[[1]][2], split = " + ", fixed = TRUE)[[1]]), collapse = " + "))) + model$terms <- stats::terms(formula) + return(model) +} .mlPredictionReadModel <- function(options) { if (options[["trainedModelFilePath"]] != "") { @@ -238,7 +278,7 @@ is.jaspMachineLearning <- function(x) { if (!is.jaspMachineLearning(model)) { jaspBase:::.quitAnalysis(gettext("Error: The trained model is not created in JASP.")) } - if (!(any(c("kknn", "lda", "gbm", "randomForest", "cv.glmnet", "nn", "rpart", "svm", "lm", "naiveBayes") %in% class(model)))) { + if (!(any(c("kknn", "lda", "gbm", "randomForest", "cv.glmnet", "nn", "rpart", "svm", "lm", "naiveBayes", "glm", "vglm") %in% class(model)))) { jaspBase:::.quitAnalysis(gettextf("The trained model (type: %1$s) is currently not supported in JASP.", paste(class(model), collapse = ", "))) } if (model[["jaspVersion"]] != .baseCitation) { @@ -326,6 +366,9 @@ is.jaspMachineLearning <- function(x) { table$addColumnInfo(name = "mtry", title = gettext("Features per split"), type = "integer") } else if (inherits(model, "cv.glmnet")) { table$addColumnInfo(name = "lambda", title = "\u03BB", type = "number") + } else if (inherits(model, "glm") || inherits(model, "vglm")) { + table$addColumnInfo(name = "family", title = gettext("Family"), type = "string") + table$addColumnInfo(name = "link", title = gettext("Link"), type = "string") } table$addColumnInfo(name = "ntrain", title = gettext("n(Train)"), type = "integer") table$addColumnInfo(name = "nnew", title = gettext("n(New)"), type = "integer") @@ -344,6 +387,12 @@ is.jaspMachineLearning <- function(x) { row[["mtry"]] <- model[["mtry"]] } else if (inherits(model, "cv.glmnet")) { row[["lambda"]] <- model[["lambda.min"]] + } else if (inherits(model, "glm")) { + row[["family"]] <- gettext("Binomial") + row[["link"]] <- paste0(toupper(substr(model[["link"]], 1, 1)), substr(model[["link"]], 2, nchar(model[["link"]]))) + } else if (inherits(model, "vglm")) { + row[["family"]] <- gettext("Multinomial") + row[["link"]] <- gettext("Logit") } if (length(presentVars) > 0) { row[["nnew"]] <- nrow(dataset) diff --git a/R/mlRegressionLinear.R b/R/mlRegressionLinear.R index 2026cc77..e3b9b54f 100644 --- a/R/mlRegressionLinear.R +++ b/R/mlRegressionLinear.R @@ -101,7 +101,7 @@ mlRegressionLinear <- function(jaspResults, dataset, options, ...) { start <- 1 } for (i in start:nrow(coefs)) { - regform <- paste0(regform, if (round(as.numeric(coefs[, 1])[i], 3) < 0) " - " else " + ", abs(round(as.numeric(coefs[, 1])[i], 3)), " x ", vars[i]) + regform <- paste0(regform, if (round(as.numeric(coefs[, 1])[i], 3) < 0) " - " else (if (!options[["intercept"]] && i == 1) "" else " + "), abs(round(as.numeric(coefs[, 1])[i], 3)), " x ", vars[i]) } # Create results object result <- list() diff --git a/R/mlRegressionRegularized.R b/R/mlRegressionRegularized.R index b2a11b6d..2e5af926 100644 --- a/R/mlRegressionRegularized.R +++ b/R/mlRegressionRegularized.R @@ -166,12 +166,14 @@ mlRegressionRegularized <- function(jaspResults, dataset, options, ...) { if (options[["intercept"]]) { regform <- paste0(options[["target"]], " = ", round(as.numeric(coefs[, 1])[1], 3)) start <- 2 + form_coefs <- coefs } else { regform <- paste0(options[["target"]], " = ") start <- 1 + form_coefs <- coefs[-1, , drop = FALSE] # There is still a row with (Intercept) but its value is 0 } - for (i in start:nrow(coefs)) { - regform <- paste0(regform, if (round(as.numeric(coefs[, 1])[i], 3) < 0) " - " else " + ", abs(round(as.numeric(coefs[, 1])[i], 3)), " x ", rownames(coefs)[i]) + for (i in start:nrow(form_coefs)) { + regform <- paste0(regform, if (round(as.numeric(form_coefs[, 1])[i], 3) < 0) " - " else (if (!options[["intercept"]] && i == 1) "" else " + "), abs(round(as.numeric(form_coefs[, 1])[i], 3)), " x ", rownames(form_coefs)[i]) } result <- list() result[["model"]] <- trainingFit diff --git a/inst/Description.qml b/inst/Description.qml index 52cea7c1..72422c39 100644 --- a/inst/Description.qml +++ b/inst/Description.qml @@ -98,6 +98,12 @@ Description func: "mlClassificationLda" } Analysis + { + menu: qsTr("Logistic / Multinomial") + title: qsTr("Logistic / Multinomial Regression Classification") + func: "mlClassificationLogisticMultinomial" + } + Analysis { menu: qsTr("Naive Bayes") title: qsTr("Naive Bayes Classification") diff --git a/inst/qml/mlClassificationLogisticMultinomial.qml b/inst/qml/mlClassificationLogisticMultinomial.qml new file mode 100644 index 00000000..f63502c4 --- /dev/null +++ b/inst/qml/mlClassificationLogisticMultinomial.qml @@ -0,0 +1,99 @@ +// +// Copyright (C) 2013-2021 University of Amsterdam +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public +// License along with this program. If not, see +// . +// + +import QtQuick 2.8 +import QtQuick.Layouts 1.3 +import JASP.Controls 1.0 +import JASP.Widgets 1.0 + +import "./common/ui" as UI +import "./common/tables" as TAB +import "./common/figures" as FIG +import "./common/analyses/regularized" as REGU + +Form +{ + info: qsTr("Logistic regression is a statistical method used to model the relationship between a binary target variable (with two possible outcomes) and one or more feature variables. It predicts the probability of a specific outcome by using a logistic function, which ensures that the predicted probabilities are between 0 and 1. Multinomial regression extends logistic regression to handle target variables with more than two categories. Instead of predicting binary outcomes, multinomial regression is used for scenarios where the target variable has three or more unordered categories.") + + UI.VariablesFormClassification { id: vars } + + Group + { + title: qsTr("Tables") + + TAB.ConfusionMatrix { } + TAB.ClassProportions { } + TAB.ModelPerformance { } + TAB.FeatureImportance { } + TAB.ExplainPredictions { } + REGU.CoefficientTable { confint: true } + } + + Group + { + title: qsTr("Plots") + + FIG.DataSplit { } + FIG.RocCurve { } + FIG.AndrewsCurve { } + FIG.DecisionBoundary { } + } + + UI.ExportResults { enabled: vars.predictorCount > 0 && vars.targetCount > 0 } + UI.DataSplit { trainingValidationSplit: false } + + Section + { + title: qsTr("Training Parameters") + + Group + { + title: qsTr("Algorithmic Settings") + + DropDown + { + name: "link" + indexDefaultValue: 0 + label: qsTr("Link function (for binary classification)") + values: + [ + { label: qsTr("Logit"), value: "logit"}, + { label: qsTr("Probit"), value: "probit"}, + { label: qsTr("Cauchit"), value: "cauchit"}, + { label: qsTr("Complimentary log-log"), value: "cloglog"}, + { label: qsTr("Log"), value: "log"} + ] + } + REGU.Intercept { } + UI.ScaleVariables { } + UI.SetSeed { } + } + + RadioButtonGroup + { + name: "modelOptimization" + visible: false + + RadioButton + { + name: "manual" + checked: true + } + } + } +} diff --git a/tests/testthat/_snaps/mlclassificationlogisticmultinomial/data-split-1.svg b/tests/testthat/_snaps/mlclassificationlogisticmultinomial/data-split-1.svg new file mode 100644 index 00000000..2a5fca82 --- /dev/null +++ b/tests/testthat/_snaps/mlclassificationlogisticmultinomial/data-split-1.svg @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + + +Train: 160 +Test: 40 +Total: 200 + + + + + +data-split-1 + + diff --git a/tests/testthat/_snaps/mlclassificationlogisticmultinomial/data-split-2.svg b/tests/testthat/_snaps/mlclassificationlogisticmultinomial/data-split-2.svg new file mode 100644 index 00000000..35b13f56 --- /dev/null +++ b/tests/testthat/_snaps/mlclassificationlogisticmultinomial/data-split-2.svg @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + + +Train: 120 +Test: 30 +Total: 150 + + + + + +data-split-2 + + diff --git a/tests/testthat/helper-ml.R b/tests/testthat/helper-ml.R index 04b8528f..0c82a34a 100644 --- a/tests/testthat/helper-ml.R +++ b/tests/testthat/helper-ml.R @@ -27,7 +27,7 @@ mlOptions <- function(analysis) { files <- c(files, list.files(testthat::test_path("..", "..", "inst", "qml", "common", "analyses", "randomforest"), full.names = TRUE)) } else if (analysis %in% c("mlClassificationSvm", "mlRegressionSvm")) { files <- c(files, list.files(testthat::test_path("..", "..", "inst", "qml", "common", "analyses", "svm"), full.names = TRUE)) - } else if (analysis %in% c("mlRegressionLinear", "mlRegressionRegularized")) { + } else if (analysis %in% c("mlClassificationLogisticMultinomial", "mlRegressionLinear", "mlRegressionRegularized")) { files <- c(files, list.files(testthat::test_path("..", "..", "inst", "qml", "common", "analyses", "regularized"), full.names = TRUE)) } options <- lapply(files, jaspTools:::readQML) |> diff --git a/tests/testthat/test-mlclassificationlogisticmultinomial.R b/tests/testthat/test-mlclassificationlogisticmultinomial.R new file mode 100644 index 00000000..38434d4f --- /dev/null +++ b/tests/testthat/test-mlclassificationlogisticmultinomial.R @@ -0,0 +1,218 @@ +context("Machine Learning Logistic / Multinomial Regression Classification") + +# Test logistic regression model ############################################ +options <- initMlOptions("mlClassificationLogisticMultinomial") +options$addIndicator <- FALSE +options$addPredictions <- FALSE +options$coefTable <- TRUE +options$coefTableConfInt <- TRUE +options$classProportionsTable <- TRUE +options$holdoutData <- "holdoutManual" +options$link <- "logit" +options$modelOptimization <- "manual" +options$modelValid <- "validationManual" +options$predictionsColumn <- "" +options$predictors <- c("x", "y") +options$predictors.types <- rep("scale", 2) +options$saveModel <- FALSE +options$savePath <- "" +options$setSeed <- TRUE +options$target <- "color" +options$target.types <- "nominal" +options$testDataManual <- 0.2 +options$testIndicatorColumn <- "" +options$testSetIndicatorVariable <- "" +options$validationDataManual <- 0.2 +options$validationMeasures <- TRUE +options$tableShap <- TRUE +options$fromIndex <- 1 +options$toIndex <- 5 +options$featureImportanceTable <- TRUE +options$seed <- 2 +set.seed(1) +results <- jaspTools::runAnalysis("mlClassificationLogisticMultinomial", "spiral.csv", options) + +test_that("Class Proportions table results match", { + table <- results[["results"]][["classProportionsTable"]][["data"]] + jaspTools::expect_equal_tables(table, + list(0.5, "Black", 0.575, 0.48125, 0.5, "Red", 0.425, 0.51875)) +}) + +test_that("Model Summary: Logistic Regression Classification table results match", { + table <- results[["results"]][["classificationTable"]][["data"]] + jaspTools::expect_equal_tables(table, + list("Binomial", "Logit", 40, 160, 0.675)) +}) + +test_that("Regression Coefficients table results match", { + table <- results[["results"]][["coefTable"]][["data"]] + jaspTools::expect_equal_tables(table, + list(0.0784909051640228, -0.241691087574101, 0.630847847956443, 0.163341057208965, + 0.400211265504598, "(Intercept)", 0.480533838247466, -0.0733280693358763, + -0.389515022216725, 0.647229030486731, 0.160239637348256, 0.241355688318239, + "x", -0.457615047995329, -0.520574613112014, -0.864252780481388, + 0.00221841580161717, 0.170160350681228, -0.194437228429846, + "y", -3.0593179376272)) +}) + +test_that("Confusion Matrix table results match", { + table <- results[["results"]][["confusionTable"]][["data"]] + jaspTools::expect_equal_tables(table, + list("Observed", "Black", 14, 9, "", "Red", 4, 13)) +}) + +test_that("Feature Importance Metrics table results match", { + table <- results[["results"]][["featureImportanceTable"]][["data"]] + jaspTools::expect_equal_tables(table, + list(0.49553121577218, "y", 0.367228915662651, "x")) +}) + +test_that("Data Split plot matches", { + plotName <- results[["results"]][["plotDataSplit"]][["data"]] + testPlot <- results[["state"]][["figures"]][[plotName]][["obj"]] + jaspTools::expect_equal_plots(testPlot, "data-split-1") +}) + +test_that("Additive Explanations for Predictions of Test Set Cases table results match", { + table <- results[["results"]][["tableShap"]][["data"]] + jaspTools::expect_equal_tables(table, + list(0.518749999999986, 1, "Red (0.511)", -0.0104280206091216, 0.00223442804446772, + 0.481250000000014, 2, "Black (0.505)", 0.0113934073181472, 0.0126826432048778, + 0.518749999999986, 3, "Red (0.694)", -0.0215827660760237, 0.196894941069013, + 0.518749999999986, 4, "Red (0.707)", -0.0148699072618508, 0.203086871619157, + 0.481250000000014, 5, "Black (0.584)", 0.00703631455756359, + 0.0957186107999987)) +}) + +test_that("Model Performance Metrics table results match", { + table <- results[["results"]][["validationMeasures"]][["data"]] + jaspTools::expect_equal_tables(table, + list(0.675, 0.686700767263427, 0.682926829268293, 0.222222222222222, + 0.391304347826087, 0.409090909090909, 0.235294117647059, "Black", + 0.371036713180216, 0.590909090909091, 0.777777777777778, 0.608695652173913, + 0.45, 23, 0.764705882352941, 0.823529411764706, 0.675, 0.686700767263427, + 0.666666666666667, 0.409090909090909, 0.235294117647059, 0.222222222222222, + 0.391304347826087, "Red", 0.371036713180216, 0.777777777777778, + 0.590909090909091, 0.764705882352941, 0.55, 17, 0.608695652173913, + 0.590909090909091, 0.675, 0.686700767263427, 0.676016260162602, + 0.315656565656566, 0.313299232736573, 0.315656565656566, 0.313299232736573, + "Average / Total", 0.371036713180216, 0.684343434343434, 0.698358585858586, + 0.675, 1, 40, 0.686700767263427, 0.707219251336898)) +}) + +# Test multinomial regression model ############################################ +options <- initMlOptions("mlClassificationLogisticMultinomial") +options$addIndicator <- FALSE +options$addPredictions <- FALSE +options$coefTable <- TRUE +options$coefTableConfInt <- TRUE +options$classProportionsTable <- TRUE +options$holdoutData <- "holdoutManual" +options$link <- "logit" +options$modelOptimization <- "manual" +options$modelValid <- "validationManual" +options$predictionsColumn <- "" +options$predictors <- c("Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width") +options$predictors.types <- rep("scale", 4) +options$saveModel <- FALSE +options$savePath <- "" +options$setSeed <- TRUE +options$target <- "Species" +options$target.types <- "nominal" +options$testDataManual <- 0.2 +options$testIndicatorColumn <- "" +options$testSetIndicatorVariable <- "" +options$validationDataManual <- 0.2 +options$validationMeasures <- TRUE +options$tableShap <- TRUE +options$fromIndex <- 1 +options$toIndex <- 5 +options$featureImportanceTable <- TRUE +set.seed(1) +results <- jaspTools::runAnalysis("mlClassificationLogisticMultinomial", "iris.csv", options) + +test_that("Class Proportions table results match", { + table <- results[["results"]][["classProportionsTable"]][["data"]] + jaspTools::expect_equal_tables(table, + list(0.333333333333333, "setosa", 0.333333333333333, 0.333333333333333, + 0.333333333333333, "versicolor", 0.266666666666667, 0.35, 0.333333333333333, + "virginica", 0.4, 0.316666666666667)) +}) + +test_that("Model Summary: Multinomial Regression Classification table results match", { + table <- results[["results"]][["classificationTable"]][["data"]] + jaspTools::expect_equal_tables(table, + list("Multinomial", "Logit", 30, 120, 1)) +}) + +test_that("Regression Coefficients table results match", { + table <- results[["results"]][["coefTable"]][["data"]] + jaspTools::expect_equal_tables(table, + list(-0.42267961226173, -4222.39134720741, 0.999843438690878, 2154.10522892129, + 4221.54598798288, "(Intercept) : setosa", -0.000196220503337896, + 17.8212683047785, 0.307905475852188, 0.0461059069011903, 8.93555339132221, + 35.3346311337047, "(Intercept) : versicolor", 1.99442245200903, + 7.95554619079237, -9967.35637439074, 0.998752813237481, 5089.53837890162, + 9983.26746677233, "Sepal.Length : setosa", 0.00156311743787445, + 1.89635371704686, -2.10497557567811, 0.352947362657274, 2.04153205073509, + 5.89768300977184, "Sepal.Length : versicolor", 0.928887555972511, + 4.70317443290632, -2533.61853897752, 0.997102445459767, 1295.08589618604, + 2543.02488784333, "Sepal.Width : setosa", 0.00363155405116905, + 2.30296713825466, -1.92502652612214, 0.285708816748704, 2.15717926335722, + 6.53096080263146, "Sepal.Width : versicolor", 1.0675826424692, + -38.7099474348317, -15129.1258351686, 0.995988491696961, 7699.33325651139, + 15051.705940299, "Petal.Length : setosa", -0.00502770124959773, + -14.4906493502573, -31.1295653321292, 0.087838572245474, 8.48939884258976, + 2.14826663161466, "Petal.Length : versicolor", -1.7069111275065, + -24.205993975534, -10348.1025945441, 0.996333380037363, 5267.39097350877, + 10299.690606593, "Petal.Width : setosa", -0.00459544280978437, + -12.7574125117412, -27.3435062790441, 0.0864846111196134, 7.44202132404277, + 1.82868125556169, "Petal.Width : versicolor", -1.7142402522452 + )) +}) + +test_that("Confusion Matrix table results match", { + table <- results[["results"]][["confusionTable"]][["data"]] + jaspTools::expect_equal_tables(table, + list("Observed", "setosa", 10, 0, 0, "", "versicolor", 0, 8, 0, "", + "virginica", 0, 0, 12)) +}) + +test_that("Feature Importance Metrics table results match", { + table <- results[["results"]][["featureImportanceTable"]][["data"]] + jaspTools::expect_equal_tables(table, + list(530.419652531233, "Petal.Length", 258.101247632355, "Petal.Width", + 11.6632506085855, "Sepal.Width", 10.7849556008181, "Sepal.Length" + )) +}) + +test_that("Data Split plot matches", { + plotName <- results[["results"]][["plotDataSplit"]][["data"]] + testPlot <- results[["state"]][["figures"]][[plotName]][["obj"]] + jaspTools::expect_equal_plots(testPlot, "data-split-2") +}) + +test_that("Additive Explanations for Predictions of Test Set Cases table results match", { + table <- results[["results"]][["tableShap"]][["data"]] + jaspTools::expect_equal_tables(table, + list(0.629688901507271, 0.0369777616411127, 2.01684224876431e-11, 3.29787464004028e-09, + 0.333333333528117, 1, "setosa (1)", 0.608365336494419, 0.0583013168725895, + 5.75901548671709e-11, 1.30466177861166e-08, 0.333333333528117, + 2, "setosa (1)", 0.578930644244146, 0.0877359705540646, -1.91210501876427e-08, + 6.96919006948349e-08, 0.333333333528117, 3, "setosa (1)", 0.644504774615733, + 0.0221618748599225, 2.48312481687663e-11, 1.69711121822402e-08, + 0.333333333528117, 4, "setosa (1)", 0.544084722419656, 0.122581029133867, + 2.34213792804638e-09, 9.1257324152938e-07, 0.333333333528117, + 5, "setosa (1)")) +}) + +test_that("Model Performance Metrics table results match", { + table <- results[["results"]][["validationMeasures"]][["data"]] + jaspTools::expect_equal_tables(table, + list(1, 1, 1, 0, 0, 0, 0, "setosa", 1, 1, 1, 1, 0.333333333333333, + 10, 1, "", 1, 0.613636363636364, 1, 0, 0, 0, 0, "versicolor", + 1, 1, 1, 1, 0.266666666666667, 8, 1, "", 1, 1, 1, 0, + 0, 0, 0, "virginica", 1, 1, 1, 1, 0.4, 12, 1, "", 1, + 0.871212121212121, 1, 0, 0, 0, 0, "Average / Total", 1, 1, 1, + 1, 1, 30, 1, "")) +}) \ No newline at end of file