Skip to content

Commit e5a3058

Browse files
committed
Start scaling
1 parent 1e61829 commit e5a3058

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

R/commonMachineLearningRegression.R

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,8 @@
501501
model[["jaspVars"]] <- list()
502502
model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]]))
503503
model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
504+
model[["jaspScaling"]] <- .getJaspScaling(dataset[, options[["predictors"]], drop = FALSE])
505+
print(model[["jaspScaling"]])
504506
model[["jaspVersion"]] <- .baseCitation
505507
model[["explainer"]] <- regressionResult[["explainer"]]
506508
class(model) <- c(class(regressionResult[["model"]]), "jaspRegression", "jaspMachineLearning")
@@ -683,6 +685,36 @@
683685
}
684686
}
685687

688+
.getJaspScaling <- function(x) {
689+
idx <- sapply(x, function(x) is.numeric(x) && length(unique(x)) > 1)
690+
colNames <- list(encoded = colnames(x), decoded = decodeColNames(colnames(x)))
691+
cols_to_scale <- colNames[["decoded"]][idx]
692+
centers <- setNames(numeric(length(cols_to_scale)), cols_to_scale)
693+
scales <- setNames(numeric(length(cols_to_scale)), cols_to_scale)
694+
for (col in cols_to_scale) {
695+
encodedColName <- colNames[["encoded"]][which(colNames[["decoded"]] == col)]
696+
centers[col] <- mean(x[[encodedColName]])
697+
scales[col] <- sd(x[[encodedColName]])
698+
}
699+
return(list(centers, scaling))
700+
}
701+
702+
.setJaspScaling <- function(x, center, scale) {
703+
if (nrow(x) == 0) {
704+
return(x)
705+
}
706+
idx <- sapply(x, function(x) is.numeric(x) && length(unique(x)) > 1)
707+
colNames <- list(encoded = colnames(x), decoded = decodeColNames(colnames(x)))
708+
cols_to_scale <- colNames[["decoded"]][idx]
709+
for (col in cols_to_scale) {
710+
encodedColName <- colNames[["encoded"]][which(colNames[["decoded"]] == col)]
711+
x[, encodedColName] <- scale(x[, encodedColName, drop = FALSE], center = center[col], scale = scale[col])
712+
}
713+
attr(x, which = "scaled:center") <- NULL
714+
attr(x, which = "scaled:scale") <- NULL
715+
return(x)
716+
}
717+
686718
# these could also extend the S3 method scale although that could be somewhat unexpected
687719
.scaleNumericData <- function(x, ...) {
688720
UseMethod(".scaleNumericData", x)

R/mlPrediction.R

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,13 +274,18 @@ is.jaspMachineLearning <- function(x) {
274274
dataset <- NULL
275275
} else {
276276
dataset <- jaspBase::excludeNaListwise(dataset, options[["predictors"]])
277-
if (options[["scaleVariables"]] && length(unlist(options[["predictors"]])) > 0) {
278-
dataset <- .scaleNumericData(dataset)
279-
}
280277
# Select only the predictors in the model to prevent accidental double column names
281278
dataset <- dataset[, which(decodeColNames(colnames(dataset)) %in% model[["jaspVars"]][["decoded"]]$predictors), drop = FALSE]
282279
# Ensure the column names in the dataset match those in the training data
283280
colnames(dataset) <- .matchDecodedNames(colnames(dataset), model)
281+
# Scale the features with the same scaling as the origingal dataset
282+
if (options[["scaleVariables"]] && length(unlist(options[["predictors"]])) > 0) {
283+
if (is.null(model[["jaspScaling"]])) {
284+
dataset <- .scaleNumericData(dataset)
285+
} else {
286+
dataset <- .setJaspScaling(dataset, model$jaspScaling[[1]], model$jaspScaling[[2]])
287+
}
288+
}
284289
# Retrieve the training set
285290
trainingSet <- model[["explainer"]]$data
286291
# Check for factor levels in the test set that are not in the training set

0 commit comments

Comments
 (0)