Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion R/commonMachineLearningClassification.R
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@
model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
model[["jaspVersion"]] <- .baseCitation
model[["explainer"]] <- classificationResult[["explainer"]]
model <- .decodeJaspMLobject(model)
class(model) <- c(class(classificationResult[["model"]]), "jaspClassification", "jaspMachineLearning")
path <- options[["savePath"]]
if (!endsWith(path, ".jaspML")) {
Expand Down
5 changes: 2 additions & 3 deletions R/commonMachineLearningRegression.R
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@
model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
model[["jaspVersion"]] <- .baseCitation
model[["explainer"]] <- regressionResult[["explainer"]]
model <- .decodeJaspMLobject(model)
class(model) <- c(class(regressionResult[["model"]]), "jaspRegression", "jaspMachineLearning")
path <- options[["savePath"]]
if (!endsWith(path, ".jaspML")) {
Expand Down Expand Up @@ -713,7 +712,7 @@
} else {
purpose <- "classification"
}
predictors <- options[["predictors"]][which(decodeColNames(options[["predictors"]]) %in% model[["jaspVars"]][["decoded"]]$predictors)]
predictors <- model[["jaspVars"]][["encoded"]]$predictors
} else {
predictors <- options[["predictors"]]
}
Expand All @@ -724,7 +723,7 @@
} else {
.mlClassificationDependencies(options)
},
"tableShap", "fromIndex", "toIndex"
"tableShap", "fromIndex", "toIndex", "trainedModelFilePath"
))
table$addColumnInfo(name = "id", title = gettext("Case"), type = "integer")
if (purpose == "regression") {
Expand Down
126 changes: 36 additions & 90 deletions R/mlPrediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ mlPrediction <- function(jaspResults, dataset, options, ...) {

# Preparatory work
model <- .mlPredictionReadModel(options)
dataset <- .mlPredictionReadData(dataset, options)
dataset <- .mlPredictionReadData(dataset, options, model)

# Check if analysis is ready
ready <- .mlPredictionReady(model, dataset, options)
Expand Down Expand Up @@ -145,7 +145,6 @@ is.jaspMachineLearning <- function(x) {
as.character(levels(as.factor(model$model[[model[["jaspVars"]][["encoded"]]$target]]))[round(predict(model, newdata = dataset, type = "response"), 0) + 1])
}
.mlPredictionGetPredictions.vglm <- function(model, dataset) {
model[["original"]]@terms$terms <- model[["terms"]]
logodds <- predict(model[["original"]], newdata = dataset)
ncategories <- ncol(logodds) + 1
probabilities <- matrix(0, nrow = nrow(logodds), ncol = ncategories)
Expand Down Expand Up @@ -200,76 +199,6 @@ is.jaspMachineLearning <- function(x) {
nrow(model[["x"]])
}

# S3 method to decode the model variables in the result object
# so that they can be matched to variables in the prediction analysis
.decodeJaspMLobject <- function(model) {
UseMethod(".decodeJaspMLobject", model)
}
.decodeJaspMLobject.kknn <- function(model) {
formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + ")))
model[["predictive"]]$terms <- stats::terms(formula)
colnames(model[["predictive"]]$data) <- decodeColNames(colnames(model[["predictive"]]$data))
return(model)
}
.decodeJaspMLobject.lda <- function(model) {
formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + ")))
model$terms <- stats::terms(formula)
return(model)
}
.decodeJaspMLobject.lm <- function(model) {
formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + ")))
model$terms <- stats::terms(formula)
return(model)
}
.decodeJaspMLobject.gbm <- function(model) {
model[["var.names"]] <- decodeColNames(model[["var.names"]])
return(model)
}
.decodeJaspMLobject.randomForest <- function(model) {
rownames(model$importance) <- decodeColNames(rownames(model$importance))
names(model$forest$xlevels) <- decodeColNames(names(model$forest$xlevels))
formula <- formula(paste("DOESNOTMATTER", "~", paste0(decodeColNames(names(model$forest$xlevels)), collapse = " + ")))
model$terms <- stats::terms(formula)
class(model) <- c(class(model), "randomForest.formula")
return(model)
}
.decodeJaspMLobject.cv.glmnet <- function(model) {
rownames(model[["glmnet.fit"]][["beta"]]) <- decodeColNames(rownames(model[["glmnet.fit"]][["beta"]]))
return(model)
}
.decodeJaspMLobject.nn <- function(model) {
model[["model.list"]]$variables <- decodeColNames(model[["model.list"]]$variables)
return(model)
}
.decodeJaspMLobject.rpart <- function(model) {
formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + ")))
model$terms <- stats::terms(formula)
model$frame$var <- decodeColNames(model$frame$var)
rownames(model$splits) <- decodeColNames(rownames(model$splits))
return(model)
}
.decodeJaspMLobject.svm <- function(model) {
formula <- formula(paste(decodeColNames(as.character(model$terms)[2]), "~ 0 +", paste0(decodeColNames(strsplit(as.character(model$terms)[3], split = " + ", fixed = TRUE)[[1]]), collapse = " + ")))
model$terms <- stats::terms(formula)
return(model)
}
.decodeJaspMLobject.naiveBayes <- function(model) {
names(model[["isnumeric"]]) <- decodeColNames(names(model[["isnumeric"]]))
names(model[["tables"]]) <- decodeColNames(names(model[["tables"]]))
return(model)
}
.decodeJaspMLobject.glm <- function(model) {
vars <- all.vars(stats::terms(model))
formula <- formula(paste(decodeColNames(vars[1]), "~", paste0(decodeColNames(vars[-1]), collapse = " + ")))
model$terms <- stats::terms(formula)
return(model)
}
.decodeJaspMLobject.vglm <- function(model) {
formula <- formula(paste(decodeColNames(strsplit(as.character(model$terms), " ")[[1]][1]), "~", paste0(decodeColNames(strsplit(strsplit(as.character(model$terms), split = " ~ ")[[1]][2], split = " + ", fixed = TRUE)[[1]]), collapse = " + ")))
model$terms <- stats::terms(formula)
return(model)
}

.mlPredictionReadModel <- function(options) {
if (options[["trainedModelFilePath"]] != "") {
model <- try({
Expand All @@ -293,20 +222,31 @@ is.jaspMachineLearning <- function(x) {
# also define methods for other objects
.mlPredictionReady <- function(model, dataset, options) {
if (!is.null(model)) {
modelVars <- model[["jaspVars"]][["decoded"]]$predictors
presentVars <- decodeColNames(colnames(dataset))
modelVars <- model[["jaspVars"]][["encoded"]]$predictors
presentVars <- colnames(dataset)
ready <- all(modelVars %in% presentVars)
} else {
ready <- FALSE
}
return(ready)
}

.mlPredictionReadData <- function(dataset, options) {
# Ensure names in prediction data match names in training data
.matchDecodedNames <- function(names, model) {
decoded <- model[["jaspVars"]][["decoded"]]$predictors
encoded <- model[["jaspVars"]][["encoded"]]$predictors
matched_indices <- match(decodeColNames(names), decoded)
names[!is.na(matched_indices)] <- encoded[matched_indices[!is.na(matched_indices)]]
return(names)
}

.mlPredictionReadData <- function(dataset, options, model) {
dataset <- jaspBase::excludeNaListwise(dataset, options[["predictors"]])
if (options[["scaleVariables"]] && length(unlist(options[["predictors"]])) > 0) {
dataset <- .scaleNumericData(dataset)
}
dataset <- dataset[, which(decodeColNames(colnames(dataset)) %in% model[["jaspVars"]][["decoded"]]$predictors)] # Filter only predictors to prevent accidental double column names
colnames(dataset) <- .matchDecodedNames(colnames(dataset), model)
return(dataset)
}

Expand All @@ -315,7 +255,7 @@ is.jaspMachineLearning <- function(x) {
return(jaspResults[["predictions"]]$object)
} else {
if (ready) {
colnames(dataset) <- decodeColNames(colnames(dataset))
dataset <- dataset[which(colnames(dataset) %in% model[["jaspVars"]][["encoded"]]$predictors)]
jaspResults[["predictions"]] <- createJaspState(.mlPredictionGetPredictions(model, dataset))
jaspResults[["predictions"]]$dependOn(options = c("loadPath", "predictors", "scaleVariables"))
return(jaspResults[["predictions"]]$object)
Expand Down Expand Up @@ -344,14 +284,16 @@ is.jaspMachineLearning <- function(x) {
if (is.null(model)) {
return()
}
modelVars <- model[["jaspVars"]][["decoded"]]$predictors
presentVars <- decodeColNames(colnames(dataset))
if (!all(modelVars %in% presentVars)) {
missingVars <- modelVars[which(!(modelVars %in% presentVars))]
modelVars_encoded <- model[["jaspVars"]][["encoded"]]$predictors
modelVars_decoded <- model[["jaspVars"]][["decoded"]]$predictors
presentVars_encoded <- colnames(dataset)
presentVars_decoded <- decodeColNames(options[["predictors"]])
if (!all(modelVars_decoded %in% presentVars_decoded)) {
missingVars <- modelVars_decoded[which(!(modelVars_decoded %in% presentVars_decoded))]
table$addFootnote(gettextf("The trained model is not applied because the the following features are missing: <i>%1$s</i>.", paste0(missingVars, collapse = ", ")))
}
if (!all(presentVars %in% modelVars)) {
unusedVars <- presentVars[which(!(presentVars %in% modelVars))]
if (!all(presentVars_decoded %in% modelVars_decoded)) {
unusedVars <- presentVars_decoded[which(!(presentVars_decoded %in% modelVars_decoded))]
table$addFootnote(gettextf("The following features are unused because they are not a feature variable in the trained model: <i>%1$s</i>.", paste0(unusedVars, collapse = ", ")))
}
if (inherits(model, "kknn")) {
Expand Down Expand Up @@ -394,7 +336,7 @@ is.jaspMachineLearning <- function(x) {
row[["family"]] <- gettext("Multinomial")
row[["link"]] <- gettext("Logit")
}
if (length(presentVars) > 0) {
if (length(presentVars_encoded) > 0) {
row[["nnew"]] <- nrow(dataset)
}
table$addRows(row)
Expand All @@ -405,7 +347,7 @@ is.jaspMachineLearning <- function(x) {
return()
}
table <- createJaspTable(gettext("Predictions for New Data"))
table$dependOn(options = c("predictors", "loadPath", "predictionsTable", "predictionsTableFeatures", "scaleVariables", "fromIndex", "toIndex"))
table$dependOn(options = c("predictors", "trainedModelFilePath", "predictionsTable", "predictionsTableFeatures", "scaleVariables", "fromIndex", "toIndex"))
table$position <- position
table$addColumnInfo(name = "row", title = gettext("Case"), type = "integer")
if (!is.null(model) && inherits(model, "jaspClassification")) {
Expand All @@ -422,16 +364,20 @@ is.jaspMachineLearning <- function(x) {
selection <- predictions[indexes]
cols <- list(row = indexes, pred = selection)
if (options[["predictionsTableFeatures"]]) {
for (i in colnames(dataset)) {
if (.columnIsNominal(i)) {
table$addColumnInfo(name = i, title = i, type = "string")
modelVars_encoded <- model[["jaspVars"]][["encoded"]]$predictors
modelVars_decoded <- model[["jaspVars"]][["decoded"]]$predictors
matched_names <- match(colnames(dataset), modelVars_encoded)
for (i in seq_len(ncol(dataset))) {
colName <- modelVars_decoded[matched_names[i]]
if (is.factor(dataset[[i]])) {
table$addColumnInfo(name = colName, title = colName, type = "string")
var <- levels(dataset[[i]])[dataset[[i]]]
} else {
table$addColumnInfo(name = i, title = i, type = "number")
table$addColumnInfo(name = colName, title = colName, type = "number")
var <- dataset[[i]]
}
var <- var[indexes]
cols[[i]] <- var
cols[[colName]] <- var
}
}
table$setData(cols)
Expand All @@ -442,7 +388,7 @@ is.jaspMachineLearning <- function(x) {
predictionsColumn <- rep(NA, max(as.numeric(rownames(dataset))))
predictionsColumn[as.numeric(rownames(dataset))] <- .mlPredictionsState(model, dataset, options, jaspResults, ready)
jaspResults[["predictionsColumn"]] <- createJaspColumn(columnName = options[["predictionsColumn"]])
jaspResults[["predictionsColumn"]]$dependOn(options = c("predictionsColumn", "predictors", "loadPath", "scaleVariables", "addPredictions"))
jaspResults[["predictionsColumn"]]$dependOn(options = c("predictionsColumn", "predictors", "trainedModelFilePath", "scaleVariables", "addPredictions"))
if (inherits(model, "jaspClassification")) jaspResults[["predictionsColumn"]]$setNominal(predictionsColumn)
if (inherits(model, "jaspRegression")) jaspResults[["predictionsColumn"]]$setScale(predictionsColumn)
}
Expand Down
Loading