|
166 | 166 | if (length(factorsWithNewLevels) > 0) {
|
167 | 167 | setType <- switch(type, "test" = gettext("test set"), "validation" = gettext("validation set"), "prediction" = gettext("new dataset"))
|
168 | 168 | additionalMessage <- switch(type,
|
169 |
| - "test" = gettext(" or use a different test set (e.g., automatically by setting a different seed or manually by specifying the test set indicator)"), |
| 169 | + "test" = gettext(" or use a different test set (e.g., automatically by setting a different seed or manually by specifying the test set indicator)"), |
170 | 170 | "validation" = gettext(" or use a different validation set by setting a different seed"),
|
171 | 171 | "prediction" = "")
|
172 | 172 | factorMessage <- paste(sapply(factorsWithNewLevels, function(i) {
|
|
337 | 337 | if (!ready) {
|
338 | 338 | table$addFootnote(gettextf("Please provide a target variable and at least %d feature variable(s).", if (type == "knn" || type == "neuralnet" || type == "rpart" || type == "svm" || type == "lm") 1L else 2L))
|
339 | 339 | }
|
340 |
| - if (options[["saveModel"]]) { |
341 |
| - validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0) |
342 |
| - if (options[["savePath"]] != "" && validNames) { |
343 |
| - table$addFootnote(gettextf("The trained model is saved as <i>%1$s</i>.", basename(options[["savePath"]]))) |
344 |
| - } else if (options[["savePath"]] != "" && !validNames) { |
345 |
| - table$addFootnote(gettext("The trained model is <b>not</b> saved because the some of the variable names in the model contain spaces (i.e., ' ') or underscores (i.e., '_'). Please remove all such characters from the variable names and try saving the model again.")) |
346 |
| - } else { |
347 |
| - table$addFootnote(gettext("The trained model is not saved until a file name is specified under 'Save as'.")) |
348 |
| - } |
349 |
| - } |
| 340 | + |
| 341 | + .mlAddSaveModelInfo(table, options) |
| 342 | + |
350 | 343 | jaspResults[["regressionTable"]] <- table
|
351 | 344 | if (!ready) {
|
352 | 345 | return()
|
|
492 | 485 | )
|
493 | 486 | table$addRows(row)
|
494 | 487 | }
|
| 488 | + |
495 | 489 | # Save the model if requested
|
496 |
| - if (options[["saveModel"]] && options[["savePath"]] != "") { |
497 |
| - validNames <- (length(grep(" ", decodeColNames(colnames(dataset)))) == 0) && (length(grep("_", decodeColNames(colnames(dataset)))) == 0) |
498 |
| - if (!validNames) { |
499 |
| - return() |
500 |
| - } |
501 |
| - model <- regressionResult[["model"]] |
| 490 | + saveResult <- .mlSaveModelToDisk(options, regressionResult, dataset, class = "jaspRegression") |
| 491 | + .mlPossiblyShowSaveResult(table, saveResult, options) |
| 492 | + |
| 493 | + |
| 494 | +} |
| 495 | + |
| 496 | +.mlAddSaveModelInfo <- function(table, options) { |
| 497 | + if (options[["saveModel"]] && options[["savePath"]] == "") { |
| 498 | + table$addFootnote(gettext("The trained model is not saved until a file name is specified under 'Save as'.")) |
| 499 | + } |
| 500 | +} |
| 501 | + |
| 502 | +.mlSaveModelToDisk <- function(options, mlResult, dataset, class = c("jaspRegression", "jaspClassification")) { |
| 503 | + |
| 504 | + if (!options[["saveModel"]] || options[["savePath"]] == "") |
| 505 | + return() |
| 506 | + |
| 507 | + class <- match.arg(class) |
| 508 | + |
| 509 | + error <- try({ |
| 510 | + model <- mlResult[["model"]] |
502 | 511 | model[["jaspVars"]] <- list()
|
503 | 512 | model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]]))
|
504 | 513 | model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
|
505 | 514 | model[["jaspScaling"]] <- attr(dataset, "jaspScaling")
|
506 | 515 | model[["jaspVersion"]] <- .baseCitation
|
507 |
| - model[["explainer"]] <- regressionResult[["explainer"]] |
508 |
| - class(model) <- c(class(regressionResult[["model"]]), "jaspRegression", "jaspMachineLearning") |
| 516 | + model[["explainer"]] <- mlResult[["explainer"]] |
| 517 | + class(model) <- c(class(mlResult[["model"]]), class, "jaspMachineLearning") |
509 | 518 | path <- options[["savePath"]]
|
510 | 519 | if (!endsWith(path, ".jaspML")) {
|
511 | 520 | path <- paste0(path, ".jaspML")
|
512 | 521 | }
|
513 | 522 | saveRDS(model, file = path)
|
| 523 | + |
| 524 | + "success" |
| 525 | + }) |
| 526 | + return(list(exists = file.exists(path), error = error)) |
| 527 | +} |
| 528 | + |
| 529 | +.mlPossiblyShowSaveResult <- function(table, saveResult, options) { |
| 530 | + |
| 531 | + if (is.null(saveResult)) |
| 532 | + return() |
| 533 | + |
| 534 | + if (identical(saveResult[["error"]], "success") && isTRUE(saveResult[["exists"]])) { |
| 535 | + table$addFootnote(gettextf("The model is saved as <i>%1$s</i>.", basename(options[["savePath"]]))) |
| 536 | + } else if (!identical(saveResult[["error"]], "success")) { |
| 537 | + table$addFootnote(gettextf("The model could not be saved because the following error occured: %s", saveResult[["error"]][["message"]])) |
| 538 | + } else if (!isTRUE(saveResult[["exists"]]) && !is.null(saveResult[["error"]])) { |
| 539 | + table$addFootnote(gettextf("The model could not be saved because an unexpected error occured.")) |
514 | 540 | }
|
515 | 541 | }
|
516 | 542 |
|
|
0 commit comments