diff --git a/R/commonMachineLearningClassification.R b/R/commonMachineLearningClassification.R index 3e7a01e2..9afb3656 100644 --- a/R/commonMachineLearningClassification.R +++ b/R/commonMachineLearningClassification.R @@ -167,16 +167,9 @@ if (!ready) { table$addFootnote(gettextf("Please provide a target variable and at least %i feature variable(s).", if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm" || type == "logistic") 1L else 2L)) } - if (options[["saveModel"]]) { - validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0) - if (options[["savePath"]] != "" && validNames) { - table$addFootnote(gettextf("The trained model is saved as %1$s.", basename(options[["savePath"]]))) - } else if (options[["savePath"]] != "" && !validNames) { - table$addFootnote(gettext("The trained model is not saved because the some of the variable names in the model contain spaces (i.e., ' ') or underscores (i.e., '_'). Please remove all such characters from the variable names and try saving the model again.")) - } else { - table$addFootnote(gettext("The trained model is not saved until a file name is specified under 'Save as'.")) - } - } + + .mlAddSaveModelInfo(table, options) + jaspResults[["classificationTable"]] <- table if (!ready) { return() @@ -330,26 +323,11 @@ ) table$addRows(row) } - # Save the applied model if requested - if (options[["saveModel"]] && options[["savePath"]] != "") { - validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0) - if (!validNames) { - return() - } - model <- classificationResult[["model"]] - model[["jaspVars"]] <- list() - model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]])) - model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]]) - model[["jaspScaling"]] <- attr(dataset, "jaspScaling") - model[["jaspVersion"]] <- .baseCitation - model[["explainer"]] <- classificationResult[["explainer"]] - class(model) <- c(class(classificationResult[["model"]]), "jaspClassification", "jaspMachineLearning") - path <- options[["savePath"]] - if (!endsWith(path, ".jaspML")) { - path <- paste0(path, ".jaspML") - } - saveRDS(model, file = path) - } + + # Save the model if requested + saveResult <- .mlSaveModelToDisk(options, classificationResult, dataset, class = "jaspClassification") + .mlPossiblyShowSaveResult(table, saveResult, options) + } .mlClassificationTableConfusion <- function(dataset, options, jaspResults, ready, position) { diff --git a/R/commonMachineLearningRegression.R b/R/commonMachineLearningRegression.R index 61fd1596..69746a8c 100644 --- a/R/commonMachineLearningRegression.R +++ b/R/commonMachineLearningRegression.R @@ -166,7 +166,7 @@ if (length(factorsWithNewLevels) > 0) { setType <- switch(type, "test" = gettext("test set"), "validation" = gettext("validation set"), "prediction" = gettext("new dataset")) additionalMessage <- switch(type, - "test" = gettext(" or use a different test set (e.g., automatically by setting a different seed or manually by specifying the test set indicator)"), + "test" = gettext(" or use a different test set (e.g., automatically by setting a different seed or manually by specifying the test set indicator)"), "validation" = gettext(" or use a different validation set by setting a different seed"), "prediction" = "") factorMessage <- paste(sapply(factorsWithNewLevels, function(i) { @@ -337,16 +337,9 @@ if (!ready) { table$addFootnote(gettextf("Please provide a target variable and at least %d feature variable(s).", if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm" || type == "lm") 1L else 2L)) } - if (options[["saveModel"]]) { - validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0) - if (options[["savePath"]] != "" && validNames) { - table$addFootnote(gettextf("The trained model is saved as %1$s.", basename(options[["savePath"]]))) - } else if (options[["savePath"]] != "" && !validNames) { - table$addFootnote(gettext("The trained model is not saved because the some of the variable names in the model contain spaces (i.e., ' ') or underscores (i.e., '_'). Please remove all such characters from the variable names and try saving the model again.")) - } else { - table$addFootnote(gettext("The trained model is not saved until a file name is specified under 'Save as'.")) - } - } + + .mlAddSaveModelInfo(table, options) + jaspResults[["regressionTable"]] <- table if (!ready) { return() @@ -492,25 +485,63 @@ ) table$addRows(row) } + # Save the model if requested - if (options[["saveModel"]] && options[["savePath"]] != "") { - validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0) - if (!validNames) { - return() - } - model <- regressionResult[["model"]] + saveResult <- .mlSaveModelToDisk(options, regressionResult, dataset, class = "jaspRegression") + .mlPossiblyShowSaveResult(table, saveResult, options) + +} + +.mlAddSaveModelInfo <- function(table, options) { + if (options[["saveModel"]] && options[["savePath"]] == "") { + table$addFootnote(gettext("The trained model is not saved until a file name is specified under 'Save as'.")) + } +} + +.mlSaveModelToDisk <- function(options, mlResult, dataset, class = c("jaspRegression", "jaspClassification")) { + + if (!options[["saveModel"]] || options[["savePath"]] == "") + return() + + class <- match.arg(class) + + error <- try({ + model <- mlResult[["model"]] model[["jaspVars"]] <- list() model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]])) model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]]) model[["jaspScaling"]] <- attr(dataset, "jaspScaling") model[["jaspVersion"]] <- .baseCitation - model[["explainer"]] <- regressionResult[["explainer"]] - class(model) <- c(class(regressionResult[["model"]]), "jaspRegression", "jaspMachineLearning") + model[["explainer"]] <- mlResult[["explainer"]] + class(model) <- c(class(mlResult[["model"]]), class, "jaspMachineLearning") path <- options[["savePath"]] if (!endsWith(path, ".jaspML")) { path <- paste0(path, ".jaspML") } saveRDS(model, file = path) + + "success" + }) + return(list(exists = file.exists(path), error = error)) +} + +.mlPossiblyShowSaveResult <- function(table, saveResult, options) { + + if (is.null(saveResult)) + return() + + if (identical(saveResult[["error"]], "success") && isTRUE(saveResult[["exists"]])) { + table$addFootnote(gettextf("The model is saved as %1$s.", basename(options[["savePath"]]))) + } else if (isTryError(saveResult[["error"]])) { + msg <- .extractErrorMessage(saveResult[["error"]]) + if (grepl(x = msg, pattern = "cannot open the connection", fixed = TRUE) && !dir.exists(dirname(options[["savePath"]]))) { + # likely occurs most often when using a downloaded jasp file that contains a path that is valid another computer + table$addFootnote(gettextf("The model could not be saved because the parent directory '%s' does not exist.", dirname(options[["savePath"]]))) + } else + table$addFootnote(gettextf("The model could not be saved because the following error occured: %s", msg)) + + } else if (!isTRUE(saveResult[["exists"]])) { + table$addFootnote(gettext("The model could not be saved because an unexpected error occured.")) } }