From d4ec9fc789ed05dce8bd519f0d1c62cf45104b12 Mon Sep 17 00:00:00 2001 From: Koen Derks Date: Fri, 22 Nov 2024 11:47:51 +0100 Subject: [PATCH 1/5] Make explain predictions work for new data again --- R/commonMachineLearningRegression.R | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/R/commonMachineLearningRegression.R b/R/commonMachineLearningRegression.R index ede6fc57..723e1496 100644 --- a/R/commonMachineLearningRegression.R +++ b/R/commonMachineLearningRegression.R @@ -714,6 +714,14 @@ purpose <- "classification" } predictors <- options[["predictors"]][which(decodeColNames(options[["predictors"]]) %in% model[["jaspVars"]][["decoded"]]$predictors)] + # Ensure same encoding in prediction data as in training data + for (i in seq_along(predictors)) { + for (j in seq_along(model[["jaspVars"]][["decoded"]]$predictors)) { + if (decodeColNames(predictors[i]) == model[["jaspVars"]][["decoded"]]$predictors[j]) { + predictors[i] <- model[["jaspVars"]][["encoded"]]$predictors[j] + } + } + } } else { predictors <- options[["predictors"]] } @@ -746,6 +754,14 @@ x_test <- result[["test"]][, predictors] } else { explainer <- model[["explainer"]] + # Ensure same encoding in prediction data as in training data + for (i in seq_along(colnames(dataset))) { + for (j in seq_along(model[["jaspVars"]][["decoded"]]$predictors)) { + if (decodeColNames(colnames(dataset)[i]) == model[["jaspVars"]][["decoded"]]$predictors[j]) { + colnames(dataset)[i] <- model[["jaspVars"]][["encoded"]]$predictors[j] + } + } + } x_test <- dataset[, predictors] predictions <- .mlPredictionsState(model, dataset, options, jaspResults, ready)[options[["fromIndex"]]:options[["toIndex"]]] } From 364f2a9e266538c2fb64f5106396d77a9eced3f4 Mon Sep 17 00:00:00 2001 From: Koen Derks Date: Fri, 22 Nov 2024 12:58:27 +0100 Subject: [PATCH 2/5] Use a function --- R/commonMachineLearningRegression.R | 18 ++---------------- R/mlPrediction.R | 8 ++++++++ 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/R/commonMachineLearningRegression.R b/R/commonMachineLearningRegression.R index 723e1496..13fdfeb4 100644 --- a/R/commonMachineLearningRegression.R +++ b/R/commonMachineLearningRegression.R @@ -714,14 +714,7 @@ purpose <- "classification" } predictors <- options[["predictors"]][which(decodeColNames(options[["predictors"]]) %in% model[["jaspVars"]][["decoded"]]$predictors)] - # Ensure same encoding in prediction data as in training data - for (i in seq_along(predictors)) { - for (j in seq_along(model[["jaspVars"]][["decoded"]]$predictors)) { - if (decodeColNames(predictors[i]) == model[["jaspVars"]][["decoded"]]$predictors[j]) { - predictors[i] <- model[["jaspVars"]][["encoded"]]$predictors[j] - } - } - } + predictors <- .matchDecodedNames(predictors, model) } else { predictors <- options[["predictors"]] } @@ -754,14 +747,7 @@ x_test <- result[["test"]][, predictors] } else { explainer <- model[["explainer"]] - # Ensure same encoding in prediction data as in training data - for (i in seq_along(colnames(dataset))) { - for (j in seq_along(model[["jaspVars"]][["decoded"]]$predictors)) { - if (decodeColNames(colnames(dataset)[i]) == model[["jaspVars"]][["decoded"]]$predictors[j]) { - colnames(dataset)[i] <- model[["jaspVars"]][["encoded"]]$predictors[j] - } - } - } + colnames(dataset) <- .matchDecodedNames(colnames(dataset), model) x_test <- dataset[, predictors] predictions <- .mlPredictionsState(model, dataset, options, jaspResults, ready)[options[["fromIndex"]]:options[["toIndex"]]] } diff --git a/R/mlPrediction.R b/R/mlPrediction.R index 328e6a42..d4ae6b7d 100644 --- a/R/mlPrediction.R +++ b/R/mlPrediction.R @@ -302,6 +302,14 @@ is.jaspMachineLearning <- function(x) { return(ready) } +# Ensure names in prediction data match names in training data +.matchDecodedNames <- function(names, model) { + decoded <- model[["jaspVars"]][["decoded"]]$predictors + encoded <- model[["jaspVars"]][["encoded"]]$predictors + matched_indices <- match(decodeColNames(names), decoded) + names[!is.na(matched_indices)] <- encoded[matched_indices[!is.na(matched_indices)]] +} + .mlPredictionReadData <- function(dataset, options) { dataset <- jaspBase::excludeNaListwise(dataset, options[["predictors"]]) if (options[["scaleVariables"]] && length(unlist(options[["predictors"]])) > 0) { From 22248be2da900847ce098358dc1355243f9f8113 Mon Sep 17 00:00:00 2001 From: Koen Derks Date: Sun, 24 Nov 2024 13:08:26 +0100 Subject: [PATCH 3/5] No need to decode jaspML object anymore! --- R/commonMachineLearningClassification.R | 1 - R/commonMachineLearningRegression.R | 7 +- R/mlPrediction.R | 104 +++++------------------- 3 files changed, 21 insertions(+), 91 deletions(-) diff --git a/R/commonMachineLearningClassification.R b/R/commonMachineLearningClassification.R index bdb12043..74496c54 100644 --- a/R/commonMachineLearningClassification.R +++ b/R/commonMachineLearningClassification.R @@ -342,7 +342,6 @@ model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]]) model[["jaspVersion"]] <- .baseCitation model[["explainer"]] <- classificationResult[["explainer"]] - model <- .decodeJaspMLobject(model) class(model) <- c(class(classificationResult[["model"]]), "jaspClassification", "jaspMachineLearning") path <- options[["savePath"]] if (!endsWith(path, ".jaspML")) { diff --git a/R/commonMachineLearningRegression.R b/R/commonMachineLearningRegression.R index 13fdfeb4..c3edee3a 100644 --- a/R/commonMachineLearningRegression.R +++ b/R/commonMachineLearningRegression.R @@ -473,7 +473,6 @@ model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]]) model[["jaspVersion"]] <- .baseCitation model[["explainer"]] <- regressionResult[["explainer"]] - model <- .decodeJaspMLobject(model) class(model) <- c(class(regressionResult[["model"]]), "jaspRegression", "jaspMachineLearning") path <- options[["savePath"]] if (!endsWith(path, ".jaspML")) { @@ -713,8 +712,7 @@ } else { purpose <- "classification" } - predictors <- options[["predictors"]][which(decodeColNames(options[["predictors"]]) %in% model[["jaspVars"]][["decoded"]]$predictors)] - predictors <- .matchDecodedNames(predictors, model) + predictors <- model[["jaspVars"]][["encoded"]]$predictors } else { predictors <- options[["predictors"]] } @@ -725,7 +723,7 @@ } else { .mlClassificationDependencies(options) }, - "tableShap", "fromIndex", "toIndex" + "tableShap", "fromIndex", "toIndex", "trainedModelFilePath" )) table$addColumnInfo(name = "id", title = gettext("Case"), type = "integer") if (purpose == "regression") { @@ -747,7 +745,6 @@ x_test <- result[["test"]][, predictors] } else { explainer <- model[["explainer"]] - colnames(dataset) <- .matchDecodedNames(colnames(dataset), model) x_test <- dataset[, predictors] predictions <- .mlPredictionsState(model, dataset, options, jaspResults, ready)[options[["fromIndex"]]:options[["toIndex"]]] } diff --git a/R/mlPrediction.R b/R/mlPrediction.R index d4ae6b7d..21095c89 100644 --- a/R/mlPrediction.R +++ b/R/mlPrediction.R @@ -19,7 +19,7 @@ mlPrediction <- function(jaspResults, dataset, options, ...) { # Preparatory work model <- .mlPredictionReadModel(options) - dataset <- .mlPredictionReadData(dataset, options) + dataset <- .mlPredictionReadData(dataset, options, model) # Check if analysis is ready ready <- .mlPredictionReady(model, dataset, options) @@ -145,7 +145,6 @@ is.jaspMachineLearning <- function(x) { as.character(levels(as.factor(model$model[[model[["jaspVars"]][["encoded"]]$target]]))[round(predict(model, newdata = dataset, type = "response"), 0) + 1]) } .mlPredictionGetPredictions.vglm <- function(model, dataset) { - model[["original"]]@terms$terms <- model[["terms"]] logodds <- predict(model[["original"]], newdata = dataset) ncategories <- ncol(logodds) + 1 probabilities <- matrix(0, nrow = nrow(logodds), ncol = ncategories) @@ -200,76 +199,6 @@ is.jaspMachineLearning <- function(x) { nrow(model[["x"]]) } -# S3 method to decode the model variables in the result object -# so that they can be matched to variables in the prediction analysis -.decodeJaspMLobject <- function(model) { - UseMethod(".decodeJaspMLobject", model) -} -.decodeJaspMLobject.kknn <- function(model) { - formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + "))) - model[["predictive"]]$terms <- stats::terms(formula) - colnames(model[["predictive"]]$data) <- decodeColNames(colnames(model[["predictive"]]$data)) - return(model) -} -.decodeJaspMLobject.lda <- function(model) { - formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + "))) - model$terms <- stats::terms(formula) - return(model) -} -.decodeJaspMLobject.lm <- function(model) { - formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + "))) - model$terms <- stats::terms(formula) - return(model) -} -.decodeJaspMLobject.gbm <- function(model) { - model[["var.names"]] <- decodeColNames(model[["var.names"]]) - return(model) -} -.decodeJaspMLobject.randomForest <- function(model) { - rownames(model$importance) <- decodeColNames(rownames(model$importance)) - names(model$forest$xlevels) <- decodeColNames(names(model$forest$xlevels)) - formula <- formula(paste("DOESNOTMATTER", "~", paste0(decodeColNames(names(model$forest$xlevels)), collapse = " + "))) - model$terms <- stats::terms(formula) - class(model) <- c(class(model), "randomForest.formula") - return(model) -} -.decodeJaspMLobject.cv.glmnet <- function(model) { - rownames(model[["glmnet.fit"]][["beta"]]) <- decodeColNames(rownames(model[["glmnet.fit"]][["beta"]])) - return(model) -} -.decodeJaspMLobject.nn <- function(model) { - model[["model.list"]]$variables <- decodeColNames(model[["model.list"]]$variables) - return(model) -} -.decodeJaspMLobject.rpart <- function(model) { - formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + "))) - model$terms <- stats::terms(formula) - model$frame$var <- decodeColNames(model$frame$var) - rownames(model$splits) <- decodeColNames(rownames(model$splits)) - return(model) -} -.decodeJaspMLobject.svm <- function(model) { - formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~ 0 +", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + "))) - model$terms <- stats::terms(formula) - return(model) -} -.decodeJaspMLobject.naiveBayes <- function(model) { - names(model[["isnumeric"]]) <- decodeColNames(names(model[["isnumeric"]])) - names(model[["tables"]]) <- decodeColNames(names(model[["tables"]])) - return(model) -} -.decodeJaspMLobject.glm <- function(model) { - vars <- all.vars(stats::terms(model)) - formula <- formula(paste(decodeColNames(vars[1]), "~", paste0(decodeColNames(vars[-1]), collapse = " + "))) - model$terms <- stats::terms(formula) - return(model) -} -.decodeJaspMLobject.vglm <- function(model) { - formula <- formula(paste(decodeColNames(strsplit(as.character(model$terms), " ")[[1]][1]), "~", paste0(decodeColNames(strsplit(strsplit(as.character(model$terms), split = " ~ ")[[1]][2], split = " + ", fixed = TRUE)[[1]]), collapse = " + "))) - model$terms <- stats::terms(formula) - return(model) -} - .mlPredictionReadModel <- function(options) { if (options[["trainedModelFilePath"]] != "") { model <- try({ @@ -293,8 +222,8 @@ is.jaspMachineLearning <- function(x) { # also define methods for other objects .mlPredictionReady <- function(model, dataset, options) { if (!is.null(model)) { - modelVars <- model[["jaspVars"]][["decoded"]]$predictors - presentVars <- decodeColNames(colnames(dataset)) + modelVars <- model[["jaspVars"]][["encoded"]]$predictors + presentVars <- colnames(dataset) ready <- all(modelVars %in% presentVars) } else { ready <- FALSE @@ -308,13 +237,16 @@ is.jaspMachineLearning <- function(x) { encoded <- model[["jaspVars"]][["encoded"]]$predictors matched_indices <- match(decodeColNames(names), decoded) names[!is.na(matched_indices)] <- encoded[matched_indices[!is.na(matched_indices)]] + return(names) } -.mlPredictionReadData <- function(dataset, options) { +.mlPredictionReadData <- function(dataset, options, model) { dataset <- jaspBase::excludeNaListwise(dataset, options[["predictors"]]) if (options[["scaleVariables"]] && length(unlist(options[["predictors"]])) > 0) { dataset <- .scaleNumericData(dataset) } + dataset <- dataset[, which(decodeColNames(colnames(dataset)) %in% model[["jaspVars"]][["decoded"]]$predictors)] # Filter only predictors to prevent accidental double column names + colnames(dataset) <- .matchDecodedNames(colnames(dataset), model) return(dataset) } @@ -323,7 +255,7 @@ is.jaspMachineLearning <- function(x) { return(jaspResults[["predictions"]]$object) } else { if (ready) { - colnames(dataset) <- decodeColNames(colnames(dataset)) + dataset <- dataset[, which(colnames(dataset) %in% model[["jaspVars"]][["encoded"]]$predictors)] jaspResults[["predictions"]] <- createJaspState(.mlPredictionGetPredictions(model, dataset)) jaspResults[["predictions"]]$dependOn(options = c("loadPath", "predictors", "scaleVariables")) return(jaspResults[["predictions"]]$object) @@ -352,14 +284,16 @@ is.jaspMachineLearning <- function(x) { if (is.null(model)) { return() } - modelVars <- model[["jaspVars"]][["decoded"]]$predictors - presentVars <- decodeColNames(colnames(dataset)) - if (!all(modelVars %in% presentVars)) { - missingVars <- modelVars[which(!(modelVars %in% presentVars))] + modelVars_encoded <- model[["jaspVars"]][["encoded"]]$predictors + modelVars_decoded <- model[["jaspVars"]][["decoded"]]$predictors + presentVars_encoded <- colnames(dataset) + presentVars_decoded <- decodeColNames(options[["predictors"]]) + if (!all(modelVars_decoded %in% presentVars_decoded)) { + missingVars <- modelVars_decoded[which(!(modelVars_decoded %in% presentVars_decoded))] table$addFootnote(gettextf("The trained model is not applied because the the following features are missing: %1$s.", paste0(missingVars, collapse = ", "))) } - if (!all(presentVars %in% modelVars)) { - unusedVars <- presentVars[which(!(presentVars %in% modelVars))] + if (!all(presentVars_decoded %in% modelVars_decoded)) { + unusedVars <- presentVars_decoded[which(!(presentVars_decoded %in% modelVars_decoded))] table$addFootnote(gettextf("The following features are unused because they are not a feature variable in the trained model: %1$s.", paste0(unusedVars, collapse = ", "))) } if (inherits(model, "kknn")) { @@ -402,7 +336,7 @@ is.jaspMachineLearning <- function(x) { row[["family"]] <- gettext("Multinomial") row[["link"]] <- gettext("Logit") } - if (length(presentVars) > 0) { + if (length(presentVars_encoded) > 0) { row[["nnew"]] <- nrow(dataset) } table$addRows(row) @@ -413,7 +347,7 @@ is.jaspMachineLearning <- function(x) { return() } table <- createJaspTable(gettext("Predictions for New Data")) - table$dependOn(options = c("predictors", "loadPath", "predictionsTable", "predictionsTableFeatures", "scaleVariables", "fromIndex", "toIndex")) + table$dependOn(options = c("predictors", "trainedModelFilePath", "predictionsTable", "predictionsTableFeatures", "scaleVariables", "fromIndex", "toIndex")) table$position <- position table$addColumnInfo(name = "row", title = gettext("Case"), type = "integer") if (!is.null(model) && inherits(model, "jaspClassification")) { @@ -450,7 +384,7 @@ is.jaspMachineLearning <- function(x) { predictionsColumn <- rep(NA, max(as.numeric(rownames(dataset)))) predictionsColumn[as.numeric(rownames(dataset))] <- .mlPredictionsState(model, dataset, options, jaspResults, ready) jaspResults[["predictionsColumn"]] <- createJaspColumn(columnName = options[["predictionsColumn"]]) - jaspResults[["predictionsColumn"]]$dependOn(options = c("predictionsColumn", "predictors", "loadPath", "scaleVariables", "addPredictions")) + jaspResults[["predictionsColumn"]]$dependOn(options = c("predictionsColumn", "predictors", "trainedModelFilePath", "scaleVariables", "addPredictions")) if (inherits(model, "jaspClassification")) jaspResults[["predictionsColumn"]]$setNominal(predictionsColumn) if (inherits(model, "jaspRegression")) jaspResults[["predictionsColumn"]]$setScale(predictionsColumn) } From 47ffb9b8eef498a322ca7031ad3b42f3d038dfa4 Mon Sep 17 00:00:00 2001 From: Koen Derks Date: Mon, 25 Nov 2024 10:02:51 +0100 Subject: [PATCH 4/5] Ensure correct column names in predictions tabel --- R/mlPrediction.R | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/R/mlPrediction.R b/R/mlPrediction.R index 21095c89..196ce321 100644 --- a/R/mlPrediction.R +++ b/R/mlPrediction.R @@ -364,16 +364,20 @@ is.jaspMachineLearning <- function(x) { selection <- predictions[indexes] cols <- list(row = indexes, pred = selection) if (options[["predictionsTableFeatures"]]) { - for (i in colnames(dataset)) { - if (.columnIsNominal(i)) { - table$addColumnInfo(name = i, title = i, type = "string") + modelVars_encoded <- model[["jaspVars"]][["encoded"]]$predictors + modelVars_decoded <- model[["jaspVars"]][["decoded"]]$predictors + matched_names <- match(colnames(dataset), modelVars_encoded) + for (i in seq_len(ncol(dataset))) { + colName <- modelVars_decoded[matched_names[i]] + if (is.factor(dataset[[i]])) { + table$addColumnInfo(name = colName, title = colName, type = "string") var <- levels(dataset[[i]])[dataset[[i]]] } else { - table$addColumnInfo(name = i, title = i, type = "number") + table$addColumnInfo(name = colName, title = colName, type = "number") var <- dataset[[i]] } var <- var[indexes] - cols[[i]] <- var + cols[[colName]] <- var } } table$setData(cols) From 569ba927d7fc74ebf46d000dbe9fe4c8ec10338d Mon Sep 17 00:00:00 2001 From: Koen Derks Date: Mon, 25 Nov 2024 11:31:47 +0100 Subject: [PATCH 5/5] Don's comment --- R/mlPrediction.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/mlPrediction.R b/R/mlPrediction.R index 196ce321..6ca7d1a5 100644 --- a/R/mlPrediction.R +++ b/R/mlPrediction.R @@ -255,7 +255,7 @@ is.jaspMachineLearning <- function(x) { return(jaspResults[["predictions"]]$object) } else { if (ready) { - dataset <- dataset[, which(colnames(dataset) %in% model[["jaspVars"]][["encoded"]]$predictors)] + dataset <- dataset[which(colnames(dataset) %in% model[["jaspVars"]][["encoded"]]$predictors)] jaspResults[["predictions"]] <- createJaspState(.mlPredictionGetPredictions(model, dataset)) jaspResults[["predictions"]]$dependOn(options = c("loadPath", "predictors", "scaleVariables")) return(jaspResults[["predictions"]]$object)