diff --git a/R/commonMachineLearningClassification.R b/R/commonMachineLearningClassification.R index bdb12043..74496c54 100644 --- a/R/commonMachineLearningClassification.R +++ b/R/commonMachineLearningClassification.R @@ -342,7 +342,6 @@ model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]]) model[["jaspVersion"]] <- .baseCitation model[["explainer"]] <- classificationResult[["explainer"]] - model <- .decodeJaspMLobject(model) class(model) <- c(class(classificationResult[["model"]]), "jaspClassification", "jaspMachineLearning") path <- options[["savePath"]] if (!endsWith(path, ".jaspML")) { diff --git a/R/commonMachineLearningRegression.R b/R/commonMachineLearningRegression.R index ede6fc57..c3edee3a 100644 --- a/R/commonMachineLearningRegression.R +++ b/R/commonMachineLearningRegression.R @@ -473,7 +473,6 @@ model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]]) model[["jaspVersion"]] <- .baseCitation model[["explainer"]] <- regressionResult[["explainer"]] - model <- .decodeJaspMLobject(model) class(model) <- c(class(regressionResult[["model"]]), "jaspRegression", "jaspMachineLearning") path <- options[["savePath"]] if (!endsWith(path, ".jaspML")) { @@ -713,7 +712,7 @@ } else { purpose <- "classification" } - predictors <- options[["predictors"]][which(decodeColNames(options[["predictors"]]) %in% model[["jaspVars"]][["decoded"]]$predictors)] + predictors <- model[["jaspVars"]][["encoded"]]$predictors } else { predictors <- options[["predictors"]] } @@ -724,7 +723,7 @@ } else { .mlClassificationDependencies(options) }, - "tableShap", "fromIndex", "toIndex" + "tableShap", "fromIndex", "toIndex", "trainedModelFilePath" )) table$addColumnInfo(name = "id", title = gettext("Case"), type = "integer") if (purpose == "regression") { diff --git a/R/mlPrediction.R b/R/mlPrediction.R index 328e6a42..6ca7d1a5 100644 --- a/R/mlPrediction.R +++ b/R/mlPrediction.R @@ -19,7 +19,7 @@ mlPrediction <- function(jaspResults, dataset, options, ...) { # Preparatory work model <- .mlPredictionReadModel(options) - dataset <- .mlPredictionReadData(dataset, options) + dataset <- .mlPredictionReadData(dataset, options, model) # Check if analysis is ready ready <- .mlPredictionReady(model, dataset, options) @@ -145,7 +145,6 @@ is.jaspMachineLearning <- function(x) { as.character(levels(as.factor(model$model[[model[["jaspVars"]][["encoded"]]$target]]))[round(predict(model, newdata = dataset, type = "response"), 0) + 1]) } .mlPredictionGetPredictions.vglm <- function(model, dataset) { - model[["original"]]@terms$terms <- model[["terms"]] logodds <- predict(model[["original"]], newdata = dataset) ncategories <- ncol(logodds) + 1 probabilities <- matrix(0, nrow = nrow(logodds), ncol = ncategories) @@ -200,76 +199,6 @@ is.jaspMachineLearning <- function(x) { nrow(model[["x"]]) } -# S3 method to decode the model variables in the result object -# so that they can be matched to variables in the prediction analysis -.decodeJaspMLobject <- function(model) { - UseMethod(".decodeJaspMLobject", model) -} -.decodeJaspMLobject.kknn <- function(model) { - formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + "))) - model[["predictive"]]$terms <- stats::terms(formula) - colnames(model[["predictive"]]$data) <- decodeColNames(colnames(model[["predictive"]]$data)) - return(model) -} -.decodeJaspMLobject.lda <- function(model) { - formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + "))) - model$terms <- stats::terms(formula) - return(model) -} -.decodeJaspMLobject.lm <- function(model) { - formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + "))) - model$terms <- stats::terms(formula) - return(model) -} -.decodeJaspMLobject.gbm <- function(model) { - model[["var.names"]] <- decodeColNames(model[["var.names"]]) - return(model) -} -.decodeJaspMLobject.randomForest <- function(model) { - rownames(model$importance) <- decodeColNames(rownames(model$importance)) - names(model$forest$xlevels) <- decodeColNames(names(model$forest$xlevels)) - formula <- formula(paste("DOESNOTMATTER", "~", paste0(decodeColNames(names(model$forest$xlevels)), collapse = " + "))) - model$terms <- stats::terms(formula) - class(model) <- c(class(model), "randomForest.formula") - return(model) -} -.decodeJaspMLobject.cv.glmnet <- function(model) { - rownames(model[["glmnet.fit"]][["beta"]]) <- decodeColNames(rownames(model[["glmnet.fit"]][["beta"]])) - return(model) -} -.decodeJaspMLobject.nn <- function(model) { - model[["model.list"]]$variables <- decodeColNames(model[["model.list"]]$variables) - return(model) -} -.decodeJaspMLobject.rpart <- function(model) { - formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + "))) - model$terms <- stats::terms(formula) - model$frame$var <- decodeColNames(model$frame$var) - rownames(model$splits) <- decodeColNames(rownames(model$splits)) - return(model) -} -.decodeJaspMLobject.svm <- function(model) { - formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~ 0 +", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + "))) - model$terms <- stats::terms(formula) - return(model) -} -.decodeJaspMLobject.naiveBayes <- function(model) { - names(model[["isnumeric"]]) <- decodeColNames(names(model[["isnumeric"]])) - names(model[["tables"]]) <- decodeColNames(names(model[["tables"]])) - return(model) -} -.decodeJaspMLobject.glm <- function(model) { - vars <- all.vars(stats::terms(model)) - formula <- formula(paste(decodeColNames(vars[1]), "~", paste0(decodeColNames(vars[-1]), collapse = " + "))) - model$terms <- stats::terms(formula) - return(model) -} -.decodeJaspMLobject.vglm <- function(model) { - formula <- formula(paste(decodeColNames(strsplit(as.character(model$terms), " ")[[1]][1]), "~", paste0(decodeColNames(strsplit(strsplit(as.character(model$terms), split = " ~ ")[[1]][2], split = " + ", fixed = TRUE)[[1]]), collapse = " + "))) - model$terms <- stats::terms(formula) - return(model) -} - .mlPredictionReadModel <- function(options) { if (options[["trainedModelFilePath"]] != "") { model <- try({ @@ -293,8 +222,8 @@ is.jaspMachineLearning <- function(x) { # also define methods for other objects .mlPredictionReady <- function(model, dataset, options) { if (!is.null(model)) { - modelVars <- model[["jaspVars"]][["decoded"]]$predictors - presentVars <- decodeColNames(colnames(dataset)) + modelVars <- model[["jaspVars"]][["encoded"]]$predictors + presentVars <- colnames(dataset) ready <- all(modelVars %in% presentVars) } else { ready <- FALSE @@ -302,11 +231,22 @@ is.jaspMachineLearning <- function(x) { return(ready) } -.mlPredictionReadData <- function(dataset, options) { +# Ensure names in prediction data match names in training data +.matchDecodedNames <- function(names, model) { + decoded <- model[["jaspVars"]][["decoded"]]$predictors + encoded <- model[["jaspVars"]][["encoded"]]$predictors + matched_indices <- match(decodeColNames(names), decoded) + names[!is.na(matched_indices)] <- encoded[matched_indices[!is.na(matched_indices)]] + return(names) +} + +.mlPredictionReadData <- function(dataset, options, model) { dataset <- jaspBase::excludeNaListwise(dataset, options[["predictors"]]) if (options[["scaleVariables"]] && length(unlist(options[["predictors"]])) > 0) { dataset <- .scaleNumericData(dataset) } + dataset <- dataset[, which(decodeColNames(colnames(dataset)) %in% model[["jaspVars"]][["decoded"]]$predictors)] # Filter only predictors to prevent accidental double column names + colnames(dataset) <- .matchDecodedNames(colnames(dataset), model) return(dataset) } @@ -315,7 +255,7 @@ is.jaspMachineLearning <- function(x) { return(jaspResults[["predictions"]]$object) } else { if (ready) { - colnames(dataset) <- decodeColNames(colnames(dataset)) + dataset <- dataset[which(colnames(dataset) %in% model[["jaspVars"]][["encoded"]]$predictors)] jaspResults[["predictions"]] <- createJaspState(.mlPredictionGetPredictions(model, dataset)) jaspResults[["predictions"]]$dependOn(options = c("loadPath", "predictors", "scaleVariables")) return(jaspResults[["predictions"]]$object) @@ -344,14 +284,16 @@ is.jaspMachineLearning <- function(x) { if (is.null(model)) { return() } - modelVars <- model[["jaspVars"]][["decoded"]]$predictors - presentVars <- decodeColNames(colnames(dataset)) - if (!all(modelVars %in% presentVars)) { - missingVars <- modelVars[which(!(modelVars %in% presentVars))] + modelVars_encoded <- model[["jaspVars"]][["encoded"]]$predictors + modelVars_decoded <- model[["jaspVars"]][["decoded"]]$predictors + presentVars_encoded <- colnames(dataset) + presentVars_decoded <- decodeColNames(options[["predictors"]]) + if (!all(modelVars_decoded %in% presentVars_decoded)) { + missingVars <- modelVars_decoded[which(!(modelVars_decoded %in% presentVars_decoded))] table$addFootnote(gettextf("The trained model is not applied because the the following features are missing: %1$s.", paste0(missingVars, collapse = ", "))) } - if (!all(presentVars %in% modelVars)) { - unusedVars <- presentVars[which(!(presentVars %in% modelVars))] + if (!all(presentVars_decoded %in% modelVars_decoded)) { + unusedVars <- presentVars_decoded[which(!(presentVars_decoded %in% modelVars_decoded))] table$addFootnote(gettextf("The following features are unused because they are not a feature variable in the trained model: %1$s.", paste0(unusedVars, collapse = ", "))) } if (inherits(model, "kknn")) { @@ -394,7 +336,7 @@ is.jaspMachineLearning <- function(x) { row[["family"]] <- gettext("Multinomial") row[["link"]] <- gettext("Logit") } - if (length(presentVars) > 0) { + if (length(presentVars_encoded) > 0) { row[["nnew"]] <- nrow(dataset) } table$addRows(row) @@ -405,7 +347,7 @@ is.jaspMachineLearning <- function(x) { return() } table <- createJaspTable(gettext("Predictions for New Data")) - table$dependOn(options = c("predictors", "loadPath", "predictionsTable", "predictionsTableFeatures", "scaleVariables", "fromIndex", "toIndex")) + table$dependOn(options = c("predictors", "trainedModelFilePath", "predictionsTable", "predictionsTableFeatures", "scaleVariables", "fromIndex", "toIndex")) table$position <- position table$addColumnInfo(name = "row", title = gettext("Case"), type = "integer") if (!is.null(model) && inherits(model, "jaspClassification")) { @@ -422,16 +364,20 @@ is.jaspMachineLearning <- function(x) { selection <- predictions[indexes] cols <- list(row = indexes, pred = selection) if (options[["predictionsTableFeatures"]]) { - for (i in colnames(dataset)) { - if (.columnIsNominal(i)) { - table$addColumnInfo(name = i, title = i, type = "string") + modelVars_encoded <- model[["jaspVars"]][["encoded"]]$predictors + modelVars_decoded <- model[["jaspVars"]][["decoded"]]$predictors + matched_names <- match(colnames(dataset), modelVars_encoded) + for (i in seq_len(ncol(dataset))) { + colName <- modelVars_decoded[matched_names[i]] + if (is.factor(dataset[[i]])) { + table$addColumnInfo(name = colName, title = colName, type = "string") var <- levels(dataset[[i]])[dataset[[i]]] } else { - table$addColumnInfo(name = i, title = i, type = "number") + table$addColumnInfo(name = colName, title = colName, type = "number") var <- dataset[[i]] } var <- var[indexes] - cols[[i]] <- var + cols[[colName]] <- var } } table$setData(cols) @@ -442,7 +388,7 @@ is.jaspMachineLearning <- function(x) { predictionsColumn <- rep(NA, max(as.numeric(rownames(dataset)))) predictionsColumn[as.numeric(rownames(dataset))] <- .mlPredictionsState(model, dataset, options, jaspResults, ready) jaspResults[["predictionsColumn"]] <- createJaspColumn(columnName = options[["predictionsColumn"]]) - jaspResults[["predictionsColumn"]]$dependOn(options = c("predictionsColumn", "predictors", "loadPath", "scaleVariables", "addPredictions")) + jaspResults[["predictionsColumn"]]$dependOn(options = c("predictionsColumn", "predictors", "trainedModelFilePath", "scaleVariables", "addPredictions")) if (inherits(model, "jaspClassification")) jaspResults[["predictionsColumn"]]$setNominal(predictionsColumn) if (inherits(model, "jaspRegression")) jaspResults[["predictionsColumn"]]$setScale(predictionsColumn) }