Skip to content

Commit c22f8f0

Browse files
authored
Store all variables in saved model (#380)
1 parent 73000ff commit c22f8f0

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

R/commonMachineLearningClassification.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,9 @@
342342
return()
343343
}
344344
model <- classificationResult[["model"]]
345-
model[["jaspVars"]] <- decodeColNames(options[["predictors"]])
345+
model[["jaspVars"]] <- list()
346+
model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]]))
347+
model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
346348
model[["jaspVersion"]] <- .baseCitation
347349
model[["explainer"]] <- classificationResult[["explainer"]]
348350
model <- .decodeJaspMLobject(model)

R/commonMachineLearningRegression.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,9 @@
454454
return()
455455
}
456456
model <- regressionResult[["model"]]
457-
model[["jaspVars"]] <- decodeColNames(options[["predictors"]])
457+
model[["jaspVars"]] <- list()
458+
model[["jaspVars"]]$decoded <- list(target = decodeColNames(options[["target"]]), predictors = decodeColNames(options[["predictors"]]))
459+
model[["jaspVars"]]$encoded = list(target = options[["target"]], predictors = options[["predictors"]])
458460
model[["jaspVersion"]] <- .baseCitation
459461
model[["explainer"]] <- regressionResult[["explainer"]]
460462
model <- .decodeJaspMLobject(model)
@@ -697,7 +699,7 @@
697699
} else {
698700
purpose <- "classification"
699701
}
700-
predictors <- options[["predictors"]][which(decodeColNames(options[["predictors"]]) %in% model[["jaspVars"]])]
702+
predictors <- options[["predictors"]][which(decodeColNames(options[["predictors"]]) %in% model[["jaspVars"]][["decoded"]]$predictors)]
701703
} else {
702704
predictors <- options[["predictors"]]
703705
}

R/mlPrediction.R

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,21 +119,21 @@ is.jaspMachineLearning <- function(x) {
119119
}
120120
.mlPredictionGetPredictions.nn <- function(model, dataset) {
121121
if (inherits(model, "jaspClassification")) {
122-
as.character(levels(factor(model[["data"]][, 1]))[max.col(neuralnet:::predict.nn(model, newdata = dataset))])
122+
as.character(levels(factor(model[["data"]][[model[["jaspVars"]][["encoded"]]$target]]))[max.col(neuralnet:::predict.nn(model, newdata = dataset))])
123123
} else if (inherits(model, "jaspRegression")) {
124124
as.numeric(neuralnet:::predict.nn(model, newdata = dataset))
125125
}
126126
}
127127
.mlPredictionGetPredictions.rpart <- function(model, dataset) {
128128
if (inherits(model, "jaspClassification")) {
129-
as.character(levels(factor(model[["data"]][, 1]))[max.col(rpart:::predict.rpart(model, newdata = dataset))])
129+
as.character(levels(factor(model[["data"]][[model[["jaspVars"]][["encoded"]]$target]]))[max.col(rpart:::predict.rpart(model, newdata = dataset))])
130130
} else if (inherits(model, "jaspRegression")) {
131131
as.numeric(rpart:::predict.rpart(model, newdata = dataset))
132132
}
133133
}
134134
.mlPredictionGetPredictions.svm <- function(model, dataset) {
135135
if (inherits(model, "jaspClassification")) {
136-
as.character(levels(factor(model[["data"]][, 1]))[e1071:::predict.svm(model, newdata = dataset)])
136+
as.character(levels(factor(model[["data"]][[model[["jaspVars"]][["encoded"]]$target]]))[e1071:::predict.svm(model, newdata = dataset)])
137137
} else if (inherits(model, "jaspRegression")) {
138138
as.numeric(e1071:::predict.svm(model, newdata = dataset))
139139
}
@@ -142,7 +142,7 @@ is.jaspMachineLearning <- function(x) {
142142
as.character(e1071:::predict.naiveBayes(model, newdata = dataset, type = "class"))
143143
}
144144
.mlPredictionGetPredictions.glm <- function(model, dataset) {
145-
as.character(levels(as.factor(model$model[, 1]))[round(predict(model, newdata = dataset, type = "response"), 0) + 1])
145+
as.character(levels(as.factor(model$model[[model[["jaspVars"]][["encoded"]]$target]]))[round(predict(model, newdata = dataset, type = "response"), 0) + 1])
146146
}
147147
.mlPredictionGetPredictions.vglm <- function(model, dataset) {
148148
model[["original"]]@terms$terms <- model[["terms"]]
@@ -293,7 +293,7 @@ is.jaspMachineLearning <- function(x) {
293293
# also define methods for other objects
294294
.mlPredictionReady <- function(model, dataset, options) {
295295
if (!is.null(model)) {
296-
modelVars <- model[["jaspVars"]]
296+
modelVars <- model[["jaspVars"]][["decoded"]]$predictors
297297
presentVars <- decodeColNames(colnames(dataset))
298298
ready <- all(modelVars %in% presentVars)
299299
} else {
@@ -344,7 +344,7 @@ is.jaspMachineLearning <- function(x) {
344344
if (is.null(model)) {
345345
return()
346346
}
347-
modelVars <- model[["jaspVars"]]
347+
modelVars <- model[["jaspVars"]][["decoded"]]$predictors
348348
presentVars <- decodeColNames(colnames(dataset))
349349
if (!all(modelVars %in% presentVars)) {
350350
missingVars <- modelVars[which(!(modelVars %in% presentVars))]
@@ -422,7 +422,7 @@ is.jaspMachineLearning <- function(x) {
422422
selection <- predictions[indexes]
423423
cols <- list(row = indexes, pred = selection)
424424
if (options[["predictionsTableFeatures"]]) {
425-
for (i in encodeColNames(model[["jaspVars"]])) {
425+
for (i in model[["jaspVars"]][["encoded"]]$predictors) {
426426
if (.columnIsNominal(i)) {
427427
table$addColumnInfo(name = i, title = i, type = "string")
428428
var <- levels(dataset[[i]])[dataset[[i]]]

0 commit comments

Comments
 (0)