Skip to content

Commit 945a1ef

Browse files
committed
Update
1 parent e5a3058 commit 945a1ef

File tree

4 files changed

+38
-53
lines changed

4 files changed

+38
-53
lines changed

R/commonMachineLearningClassification.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@
340340
model[["jaspVars"]] <- list()
341341
model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]]))
342342
model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
343+
model[["jaspScaling"]] <- attr(dataset, "jaspScaling")
343344
model[["jaspVersion"]] <- .baseCitation
344345
model[["explainer"]] <- classificationResult[["explainer"]]
345346
class(model) <- c(class(classificationResult[["model"]]), "jaspClassification", "jaspMachineLearning")

R/commonMachineLearningRegression.R

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070

7171
# Scale numeric predictors
7272
if (length(unlist(options[["predictors"]])) > 0 && options[["scaleVariables"]]) {
73+
attr(dataset, which = "jaspScaling") <- .getJaspScaling(dataset[, options[["predictors"]], drop = FALSE])
7374
dataset[, options[["predictors"]]] <- .scaleNumericData(dataset[, options[["predictors"]], drop = FALSE])
7475
}
7576
return(dataset)
@@ -501,8 +502,7 @@
501502
model[["jaspVars"]] <- list()
502503
model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]]))
503504
model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
504-
model[["jaspScaling"]] <- .getJaspScaling(dataset[, options[["predictors"]], drop = FALSE])
505-
print(model[["jaspScaling"]])
505+
model[["jaspScaling"]] <- attr(dataset, "jaspScaling")
506506
model[["jaspVersion"]] <- .baseCitation
507507
model[["explainer"]] <- regressionResult[["explainer"]]
508508
class(model) <- c(class(regressionResult[["model"]]), "jaspRegression", "jaspMachineLearning")
@@ -687,32 +687,10 @@
687687

688688
.getJaspScaling <- function(x) {
689689
idx <- sapply(x, function(x) is.numeric(x) && length(unique(x)) > 1)
690-
colNames <- list(encoded = colnames(x), decoded = decodeColNames(colnames(x)))
691-
cols_to_scale <- colNames[["decoded"]][idx]
692-
centers <- setNames(numeric(length(cols_to_scale)), cols_to_scale)
693-
scales <- setNames(numeric(length(cols_to_scale)), cols_to_scale)
694-
for (col in cols_to_scale) {
695-
encodedColName <- colNames[["encoded"]][which(colNames[["decoded"]] == col)]
696-
centers[col] <- mean(x[[encodedColName]])
697-
scales[col] <- sd(x[[encodedColName]])
698-
}
699-
return(list(centers, scaling))
700-
}
701-
702-
.setJaspScaling <- function(x, center, scale) {
703-
if (nrow(x) == 0) {
704-
return(x)
705-
}
706-
idx <- sapply(x, function(x) is.numeric(x) && length(unique(x)) > 1)
707-
colNames <- list(encoded = colnames(x), decoded = decodeColNames(colnames(x)))
708-
cols_to_scale <- colNames[["decoded"]][idx]
709-
for (col in cols_to_scale) {
710-
encodedColName <- colNames[["encoded"]][which(colNames[["decoded"]] == col)]
711-
x[, encodedColName] <- scale(x[, encodedColName, drop = FALSE], center = center[col], scale = scale[col])
712-
}
713-
attr(x, which = "scaled:center") <- NULL
714-
attr(x, which = "scaled:scale") <- NULL
715-
return(x)
690+
cols_to_scale <- colnames(x)[idx]
691+
centers <- sapply(x[cols_to_scale], mean)
692+
scales <- sapply(x[cols_to_scale], sd)
693+
return(list(centers, scales))
716694
}
717695

718696
# these could also extend the S3 method scale although that could be somewhat unexpected

R/mlPrediction.R

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -269,30 +269,38 @@ is.jaspMachineLearning <- function(x) {
269269
return(names)
270270
}
271271

272+
.setJaspScaling <- function(x, centers, scales) {
273+
if (nrow(x) == 0) {
274+
return(x)
275+
}
276+
for (col in names(centers)) {
277+
x[, col] <- (x[, col] - centers[col]) / scales[col]
278+
}
279+
return(x)
280+
}
281+
272282
.mlPredictionReadData <- function(dataset, options, model) {
273283
if (length(options[["predictors"]]) == 0) {
274284
dataset <- NULL
275285
} else {
276286
dataset <- jaspBase::excludeNaListwise(dataset, options[["predictors"]])
277287
# Select only the predictors in the model to prevent accidental double column names
278288
dataset <- dataset[, which(decodeColNames(colnames(dataset)) %in% model[["jaspVars"]][["decoded"]]$predictors), drop = FALSE]
279-
# Ensure the column names in the dataset match those in the training data
280-
colnames(dataset) <- .matchDecodedNames(colnames(dataset), model)
281-
# Scale the features with the same scaling as the origingal dataset
282-
if (options[["scaleVariables"]] && length(unlist(options[["predictors"]])) > 0) {
283-
if (is.null(model[["jaspScaling"]])) {
284-
dataset <- .scaleNumericData(dataset)
285-
} else {
289+
if (NCOL(dataset) > 0) {
290+
# Ensure the column names in the dataset match those in the training data
291+
colnames(dataset) <- .matchDecodedNames(colnames(dataset), model)
292+
# Scale the features with the same scaling as the original dataset
293+
if (!is.null(model[["jaspScaling"]])) {
286294
dataset <- .setJaspScaling(dataset, model$jaspScaling[[1]], model$jaspScaling[[2]])
287295
}
296+
# Retrieve the training set
297+
trainingSet <- model[["explainer"]]$data
298+
# Check for factor levels in the test set that are not in the training set
299+
.checkForNewFactorLevelsInPredictionSet(trainingSet, dataset, "prediction", model)
300+
# Ensure factor variables in dataset have same levels as those in the training data
301+
factorColumns <- colnames(dataset)[sapply(dataset, is.factor)]
302+
dataset[factorColumns] <- lapply(factorColumns, function(i) factor(dataset[[i]], levels = levels(trainingSet[[i]])))
288303
}
289-
# Retrieve the training set
290-
trainingSet <- model[["explainer"]]$data
291-
# Check for factor levels in the test set that are not in the training set
292-
.checkForNewFactorLevelsInPredictionSet(trainingSet, dataset, "prediction", model)
293-
# Ensure factor variables in dataset have same levels as those in the training data
294-
factorColumns <- colnames(dataset)[sapply(dataset, is.factor)]
295-
dataset[factorColumns] <- lapply(factorColumns, function(i) factor(dataset[[i]], levels = levels(trainingSet[[i]])))
296304
}
297305
return(dataset)
298306
}
@@ -304,7 +312,7 @@ is.jaspMachineLearning <- function(x) {
304312
if (ready) {
305313
dataset <- dataset[which(colnames(dataset) %in% model[["jaspVars"]][["encoded"]]$predictors)]
306314
jaspResults[["predictions"]] <- createJaspState(.mlPredictionGetPredictions(model, dataset))
307-
jaspResults[["predictions"]]$dependOn(options = c("loadPath", "predictors", "scaleVariables"))
315+
jaspResults[["predictions"]]$dependOn(options = c("loadPath", "predictors"))
308316
return(jaspResults[["predictions"]]$object)
309317
} else {
310318
return(NULL)
@@ -331,6 +339,11 @@ is.jaspMachineLearning <- function(x) {
331339
if (is.null(model)) {
332340
return()
333341
}
342+
if (is.null(model[["jaspScaling"]])) {
343+
table$addFootnote(gettext("The features in the new data are unscaled."))
344+
} else {
345+
table$addFootnote(gettext("The features in the new data are scaled."))
346+
}
334347
modelVars_encoded <- model[["jaspVars"]][["encoded"]]$predictors
335348
modelVars_decoded <- model[["jaspVars"]][["decoded"]]$predictors
336349
presentVars_encoded <- colnames(dataset)
@@ -394,7 +407,7 @@ is.jaspMachineLearning <- function(x) {
394407
return()
395408
}
396409
table <- createJaspTable(gettext("Predictions for New Data"))
397-
table$dependOn(options = c("predictors", "trainedModelFilePath", "predictionsTable", "predictionsTableFeatures", "scaleVariables", "fromIndex", "toIndex"))
410+
table$dependOn(options = c("predictors", "trainedModelFilePath", "predictionsTable", "predictionsTableFeatures", "fromIndex", "toIndex"))
398411
table$position <- position
399412
table$addColumnInfo(name = "row", title = gettext("Case"), type = "integer")
400413
if (!is.null(model) && inherits(model, "jaspClassification")) {
@@ -438,7 +451,7 @@ is.jaspMachineLearning <- function(x) {
438451
predictionsColumn <- rep(NA, max(as.numeric(rownames(dataset))))
439452
predictionsColumn[as.numeric(rownames(dataset))] <- predictions[[1]]
440453
jaspResults[["predictionsColumn"]] <- createJaspColumn(columnName = options[["predictionsColumn"]])
441-
jaspResults[["predictionsColumn"]]$dependOn(options = c("predictionsColumn", "predictors", "trainedModelFilePath", "scaleVariables", "addPredictions"))
454+
jaspResults[["predictionsColumn"]]$dependOn(options = c("predictionsColumn", "predictors", "trainedModelFilePath", "addPredictions"))
442455
if (inherits(model, "jaspClassification")) jaspResults[["predictionsColumn"]]$setNominal(predictionsColumn)
443456
if (inherits(model, "jaspRegression")) jaspResults[["predictionsColumn"]]$setScale(predictionsColumn)
444457
}
@@ -451,7 +464,7 @@ is.jaspMachineLearning <- function(x) {
451464
break
452465
}
453466
jaspResults[[colName]] <- createJaspColumn(columnName = colName)
454-
jaspResults[[colName]]$dependOn(options = c("predictionsColumn", "predictors", "trainedModelFilePath", "scaleVariables", "addPredictions", "addProbabilities"))
467+
jaspResults[[colName]]$dependOn(options = c("predictionsColumn", "predictors", "trainedModelFilePath", "addPredictions", "addProbabilities"))
455468
jaspResults[[colName]]$setScale(predictions[[2]][, i])
456469
}
457470
}

inst/qml/mlPrediction.qml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,6 @@ Form
5757
}
5858
}
5959

60-
Group
61-
{
62-
title: qsTr("Algorithmic Settings")
63-
64-
UI.ScaleVariables { }
65-
}
66-
6760
Group
6861
{
6962
title: qsTr("Tables")

0 commit comments

Comments
 (0)