diff --git a/R/commonMachineLearningClassification.R b/R/commonMachineLearningClassification.R index daafcc1b..a7fbaad4 100644 --- a/R/commonMachineLearningClassification.R +++ b/R/commonMachineLearningClassification.R @@ -33,7 +33,8 @@ "noOfTrees", "maxTrees", "baggingFraction", "noOfPredictors", "numberOfPredictors", # Random forest "complexityParameter", "degree", "gamma", "cost", "tolerance", "epsilon", "maxCost", # Support vector machine "smoothingParameter", # Naive Bayes - "intercept", "link" # Logistic + "intercept", "link", # Logistic + "balanceLabels", "balanceSamplingMethod" # Common ) if (includeSaveOptions) { opt <- c(opt, "saveModel", "savePath") @@ -41,6 +42,38 @@ return(opt) } +# Function balancing the size of classes of a discrete dependent variable in a dataset +.mlBalanceDataset <- function(dataset, options) { + # To balance the classes, this function uses either under- or oversampling to adjust + # the size of each class to either the minimum or maximum class size found in the data. + # The sampling method is random sampling. + + # Ensures that if the option is not selected, balancing will not occur. + if (!isTRUE(options[["balanceLabels"]])) + return(dataset) + + classes <- dataset[, options[["target"]]] + splitData <- split(dataset, classes) + + if (options[["balanceSamplingMethod"]] == "minSample") { + n <- min(sapply(splitData, nrow)) + replace <- FALSE + } + + else { + n <- max(sapply(splitData, nrow)) + replace <- TRUE + } + + balancedSplits <- lapply( + X = splitData, + FUN = function(df) {df[sample(nrow(df), size = n, replace = replace), ]} + ) + balancedDataset <- do.call(rbind, balancedSplits) + + return(balancedDataset) +} + .mlClassificationReadData <- function(dataset, options) { dataset <- .readDataClassificationRegressionAnalyses(dataset, options, include_weights = FALSE) if (options[["target"]] != "") { diff --git a/R/commonMachineLearningRegression.R b/R/commonMachineLearningRegression.R index 61fd1596..40ede4c5 100644 --- a/R/commonMachineLearningRegression.R +++ b/R/commonMachineLearningRegression.R @@ -166,7 +166,7 @@ if (length(factorsWithNewLevels) > 0) { setType <- switch(type, "test" = gettext("test set"), "validation" = gettext("validation set"), "prediction" = gettext("new dataset")) additionalMessage <- switch(type, - "test" = gettext(" or use a different test set (e.g., automatically by setting a different seed or manually by specifying the test set indicator)"), + "test" = gettext(" or use a different test set (e.g., automatically by setting a different seed or manually by specifying the test set indicator)"), "validation" = gettext(" or use a different validation set by setting a different seed"), "prediction" = "") factorMessage <- paste(sapply(factorsWithNewLevels, function(i) { @@ -597,7 +597,7 @@ } plot <- createJaspPlot(plot = NULL, title = gettext("Data Split"), width = 800, height = 30) plot$position <- position - plot$dependOn(options = c("dataSplitPlot", "target", "predictors", "trainingDataManual", "modelValid", "testSetIndicatorVariable", "testSetIndicator", "validationDataManual", "holdoutData", "testDataManual", "modelOptimization")) + plot$dependOn(options = c("balanceSamplingMethod", "balanceLabels", "dataSplitPlot", "target", "predictors", "trainingDataManual", "modelValid", "testSetIndicatorVariable", "testSetIndicator", "validationDataManual", "holdoutData", "testDataManual", "modelOptimization")) jaspResults[["plotDataSplit"]] <- plot if (!ready) { return() diff --git a/R/mlClassificationLda.R b/R/mlClassificationLda.R index 5d1f27ca..6ea85e10 100644 --- a/R/mlClassificationLda.R +++ b/R/mlClassificationLda.R @@ -110,7 +110,7 @@ mlClassificationLda <- function(jaspResults, dataset, options, ...) { # Sample a percentage of the total data set trainingIndex <- sample.int(nrow(dataset), size = ceiling((1 - options[["testDataManual"]]) * nrow(dataset))) } - trainingSet <- dataset[trainingIndex, ] + trainingSet <- .mlBalanceDataset(dataset[trainingIndex, ], options) testSet <- dataset[-trainingIndex, ] # Check for factor levels in the test set that are not in the training set .checkForNewFactorLevelsInPredictionSet(trainingSet, testSet, "test") diff --git a/R/mlClassificationLogisticMultinomial.R b/R/mlClassificationLogisticMultinomial.R index 71f5af13..51555cbb 100644 --- a/R/mlClassificationLogisticMultinomial.R +++ b/R/mlClassificationLogisticMultinomial.R @@ -74,7 +74,9 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, . # Sample a percentage of the total data set trainingIndex <- sample.int(nrow(dataset), size = ceiling((1 - options[["testDataManual"]]) * nrow(dataset))) } - trainingSet <- dataset[trainingIndex, ] + # Create training set with optional balanced classes + trainingSet <- .mlBalanceDataset(dataset[trainingIndex, ], options) + # Create the generated test set indicator testIndicatorColumn <- rep(1, nrow(dataset)) testIndicatorColumn[trainingIndex] <- 0 diff --git a/R/mlClassificationNaiveBayes.R b/R/mlClassificationNaiveBayes.R index d2f7a9e0..0602bed2 100644 --- a/R/mlClassificationNaiveBayes.R +++ b/R/mlClassificationNaiveBayes.R @@ -75,7 +75,7 @@ mlClassificationNaiveBayes <- function(jaspResults, dataset, options, ...) { # Sample a percentage of the total data set trainingIndex <- sample.int(nrow(dataset), size = ceiling((1 - options[["testDataManual"]]) * nrow(dataset))) } - trainingSet <- dataset[trainingIndex, ] + trainingSet <- .mlBalanceDataset(dataset[trainingIndex, ], options) # Create the generated test set indicator testIndicatorColumn <- rep(1, nrow(dataset)) testIndicatorColumn[trainingIndex] <- 0 diff --git a/inst/qml/common/ui/DataSplit.qml b/inst/qml/common/ui/DataSplit.qml index 02faac76..06cb0033 100644 --- a/inst/qml/common/ui/DataSplit.qml +++ b/inst/qml/common/ui/DataSplit.qml @@ -25,6 +25,7 @@ Section property alias leaveOneOutVisible: leaveOneOut.visible property alias kFoldsVisible: kFolds.visible property alias trainingValidationSplit: trainingValidationSplit.visible + property alias balanceTargetClasses: balanceTargetClasses.visible title: qsTr("Data Split Preferences") @@ -157,4 +158,32 @@ Section info: qsTr("Partition the remaining data in *n* parts.") } } + + CheckBox + { + id: balanceTargetClasses + name: "balanceLabels" + label: qsTr("Balance sample size of target classes") + info: qsTr("When clicked, the dataset is balanced to have the same sample size for all classes of the target variable. This is done either through over- or undersampling") + + RadioButtonGroup + { + name: "balanceSamplingMethod" + + RadioButton + { + value: "minSample" + label: qsTr("Undersample") + checked: true + info: qsTr("Balances the target classes by undersampling to match the size of the smallest class.") + } + + RadioButton + { + value: "maxSample" + label: qsTr("Oversample") + info: qsTr("Balances the target classes by oversampling to match the size of the largest class. This is done by sampling with replacement for smaller classes.") + } + } + } } diff --git a/inst/qml/mlClassificationBoosting.qml b/inst/qml/mlClassificationBoosting.qml index 2670c08d..f1310dda 100644 --- a/inst/qml/mlClassificationBoosting.qml +++ b/inst/qml/mlClassificationBoosting.qml @@ -56,7 +56,12 @@ Form } UI.ExportResults { enabled: vars.predictorCount > 1 && vars.targetCount > 0 } - UI.DataSplit { leaveOneOutVisible: false; trainingValidationSplit: !optim.isManual } + UI.DataSplit + { + leaveOneOutVisible: false + trainingValidationSplit: !optim.isManual + balanceTargetClasses: false + } Section { diff --git a/inst/qml/mlClassificationDecisionTree.qml b/inst/qml/mlClassificationDecisionTree.qml index 4da77e1b..82fcda9a 100644 --- a/inst/qml/mlClassificationDecisionTree.qml +++ b/inst/qml/mlClassificationDecisionTree.qml @@ -56,7 +56,13 @@ Form } UI.ExportResults { enabled: vars.predictorCount > 0 && vars.targetCount > 0 } - UI.DataSplit { leaveOneOutVisible: false; kFoldsVisible: false; trainingValidationSplit: !optim.isManual } + UI.DataSplit + { + leaveOneOutVisible: false + kFoldsVisible: false + trainingValidationSplit: !optim.isManual + balanceTargetClasses: false + } Section { diff --git a/inst/qml/mlClassificationKnn.qml b/inst/qml/mlClassificationKnn.qml index 28053f3e..dd8678a4 100644 --- a/inst/qml/mlClassificationKnn.qml +++ b/inst/qml/mlClassificationKnn.qml @@ -55,7 +55,7 @@ Form } UI.ExportResults { enabled: vars.predictorCount > 0 && vars.targetCount > 0 } - UI.DataSplit { trainingValidationSplit: !optim.isManual } + UI.DataSplit { trainingValidationSplit: !optim.isManual; balanceTargetClasses: false } Section { diff --git a/inst/qml/mlClassificationNeuralNetwork.qml b/inst/qml/mlClassificationNeuralNetwork.qml index 70fee156..b6a13c68 100644 --- a/inst/qml/mlClassificationNeuralNetwork.qml +++ b/inst/qml/mlClassificationNeuralNetwork.qml @@ -57,7 +57,7 @@ Form } UI.ExportResults { enabled: vars.predictorCount > 0 && vars.targetCount > 0 } - UI.DataSplit { leaveOneOutVisible: false; kFoldsVisible: false } + UI.DataSplit { leaveOneOutVisible: false; kFoldsVisible: false; balanceTargetClasses: false } Section { diff --git a/inst/qml/mlClassificationRandomForest.qml b/inst/qml/mlClassificationRandomForest.qml index 22f7cee1..be91b3df 100644 --- a/inst/qml/mlClassificationRandomForest.qml +++ b/inst/qml/mlClassificationRandomForest.qml @@ -56,7 +56,13 @@ Form } UI.ExportResults { enabled: vars.predictorCount > 1 && vars.targetCount > 0 } - UI.DataSplit { leaveOneOutVisible: false; kFoldsVisible: false; trainingValidationSplit: !optim.isManual } + UI.DataSplit + { + leaveOneOutVisible: false + kFoldsVisible: false + trainingValidationSplit: !optim.isManual + balanceTargetClasses: false + } Section { diff --git a/inst/qml/mlClassificationSvm.qml b/inst/qml/mlClassificationSvm.qml index d3c82ae5..6b430156 100644 --- a/inst/qml/mlClassificationSvm.qml +++ b/inst/qml/mlClassificationSvm.qml @@ -55,7 +55,12 @@ Form } UI.ExportResults { enabled: vars.predictorCount > 0 && vars.targetCount > 0 } - UI.DataSplit { leaveOneOutVisible: false; kFoldsVisible: false; trainingValidationSplit: !optim.isManual } + UI.DataSplit + { + leaveOneOutVisible: false + kFoldsVisible: false + trainingValidationSplit: !optim.isManual + balanceTargetClasses: false } Section { diff --git a/inst/qml/mlRegressionBoosting.qml b/inst/qml/mlRegressionBoosting.qml index 9e411a5e..70a4ced1 100644 --- a/inst/qml/mlRegressionBoosting.qml +++ b/inst/qml/mlRegressionBoosting.qml @@ -44,7 +44,7 @@ Form { title: qsTr("Plots") - FIG.DataSplit { } + FIG.DataSplit {} FIG.PredictivePerformance { } BOOSTING.Oob { } BOOSTING.Deviance { } @@ -52,7 +52,12 @@ Form } UI.ExportResults { enabled: vars.predictorCount > 1 && vars.targetCount > 0 } - UI.DataSplit { leaveOneOutVisible: false; trainingValidationSplit: !optim.isManual } + UI.DataSplit + { + leaveOneOutVisible: false + trainingValidationSplit: !optim.isManual + balanceTargetClasses: false + } Section { diff --git a/inst/qml/mlRegressionDecisionTree.qml b/inst/qml/mlRegressionDecisionTree.qml index 12c7a596..86a9d6da 100644 --- a/inst/qml/mlRegressionDecisionTree.qml +++ b/inst/qml/mlRegressionDecisionTree.qml @@ -52,7 +52,13 @@ Form } UI.ExportResults { enabled: vars.predictorCount > 0 && vars.targetCount > 0 } - UI.DataSplit { leaveOneOutVisible: false; kFoldsVisible: false; trainingValidationSplit: !optim.isManual } + UI.DataSplit + { + leaveOneOutVisible: false + kFoldsVisible: false + trainingValidationSplit: !optim.isManual + balanceTargetClasses: false + } Section { diff --git a/inst/qml/mlRegressionKnn.qml b/inst/qml/mlRegressionKnn.qml index e95b7f3b..43622606 100644 --- a/inst/qml/mlRegressionKnn.qml +++ b/inst/qml/mlRegressionKnn.qml @@ -52,7 +52,7 @@ Form UI.ExportResults { enabled: vars.predictorCount > 0 && vars.targetCount > 0 } - UI.DataSplit { trainingValidationSplit: !optim.isManual } + UI.DataSplit { trainingValidationSplit: !optim.isManual ; balanceTargetClasses: false } Section { diff --git a/inst/qml/mlRegressionLinear.qml b/inst/qml/mlRegressionLinear.qml index 3dd82086..b6af5548 100644 --- a/inst/qml/mlRegressionLinear.qml +++ b/inst/qml/mlRegressionLinear.qml @@ -50,7 +50,7 @@ Form } UI.ExportResults { enabled: vars.predictorCount > 0 && vars.targetCount > 0 } - UI.DataSplit { trainingValidationSplit: false } + UI.DataSplit { trainingValidationSplit: false; balanceTargetClasses: false } Section { diff --git a/inst/qml/mlRegressionNeuralNetwork.qml b/inst/qml/mlRegressionNeuralNetwork.qml index 07e827be..5f125861 100644 --- a/inst/qml/mlRegressionNeuralNetwork.qml +++ b/inst/qml/mlRegressionNeuralNetwork.qml @@ -53,7 +53,7 @@ Form } UI.ExportResults { enabled: vars.predictorCount > 0 && vars.targetCount > 0 } - UI.DataSplit { leaveOneOutVisible: false; kFoldsVisible: false } + UI.DataSplit { leaveOneOutVisible: false; kFoldsVisible: false; balanceTargetClasses: false } Section { diff --git a/inst/qml/mlRegressionRandomForest.qml b/inst/qml/mlRegressionRandomForest.qml index 6c4cd698..80c3b881 100644 --- a/inst/qml/mlRegressionRandomForest.qml +++ b/inst/qml/mlRegressionRandomForest.qml @@ -52,7 +52,13 @@ Form } UI.ExportResults { enabled: vars.predictorCount > 1 && vars.targetCount > 0 } - UI.DataSplit { leaveOneOutVisible: false; kFoldsVisible: false; trainingValidationSplit: !optim.isManual } + UI.DataSplit + { + leaveOneOutVisible: false + kFoldsVisible: false + trainingValidationSplit: !optim.isManual + balanceTargetClasses: false + } Section { diff --git a/inst/qml/mlRegressionRegularized.qml b/inst/qml/mlRegressionRegularized.qml index ae898b1b..888480f7 100644 --- a/inst/qml/mlRegressionRegularized.qml +++ b/inst/qml/mlRegressionRegularized.qml @@ -85,6 +85,7 @@ Form leaveOneOutVisible: false kFoldsVisible: false trainingValidationSplit: !fixedModel.checked + balanceTargetClasses: false } Section diff --git a/inst/qml/mlRegressionSvm.qml b/inst/qml/mlRegressionSvm.qml index 7e4f62fa..7a39b654 100644 --- a/inst/qml/mlRegressionSvm.qml +++ b/inst/qml/mlRegressionSvm.qml @@ -51,7 +51,13 @@ Form } UI.ExportResults { enabled: vars.predictorCount > 0 && vars.targetCount > 0 } - UI.DataSplit { leaveOneOutVisible: false; kFoldsVisible: false; trainingValidationSplit: !optim.isManual } + UI.DataSplit + { + leaveOneOutVisible: false + kFoldsVisible: false + trainingValidationSplit: !optim.isManual + balanceTargetClasses: false + } Section {