Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 8 additions & 30 deletions R/commonMachineLearningClassification.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,9 @@
if (!ready) {
table$addFootnote(gettextf("Please provide a target variable and at least %i feature variable(s).", if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm" || type == "logistic") 1L else 2L))
}
if (options[["saveModel"]]) {
validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0)
if (options[["savePath"]] != "" && validNames) {
table$addFootnote(gettextf("The trained model is saved as <i>%1$s</i>.", basename(options[["savePath"]])))
} else if (options[["savePath"]] != "" && !validNames) {
table$addFootnote(gettext("The trained model is <b>not</b> saved because the some of the variable names in the model contain spaces (i.e., ' ') or underscores (i.e., '_'). Please remove all such characters from the variable names and try saving the model again."))
} else {
table$addFootnote(gettext("The trained model is not saved until a file name is specified under 'Save as'."))
}
}

.mlAddSaveModelInfo(table, options)

jaspResults[["classificationTable"]] <- table
if (!ready) {
return()
Expand Down Expand Up @@ -330,26 +323,11 @@
)
table$addRows(row)
}
# Save the applied model if requested
if (options[["saveModel"]] && options[["savePath"]] != "") {
validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0)
if (!validNames) {
return()
}
model <- classificationResult[["model"]]
model[["jaspVars"]] <- list()
model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]]))
model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
model[["jaspScaling"]] <- attr(dataset, "jaspScaling")
model[["jaspVersion"]] <- .baseCitation
model[["explainer"]] <- classificationResult[["explainer"]]
class(model) <- c(class(classificationResult[["model"]]), "jaspClassification", "jaspMachineLearning")
path <- options[["savePath"]]
if (!endsWith(path, ".jaspML")) {
path <- paste0(path, ".jaspML")
}
saveRDS(model, file = path)
}

# Save the model if requested
saveResult <- .mlSaveModelToDisk(options, classificationResult, dataset, class = "jaspClassification")
.mlPossiblyShowSaveResult(table, saveResult, options)

}

.mlClassificationTableConfusion <- function(dataset, options, jaspResults, ready, position) {
Expand Down
69 changes: 50 additions & 19 deletions R/commonMachineLearningRegression.R
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@
if (length(factorsWithNewLevels) > 0) {
setType <- switch(type, "test" = gettext("test set"), "validation" = gettext("validation set"), "prediction" = gettext("new dataset"))
additionalMessage <- switch(type,
"test" = gettext(" or use a different test set (e.g., automatically by setting a different seed or manually by specifying the test set indicator)"),
"test" = gettext(" or use a different test set (e.g., automatically by setting a different seed or manually by specifying the test set indicator)"),
"validation" = gettext(" or use a different validation set by setting a different seed"),
"prediction" = "")
factorMessage <- paste(sapply(factorsWithNewLevels, function(i) {
Expand Down Expand Up @@ -337,16 +337,9 @@
if (!ready) {
table$addFootnote(gettextf("Please provide a target variable and at least %d feature variable(s).", if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm" || type == "lm") 1L else 2L))
}
if (options[["saveModel"]]) {
validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0)
if (options[["savePath"]] != "" && validNames) {
table$addFootnote(gettextf("The trained model is saved as <i>%1$s</i>.", basename(options[["savePath"]])))
} else if (options[["savePath"]] != "" && !validNames) {
table$addFootnote(gettext("The trained model is <b>not</b> saved because the some of the variable names in the model contain spaces (i.e., ' ') or underscores (i.e., '_'). Please remove all such characters from the variable names and try saving the model again."))
} else {
table$addFootnote(gettext("The trained model is not saved until a file name is specified under 'Save as'."))
}
}

.mlAddSaveModelInfo(table, options)

jaspResults[["regressionTable"]] <- table
if (!ready) {
return()
Expand Down Expand Up @@ -492,25 +485,63 @@
)
table$addRows(row)
}

# Save the model if requested
if (options[["saveModel"]] && options[["savePath"]] != "") {
validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0)
if (!validNames) {
return()
}
model <- regressionResult[["model"]]
saveResult <- .mlSaveModelToDisk(options, regressionResult, dataset, class = "jaspRegression")
.mlPossiblyShowSaveResult(table, saveResult, options)

}

.mlAddSaveModelInfo <- function(table, options) {
if (options[["saveModel"]] && options[["savePath"]] == "") {
table$addFootnote(gettext("The trained model is not saved until a file name is specified under 'Save as'."))
}
}

.mlSaveModelToDisk <- function(options, mlResult, dataset, class = c("jaspRegression", "jaspClassification")) {

if (!options[["saveModel"]] || options[["savePath"]] == "")
return()

class <- match.arg(class)

error <- try({
model <- mlResult[["model"]]
model[["jaspVars"]] <- list()
model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]]))
model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
model[["jaspScaling"]] <- attr(dataset, "jaspScaling")
model[["jaspVersion"]] <- .baseCitation
model[["explainer"]] <- regressionResult[["explainer"]]
class(model) <- c(class(regressionResult[["model"]]), "jaspRegression", "jaspMachineLearning")
model[["explainer"]] <- mlResult[["explainer"]]
class(model) <- c(class(mlResult[["model"]]), class, "jaspMachineLearning")
path <- options[["savePath"]]
if (!endsWith(path, ".jaspML")) {
path <- paste0(path, ".jaspML")
}
saveRDS(model, file = path)

"success"
})
return(list(exists = file.exists(path), error = error))
}

.mlPossiblyShowSaveResult <- function(table, saveResult, options) {

if (is.null(saveResult))
return()

if (identical(saveResult[["error"]], "success") && isTRUE(saveResult[["exists"]])) {
table$addFootnote(gettextf("The model is saved as <i>%1$s</i>.", basename(options[["savePath"]])))
} else if (isTryError(saveResult[["error"]])) {
msg <- .extractErrorMessage(saveResult[["error"]])
if (grepl(x = msg, pattern = "cannot open the connection", fixed = TRUE) && !dir.exists(dirname(options[["savePath"]]))) {
# likely occurs most often when using a downloaded jasp file that contains a path that is valid another computer
table$addFootnote(gettextf("The model could not be saved because the parent directory '%s' does not exist.", dirname(options[["savePath"]])))
} else
table$addFootnote(gettextf("The model could not be saved because the following error occured: %s", msg))

} else if (!isTRUE(saveResult[["exists"]])) {
table$addFootnote(gettext("The model could not be saved because an unexpected error occured."))
}
}

Expand Down
Loading