Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions R/commonMachineLearningClassification.R
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@
model[["jaspVars"]] <- list()
model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]]))
model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
model[["jaspScaling"]] <- attr(dataset, "jaspScaling")
model[["jaspVersion"]] <- .baseCitation
model[["explainer"]] <- classificationResult[["explainer"]]
class(model) <- c(class(classificationResult[["model"]]), "jaspClassification", "jaspMachineLearning")
Expand Down
10 changes: 10 additions & 0 deletions R/commonMachineLearningRegression.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@

# Scale numeric predictors
if (length(unlist(options[["predictors"]])) > 0 && options[["scaleVariables"]]) {
attr(dataset, which = "jaspScaling") <- .getJaspScaling(dataset[, options[["predictors"]], drop = FALSE])
dataset[, options[["predictors"]]] <- .scaleNumericData(dataset[, options[["predictors"]], drop = FALSE])
}
return(dataset)
Expand Down Expand Up @@ -501,6 +502,7 @@
model[["jaspVars"]] <- list()
model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]]))
model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
model[["jaspScaling"]] <- attr(dataset, "jaspScaling")
model[["jaspVersion"]] <- .baseCitation
model[["explainer"]] <- regressionResult[["explainer"]]
class(model) <- c(class(regressionResult[["model"]]), "jaspRegression", "jaspMachineLearning")
Expand Down Expand Up @@ -683,6 +685,14 @@
}
}

.getJaspScaling <- function(x) {
idx <- sapply(x, function(x) is.numeric(x) && length(unique(x)) > 1)
cols_to_scale <- colnames(x)[idx]
centers <- sapply(x[cols_to_scale], mean)
scales <- sapply(x[cols_to_scale], sd)
return(list(centers, scales))
}

# these could also extend the S3 method scale although that could be somewhat unexpected
.scaleNumericData <- function(x, ...) {
UseMethod(".scaleNumericData", x)
Expand Down
50 changes: 34 additions & 16 deletions R/mlPrediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -269,25 +269,38 @@ is.jaspMachineLearning <- function(x) {
return(names)
}

.setJaspScaling <- function(x, centers, scales) {
if (nrow(x) == 0) {
return(x)
}
for (col in names(centers)) {
x[, col] <- (x[, col] - centers[col]) / scales[col]
}
return(x)
}

.mlPredictionReadData <- function(dataset, options, model) {
if (length(options[["predictors"]]) == 0) {
dataset <- NULL
} else {
dataset <- jaspBase::excludeNaListwise(dataset, options[["predictors"]])
if (options[["scaleVariables"]] && length(unlist(options[["predictors"]])) > 0) {
dataset <- .scaleNumericData(dataset)
}
# Select only the predictors in the model to prevent accidental double column names
dataset <- dataset[, which(decodeColNames(colnames(dataset)) %in% model[["jaspVars"]][["decoded"]]$predictors), drop = FALSE]
# Ensure the column names in the dataset match those in the training data
colnames(dataset) <- .matchDecodedNames(colnames(dataset), model)
# Retrieve the training set
trainingSet <- model[["explainer"]]$data
# Check for factor levels in the test set that are not in the training set
.checkForNewFactorLevelsInPredictionSet(trainingSet, dataset, "prediction", model)
# Ensure factor variables in dataset have same levels as those in the training data
factorColumns <- colnames(dataset)[sapply(dataset, is.factor)]
dataset[factorColumns] <- lapply(factorColumns, function(i) factor(dataset[[i]], levels = levels(trainingSet[[i]])))
if (NCOL(dataset) > 0) {
# Ensure the column names in the dataset match those in the training data
colnames(dataset) <- .matchDecodedNames(colnames(dataset), model)
# Scale the features with the same scaling as the original dataset
if (!is.null(model[["jaspScaling"]])) {
dataset <- .setJaspScaling(dataset, model$jaspScaling[[1]], model$jaspScaling[[2]])
}
# Retrieve the training set
trainingSet <- model[["explainer"]]$data
# Check for factor levels in the test set that are not in the training set
.checkForNewFactorLevelsInPredictionSet(trainingSet, dataset, "prediction", model)
# Ensure that factor variables in the dataset have their levels ordered the same way as in the training data
factorColumns <- colnames(dataset)[sapply(dataset, is.factor)]
dataset[factorColumns] <- lapply(factorColumns, function(i) factor(dataset[[i]], levels = levels(trainingSet[[i]])))
}
}
return(dataset)
}
Expand All @@ -299,7 +312,7 @@ is.jaspMachineLearning <- function(x) {
if (ready) {
dataset <- dataset[which(colnames(dataset) %in% model[["jaspVars"]][["encoded"]]$predictors)]
jaspResults[["predictions"]] <- createJaspState(.mlPredictionGetPredictions(model, dataset))
jaspResults[["predictions"]]$dependOn(options = c("loadPath", "predictors", "scaleVariables"))
jaspResults[["predictions"]]$dependOn(options = c("loadPath", "predictors"))
return(jaspResults[["predictions"]]$object)
} else {
return(NULL)
Expand All @@ -326,6 +339,11 @@ is.jaspMachineLearning <- function(x) {
if (is.null(model)) {
return()
}
if (is.null(model[["jaspScaling"]])) {
table$addFootnote(gettext("The features in the new data are unscaled."))
} else {
table$addFootnote(gettext("The features in the new data are scaled."))
}
modelVars_encoded <- model[["jaspVars"]][["encoded"]]$predictors
modelVars_decoded <- model[["jaspVars"]][["decoded"]]$predictors
presentVars_encoded <- colnames(dataset)
Expand Down Expand Up @@ -389,7 +407,7 @@ is.jaspMachineLearning <- function(x) {
return()
}
table <- createJaspTable(gettext("Predictions for New Data"))
table$dependOn(options = c("predictors", "trainedModelFilePath", "predictionsTable", "predictionsTableFeatures", "scaleVariables", "fromIndex", "toIndex"))
table$dependOn(options = c("predictors", "trainedModelFilePath", "predictionsTable", "predictionsTableFeatures", "fromIndex", "toIndex"))
table$position <- position
table$addColumnInfo(name = "row", title = gettext("Case"), type = "integer")
if (!is.null(model) && inherits(model, "jaspClassification")) {
Expand Down Expand Up @@ -433,7 +451,7 @@ is.jaspMachineLearning <- function(x) {
predictionsColumn <- rep(NA, max(as.numeric(rownames(dataset))))
predictionsColumn[as.numeric(rownames(dataset))] <- predictions[[1]]
jaspResults[["predictionsColumn"]] <- createJaspColumn(columnName = options[["predictionsColumn"]])
jaspResults[["predictionsColumn"]]$dependOn(options = c("predictionsColumn", "predictors", "trainedModelFilePath", "scaleVariables", "addPredictions"))
jaspResults[["predictionsColumn"]]$dependOn(options = c("predictionsColumn", "predictors", "trainedModelFilePath", "addPredictions"))
if (inherits(model, "jaspClassification")) jaspResults[["predictionsColumn"]]$setNominal(predictionsColumn)
if (inherits(model, "jaspRegression")) jaspResults[["predictionsColumn"]]$setScale(predictionsColumn)
}
Expand All @@ -446,7 +464,7 @@ is.jaspMachineLearning <- function(x) {
break
}
jaspResults[[colName]] <- createJaspColumn(columnName = colName)
jaspResults[[colName]]$dependOn(options = c("predictionsColumn", "predictors", "trainedModelFilePath", "scaleVariables", "addPredictions", "addProbabilities"))
jaspResults[[colName]]$dependOn(options = c("predictionsColumn", "predictors", "trainedModelFilePath", "addPredictions", "addProbabilities"))
jaspResults[[colName]]$setScale(predictions[[2]][, i])
}
}
Expand Down
7 changes: 0 additions & 7 deletions inst/qml/mlPrediction.qml
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,6 @@ Form
}
}

Group
{
title: qsTr("Algorithmic Settings")

UI.ScaleVariables { }
}

Group
{
title: qsTr("Tables")
Expand Down
Loading