|
501 | 501 | model[["jaspVars"]] <- list()
|
502 | 502 | model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]]))
|
503 | 503 | model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
|
| 504 | + model[["jaspScaling"]] <- .getJaspScaling(dataset[, options[["predictors"]], drop = FALSE]) |
| 505 | + print(model[["jaspScaling"]]) |
504 | 506 | model[["jaspVersion"]] <- .baseCitation
|
505 | 507 | model[["explainer"]] <- regressionResult[["explainer"]]
|
506 | 508 | class(model) <- c(class(regressionResult[["model"]]), "jaspRegression", "jaspMachineLearning")
|
|
683 | 685 | }
|
684 | 686 | }
|
685 | 687 |
|
| 688 | +.getJaspScaling <- function(x) { |
| 689 | + idx <- sapply(x, function(x) is.numeric(x) && length(unique(x)) > 1) |
| 690 | + colNames <- list(encoded = colnames(x), decoded = decodeColNames(colnames(x))) |
| 691 | + cols_to_scale <- colNames[["decoded"]][idx] |
| 692 | + centers <- setNames(numeric(length(cols_to_scale)), cols_to_scale) |
| 693 | + scales <- setNames(numeric(length(cols_to_scale)), cols_to_scale) |
| 694 | + for (col in cols_to_scale) { |
| 695 | + encodedColName <- colNames[["encoded"]][which(colNames[["decoded"]] == col)] |
| 696 | + centers[col] <- mean(x[[encodedColName]]) |
| 697 | + scales[col] <- sd(x[[encodedColName]]) |
| 698 | + } |
| 699 | + return(list(centers, scaling)) |
| 700 | +} |
| 701 | + |
| 702 | +.setJaspScaling <- function(x, center, scale) { |
| 703 | + if (nrow(x) == 0) { |
| 704 | + return(x) |
| 705 | + } |
| 706 | + idx <- sapply(x, function(x) is.numeric(x) && length(unique(x)) > 1) |
| 707 | + colNames <- list(encoded = colnames(x), decoded = decodeColNames(colnames(x))) |
| 708 | + cols_to_scale <- colNames[["decoded"]][idx] |
| 709 | + for (col in cols_to_scale) { |
| 710 | + encodedColName <- colNames[["encoded"]][which(colNames[["decoded"]] == col)] |
| 711 | + x[, encodedColName] <- scale(x[, encodedColName, drop = FALSE], center = center[col], scale = scale[col]) |
| 712 | + } |
| 713 | + attr(x, which = "scaled:center") <- NULL |
| 714 | + attr(x, which = "scaled:scale") <- NULL |
| 715 | + return(x) |
| 716 | +} |
| 717 | + |
686 | 718 | # these could also extend the S3 method scale although that could be somewhat unexpected
|
687 | 719 | .scaleNumericData <- function(x, ...) {
|
688 | 720 | UseMethod(".scaleNumericData", x)
|
|
0 commit comments