diff --git a/R/commonMachineLearningClassification.R b/R/commonMachineLearningClassification.R
index 3e7a01e2..9afb3656 100644
--- a/R/commonMachineLearningClassification.R
+++ b/R/commonMachineLearningClassification.R
@@ -167,16 +167,9 @@
if (!ready) {
table$addFootnote(gettextf("Please provide a target variable and at least %i feature variable(s).", if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm" || type == "logistic") 1L else 2L))
}
- if (options[["saveModel"]]) {
- validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0)
- if (options[["savePath"]] != "" && validNames) {
- table$addFootnote(gettextf("The trained model is saved as %1$s.", basename(options[["savePath"]])))
- } else if (options[["savePath"]] != "" && !validNames) {
- table$addFootnote(gettext("The trained model is not saved because the some of the variable names in the model contain spaces (i.e., ' ') or underscores (i.e., '_'). Please remove all such characters from the variable names and try saving the model again."))
- } else {
- table$addFootnote(gettext("The trained model is not saved until a file name is specified under 'Save as'."))
- }
- }
+
+ .mlAddSaveModelInfo(table, options)
+
jaspResults[["classificationTable"]] <- table
if (!ready) {
return()
@@ -330,26 +323,11 @@
)
table$addRows(row)
}
- # Save the applied model if requested
- if (options[["saveModel"]] && options[["savePath"]] != "") {
- validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0)
- if (!validNames) {
- return()
- }
- model <- classificationResult[["model"]]
- model[["jaspVars"]] <- list()
- model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]]))
- model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
- model[["jaspScaling"]] <- attr(dataset, "jaspScaling")
- model[["jaspVersion"]] <- .baseCitation
- model[["explainer"]] <- classificationResult[["explainer"]]
- class(model) <- c(class(classificationResult[["model"]]), "jaspClassification", "jaspMachineLearning")
- path <- options[["savePath"]]
- if (!endsWith(path, ".jaspML")) {
- path <- paste0(path, ".jaspML")
- }
- saveRDS(model, file = path)
- }
+
+ # Save the model if requested
+ saveResult <- .mlSaveModelToDisk(options, classificationResult, dataset, class = "jaspClassification")
+ .mlPossiblyShowSaveResult(table, saveResult, options)
+
}
.mlClassificationTableConfusion <- function(dataset, options, jaspResults, ready, position) {
diff --git a/R/commonMachineLearningRegression.R b/R/commonMachineLearningRegression.R
index 61fd1596..69746a8c 100644
--- a/R/commonMachineLearningRegression.R
+++ b/R/commonMachineLearningRegression.R
@@ -166,7 +166,7 @@
if (length(factorsWithNewLevels) > 0) {
setType <- switch(type, "test" = gettext("test set"), "validation" = gettext("validation set"), "prediction" = gettext("new dataset"))
additionalMessage <- switch(type,
- "test" = gettext(" or use a different test set (e.g., automatically by setting a different seed or manually by specifying the test set indicator)"),
+ "test" = gettext(" or use a different test set (e.g., automatically by setting a different seed or manually by specifying the test set indicator)"),
"validation" = gettext(" or use a different validation set by setting a different seed"),
"prediction" = "")
factorMessage <- paste(sapply(factorsWithNewLevels, function(i) {
@@ -337,16 +337,9 @@
if (!ready) {
table$addFootnote(gettextf("Please provide a target variable and at least %d feature variable(s).", if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm" || type == "lm") 1L else 2L))
}
- if (options[["saveModel"]]) {
- validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0)
- if (options[["savePath"]] != "" && validNames) {
- table$addFootnote(gettextf("The trained model is saved as %1$s.", basename(options[["savePath"]])))
- } else if (options[["savePath"]] != "" && !validNames) {
- table$addFootnote(gettext("The trained model is not saved because the some of the variable names in the model contain spaces (i.e., ' ') or underscores (i.e., '_'). Please remove all such characters from the variable names and try saving the model again."))
- } else {
- table$addFootnote(gettext("The trained model is not saved until a file name is specified under 'Save as'."))
- }
- }
+
+ .mlAddSaveModelInfo(table, options)
+
jaspResults[["regressionTable"]] <- table
if (!ready) {
return()
@@ -492,25 +485,63 @@
)
table$addRows(row)
}
+
# Save the model if requested
- if (options[["saveModel"]] && options[["savePath"]] != "") {
- validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0)
- if (!validNames) {
- return()
- }
- model <- regressionResult[["model"]]
+ saveResult <- .mlSaveModelToDisk(options, regressionResult, dataset, class = "jaspRegression")
+ .mlPossiblyShowSaveResult(table, saveResult, options)
+
+}
+
+.mlAddSaveModelInfo <- function(table, options) {
+ if (options[["saveModel"]] && options[["savePath"]] == "") {
+ table$addFootnote(gettext("The trained model is not saved until a file name is specified under 'Save as'."))
+ }
+}
+
+.mlSaveModelToDisk <- function(options, mlResult, dataset, class = c("jaspRegression", "jaspClassification")) {
+
+ if (!options[["saveModel"]] || options[["savePath"]] == "")
+ return()
+
+ class <- match.arg(class)
+
+ error <- try({
+ model <- mlResult[["model"]]
model[["jaspVars"]] <- list()
model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]]))
model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
model[["jaspScaling"]] <- attr(dataset, "jaspScaling")
model[["jaspVersion"]] <- .baseCitation
- model[["explainer"]] <- regressionResult[["explainer"]]
- class(model) <- c(class(regressionResult[["model"]]), "jaspRegression", "jaspMachineLearning")
+ model[["explainer"]] <- mlResult[["explainer"]]
+ class(model) <- c(class(mlResult[["model"]]), class, "jaspMachineLearning")
path <- options[["savePath"]]
if (!endsWith(path, ".jaspML")) {
path <- paste0(path, ".jaspML")
}
saveRDS(model, file = path)
+
+ "success"
+ })
+ return(list(exists = file.exists(path), error = error))
+}
+
+.mlPossiblyShowSaveResult <- function(table, saveResult, options) {
+
+ if (is.null(saveResult))
+ return()
+
+ if (identical(saveResult[["error"]], "success") && isTRUE(saveResult[["exists"]])) {
+ table$addFootnote(gettextf("The model is saved as %1$s.", basename(options[["savePath"]])))
+ } else if (isTryError(saveResult[["error"]])) {
+ msg <- .extractErrorMessage(saveResult[["error"]])
+ if (grepl(x = msg, pattern = "cannot open the connection", fixed = TRUE) && !dir.exists(dirname(options[["savePath"]]))) {
+ # likely occurs most often when using a downloaded jasp file that contains a path that is valid another computer
+ table$addFootnote(gettextf("The model could not be saved because the parent directory '%s' does not exist.", dirname(options[["savePath"]])))
+ } else
+ table$addFootnote(gettextf("The model could not be saved because the following error occured: %s", msg))
+
+ } else if (!isTRUE(saveResult[["exists"]])) {
+ table$addFootnote(gettext("The model could not be saved because an unexpected error occured."))
}
}