Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion R/commonMachineLearningClassification.R
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@
model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
model[["jaspVersion"]] <- .baseCitation
model[["explainer"]] <- classificationResult[["explainer"]]
model <- .decodeJaspMLobject(model)
class(model) <- c(class(classificationResult[["model"]]), "jaspClassification", "jaspMachineLearning")
path <- options[["savePath"]]
if (!endsWith(path, ".jaspML")) {
Expand Down
5 changes: 2 additions & 3 deletions R/commonMachineLearningRegression.R
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@
model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
model[["jaspVersion"]] <- .baseCitation
model[["explainer"]] <- regressionResult[["explainer"]]
model <- .decodeJaspMLobject(model)
class(model) <- c(class(regressionResult[["model"]]), "jaspRegression", "jaspMachineLearning")
path <- options[["savePath"]]
if (!endsWith(path, ".jaspML")) {
Expand Down Expand Up @@ -713,7 +712,7 @@
} else {
purpose <- "classification"
}
predictors <- options[["predictors"]][which(decodeColNames(options[["predictors"]]) %in% model[["jaspVars"]][["decoded"]]$predictors)]
predictors <- model[["jaspVars"]][["encoded"]]$predictors
} else {
predictors <- options[["predictors"]]
}
Expand All @@ -724,7 +723,7 @@
} else {
.mlClassificationDependencies(options)
},
"tableShap", "fromIndex", "toIndex"
"tableShap", "fromIndex", "toIndex", "trainedModelFilePath"
))
table$addColumnInfo(name = "id", title = gettext("Case"), type = "integer")
if (purpose == "regression") {
Expand Down
126 changes: 36 additions & 90 deletions R/mlPrediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ mlPrediction <- function(jaspResults, dataset, options, ...) {

# Preparatory work
model <- .mlPredictionReadModel(options)
dataset <- .mlPredictionReadData(dataset, options)
dataset <- .mlPredictionReadData(dataset, options, model)

# Check if analysis is ready
ready <- .mlPredictionReady(model, dataset, options)
Expand Down Expand Up @@ -145,7 +145,6 @@ is.jaspMachineLearning <- function(x) {
as.character(levels(as.factor(model$model[[model[["jaspVars"]][["encoded"]]$target]]))[round(predict(model, newdata = dataset, type = "response"), 0) + 1])
}
.mlPredictionGetPredictions.vglm <- function(model, dataset) {
model[["original"]]@terms$terms <- model[["terms"]]
logodds <- predict(model[["original"]], newdata = dataset)
ncategories <- ncol(logodds) + 1
probabilities <- matrix(0, nrow = nrow(logodds), ncol = ncategories)
Expand Down Expand Up @@ -200,76 +199,6 @@ is.jaspMachineLearning <- function(x) {
nrow(model[["x"]])
}

# S3 method to decode the model variables in the result object
# so that they can be matched to variables in the prediction analysis
.decodeJaspMLobject <- function(model) {
UseMethod(".decodeJaspMLobject", model)
}
.decodeJaspMLobject.kknn <- function(model) {
formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + ")))
model[["predictive"]]$terms <- stats::terms(formula)
colnames(model[["predictive"]]$data) <- decodeColNames(colnames(model[["predictive"]]$data))
return(model)
}
.decodeJaspMLobject.lda <- function(model) {
formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + ")))
model$terms <- stats::terms(formula)
return(model)
}
.decodeJaspMLobject.lm <- function(model) {
formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + ")))
model$terms <- stats::terms(formula)
return(model)
}
.decodeJaspMLobject.gbm <- function(model) {
model[["var.names"]] <- decodeColNames(model[["var.names"]])
return(model)
}
.decodeJaspMLobject.randomForest <- function(model) {
rownames(model$importance) <- decodeColNames(rownames(model$importance))
names(model$forest$xlevels) <- decodeColNames(names(model$forest$xlevels))
formula <- formula(paste("DOESNOTMATTER", "~", paste0(decodeColNames(names(model$forest$xlevels)), collapse = " + ")))
model$terms <- stats::terms(formula)
class(model) <- c(class(model), "randomForest.formula")
return(model)
}
.decodeJaspMLobject.cv.glmnet <- function(model) {
rownames(model[["glmnet.fit"]][["beta"]]) <- decodeColNames(rownames(model[["glmnet.fit"]][["beta"]]))
return(model)
}
.decodeJaspMLobject.nn <- function(model) {
model[["model.list"]]$variables <- decodeColNames(model[["model.list"]]$variables)
return(model)
}
.decodeJaspMLobject.rpart <- function(model) {
formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + ")))
model$terms <- stats::terms(formula)
model$frame$var <- decodeColNames(model$frame$var)
rownames(model$splits) <- decodeColNames(rownames(model$splits))
return(model)
}
.decodeJaspMLobject.svm <- function(model) {
formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~ 0 +", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + ")))
model$terms <- stats::terms(formula)
return(model)
}
.decodeJaspMLobject.naiveBayes <- function(model) {
names(model[["isnumeric"]]) <- decodeColNames(names(model[["isnumeric"]]))
names(model[["tables"]]) <- decodeColNames(names(model[["tables"]]))
return(model)
}
.decodeJaspMLobject.glm <- function(model) {
vars <- all.vars(stats::terms(model))
formula <- formula(paste(decodeColNames(vars[1]), "~", paste0(decodeColNames(vars[-1]), collapse = " + ")))
model$terms <- stats::terms(formula)
return(model)
}
.decodeJaspMLobject.vglm <- function(model) {
formula <- formula(paste(decodeColNames(strsplit(as.character(model$terms), " ")[[1]][1]), "~", paste0(decodeColNames(strsplit(strsplit(as.character(model$terms), split = " ~ ")[[1]][2], split = " + ", fixed = TRUE)[[1]]), collapse = " + ")))
model$terms <- stats::terms(formula)
return(model)
}

.mlPredictionReadModel <- function(options) {
if (options[["trainedModelFilePath"]] != "") {
model <- try({
Expand All @@ -293,20 +222,31 @@ is.jaspMachineLearning <- function(x) {
# also define methods for other objects
.mlPredictionReady <- function(model, dataset, options) {
if (!is.null(model)) {
modelVars <- model[["jaspVars"]][["decoded"]]$predictors
presentVars <- decodeColNames(colnames(dataset))
modelVars <- model[["jaspVars"]][["encoded"]]$predictors
presentVars <- colnames(dataset)
ready <- all(modelVars %in% presentVars)
} else {
ready <- FALSE
}
return(ready)
}

.mlPredictionReadData <- function(dataset, options) {
# Ensure names in prediction data match names in training data
.matchDecodedNames <- function(names, model) {
decoded <- model[["jaspVars"]][["decoded"]]$predictors
encoded <- model[["jaspVars"]][["encoded"]]$predictors
matched_indices <- match(decodeColNames(names), decoded)
names[!is.na(matched_indices)] <- encoded[matched_indices[!is.na(matched_indices)]]
return(names)
}

.mlPredictionReadData <- function(dataset, options, model) {
dataset <- jaspBase::excludeNaListwise(dataset, options[["predictors"]])
if (options[["scaleVariables"]] && length(unlist(options[["predictors"]])) > 0) {
dataset <- .scaleNumericData(dataset)
}
dataset <- dataset[, which(decodeColNames(colnames(dataset)) %in% model[["jaspVars"]][["decoded"]]$predictors)] # Filter only predictors to prevent accidental double column names
colnames(dataset) <- .matchDecodedNames(colnames(dataset), model)
return(dataset)
}

Expand All @@ -315,7 +255,7 @@ is.jaspMachineLearning <- function(x) {
return(jaspResults[["predictions"]]$object)
} else {
if (ready) {
colnames(dataset) <- decodeColNames(colnames(dataset))
dataset <- dataset[, which(colnames(dataset) %in% model[["jaspVars"]][["encoded"]]$predictors)]
jaspResults[["predictions"]] <- createJaspState(.mlPredictionGetPredictions(model, dataset))
jaspResults[["predictions"]]$dependOn(options = c("loadPath", "predictors", "scaleVariables"))
return(jaspResults[["predictions"]]$object)
Expand Down Expand Up @@ -344,14 +284,16 @@ is.jaspMachineLearning <- function(x) {
if (is.null(model)) {
return()
}
modelVars <- model[["jaspVars"]][["decoded"]]$predictors
presentVars <- decodeColNames(colnames(dataset))
if (!all(modelVars %in% presentVars)) {
missingVars <- modelVars[which(!(modelVars %in% presentVars))]
modelVars_encoded <- model[["jaspVars"]][["encoded"]]$predictors
modelVars_decoded <- model[["jaspVars"]][["decoded"]]$predictors
presentVars_encoded <- colnames(dataset)
presentVars_decoded <- decodeColNames(options[["predictors"]])
if (!all(modelVars_decoded %in% presentVars_decoded)) {
missingVars <- modelVars_decoded[which(!(modelVars_decoded %in% presentVars_decoded))]
table$addFootnote(gettextf("The trained model is not applied because the the following features are missing: <i>%1$s</i>.", paste0(missingVars, collapse = ", ")))
}
if (!all(presentVars %in% modelVars)) {
unusedVars <- presentVars[which(!(presentVars %in% modelVars))]
if (!all(presentVars_decoded %in% modelVars_decoded)) {
unusedVars <- presentVars_decoded[which(!(presentVars_decoded %in% modelVars_decoded))]
table$addFootnote(gettextf("The following features are unused because they are not a feature variable in the trained model: <i>%1$s</i>.", paste0(unusedVars, collapse = ", ")))
}
if (inherits(model, "kknn")) {
Expand Down Expand Up @@ -394,7 +336,7 @@ is.jaspMachineLearning <- function(x) {
row[["family"]] <- gettext("Multinomial")
row[["link"]] <- gettext("Logit")
}
if (length(presentVars) > 0) {
if (length(presentVars_encoded) > 0) {
row[["nnew"]] <- nrow(dataset)
}
table$addRows(row)
Expand All @@ -405,7 +347,7 @@ is.jaspMachineLearning <- function(x) {
return()
}
table <- createJaspTable(gettext("Predictions for New Data"))
table$dependOn(options = c("predictors", "loadPath", "predictionsTable", "predictionsTableFeatures", "scaleVariables", "fromIndex", "toIndex"))
table$dependOn(options = c("predictors", "trainedModelFilePath", "predictionsTable", "predictionsTableFeatures", "scaleVariables", "fromIndex", "toIndex"))
table$position <- position
table$addColumnInfo(name = "row", title = gettext("Case"), type = "integer")
if (!is.null(model) && inherits(model, "jaspClassification")) {
Expand All @@ -422,16 +364,20 @@ is.jaspMachineLearning <- function(x) {
selection <- predictions[indexes]
cols <- list(row = indexes, pred = selection)
if (options[["predictionsTableFeatures"]]) {
for (i in colnames(dataset)) {
if (.columnIsNominal(i)) {
table$addColumnInfo(name = i, title = i, type = "string")
modelVars_encoded <- model[["jaspVars"]][["encoded"]]$predictors
modelVars_decoded <- model[["jaspVars"]][["decoded"]]$predictors
matched_names <- match(colnames(dataset), modelVars_encoded)
for (i in seq_len(ncol(dataset))) {
colName <- modelVars_decoded[matched_names[i]]
if (is.factor(dataset[[i]])) {
table$addColumnInfo(name = colName, title = colName, type = "string")
var <- levels(dataset[[i]])[dataset[[i]]]
} else {
table$addColumnInfo(name = i, title = i, type = "number")
table$addColumnInfo(name = colName, title = colName, type = "number")
var <- dataset[[i]]
}
var <- var[indexes]
cols[[i]] <- var
cols[[colName]] <- var
}
}
table$setData(cols)
Expand All @@ -442,7 +388,7 @@ is.jaspMachineLearning <- function(x) {
predictionsColumn <- rep(NA, max(as.numeric(rownames(dataset))))
predictionsColumn[as.numeric(rownames(dataset))] <- .mlPredictionsState(model, dataset, options, jaspResults, ready)
jaspResults[["predictionsColumn"]] <- createJaspColumn(columnName = options[["predictionsColumn"]])
jaspResults[["predictionsColumn"]]$dependOn(options = c("predictionsColumn", "predictors", "loadPath", "scaleVariables", "addPredictions"))
jaspResults[["predictionsColumn"]]$dependOn(options = c("predictionsColumn", "predictors", "trainedModelFilePath", "scaleVariables", "addPredictions"))
if (inherits(model, "jaspClassification")) jaspResults[["predictionsColumn"]]$setNominal(predictionsColumn)
if (inherits(model, "jaspRegression")) jaspResults[["predictionsColumn"]]$setScale(predictionsColumn)
}
Expand Down
Loading