@@ -269,30 +269,38 @@ is.jaspMachineLearning <- function(x) {
269
269
return (names )
270
270
}
271
271
272
+ .setJaspScaling <- function (x , centers , scales ) {
273
+ if (nrow(x ) == 0 ) {
274
+ return (x )
275
+ }
276
+ for (col in names(centers )) {
277
+ x [, col ] <- (x [, col ] - centers [col ]) / scales [col ]
278
+ }
279
+ return (x )
280
+ }
281
+
272
282
.mlPredictionReadData <- function (dataset , options , model ) {
273
283
if (length(options [[" predictors" ]]) == 0 ) {
274
284
dataset <- NULL
275
285
} else {
276
286
dataset <- jaspBase :: excludeNaListwise(dataset , options [[" predictors" ]])
277
287
# Select only the predictors in the model to prevent accidental double column names
278
288
dataset <- dataset [, which(decodeColNames(colnames(dataset )) %in% model [[" jaspVars" ]][[" decoded" ]]$ predictors ), drop = FALSE ]
279
- # Ensure the column names in the dataset match those in the training data
280
- colnames(dataset ) <- .matchDecodedNames(colnames(dataset ), model )
281
- # Scale the features with the same scaling as the origingal dataset
282
- if (options [[" scaleVariables" ]] && length(unlist(options [[" predictors" ]])) > 0 ) {
283
- if (is.null(model [[" jaspScaling" ]])) {
284
- dataset <- .scaleNumericData(dataset )
285
- } else {
289
+ if (NCOL(dataset ) > 0 ) {
290
+ # Ensure the column names in the dataset match those in the training data
291
+ colnames(dataset ) <- .matchDecodedNames(colnames(dataset ), model )
292
+ # Scale the features with the same scaling as the original dataset
293
+ if (! is.null(model [[" jaspScaling" ]])) {
286
294
dataset <- .setJaspScaling(dataset , model $ jaspScaling [[1 ]], model $ jaspScaling [[2 ]])
287
295
}
296
+ # Retrieve the training set
297
+ trainingSet <- model [[" explainer" ]]$ data
298
+ # Check for factor levels in the test set that are not in the training set
299
+ .checkForNewFactorLevelsInPredictionSet(trainingSet , dataset , " prediction" , model )
300
+ # Ensure factor variables in dataset have same levels as those in the training data
301
+ factorColumns <- colnames(dataset )[sapply(dataset , is.factor )]
302
+ dataset [factorColumns ] <- lapply(factorColumns , function (i ) factor (dataset [[i ]], levels = levels(trainingSet [[i ]])))
288
303
}
289
- # Retrieve the training set
290
- trainingSet <- model [[" explainer" ]]$ data
291
- # Check for factor levels in the test set that are not in the training set
292
- .checkForNewFactorLevelsInPredictionSet(trainingSet , dataset , " prediction" , model )
293
- # Ensure factor variables in dataset have same levels as those in the training data
294
- factorColumns <- colnames(dataset )[sapply(dataset , is.factor )]
295
- dataset [factorColumns ] <- lapply(factorColumns , function (i ) factor (dataset [[i ]], levels = levels(trainingSet [[i ]])))
296
304
}
297
305
return (dataset )
298
306
}
@@ -304,7 +312,7 @@ is.jaspMachineLearning <- function(x) {
304
312
if (ready ) {
305
313
dataset <- dataset [which(colnames(dataset ) %in% model [[" jaspVars" ]][[" encoded" ]]$ predictors )]
306
314
jaspResults [[" predictions" ]] <- createJaspState(.mlPredictionGetPredictions(model , dataset ))
307
- jaspResults [[" predictions" ]]$ dependOn(options = c(" loadPath" , " predictors" , " scaleVariables " ))
315
+ jaspResults [[" predictions" ]]$ dependOn(options = c(" loadPath" , " predictors" ))
308
316
return (jaspResults [[" predictions" ]]$ object )
309
317
} else {
310
318
return (NULL )
@@ -331,6 +339,11 @@ is.jaspMachineLearning <- function(x) {
331
339
if (is.null(model )) {
332
340
return ()
333
341
}
342
+ if (is.null(model [[" jaspScaling" ]])) {
343
+ table $ addFootnote(gettext(" The features in the new data are unscaled." ))
344
+ } else {
345
+ table $ addFootnote(gettext(" The features in the new data are scaled." ))
346
+ }
334
347
modelVars_encoded <- model [[" jaspVars" ]][[" encoded" ]]$ predictors
335
348
modelVars_decoded <- model [[" jaspVars" ]][[" decoded" ]]$ predictors
336
349
presentVars_encoded <- colnames(dataset )
@@ -394,7 +407,7 @@ is.jaspMachineLearning <- function(x) {
394
407
return ()
395
408
}
396
409
table <- createJaspTable(gettext(" Predictions for New Data" ))
397
- table $ dependOn(options = c(" predictors" , " trainedModelFilePath" , " predictionsTable" , " predictionsTableFeatures" , " scaleVariables " , " fromIndex" , " toIndex" ))
410
+ table $ dependOn(options = c(" predictors" , " trainedModelFilePath" , " predictionsTable" , " predictionsTableFeatures" , " fromIndex" , " toIndex" ))
398
411
table $ position <- position
399
412
table $ addColumnInfo(name = " row" , title = gettext(" Case" ), type = " integer" )
400
413
if (! is.null(model ) && inherits(model , " jaspClassification" )) {
@@ -438,7 +451,7 @@ is.jaspMachineLearning <- function(x) {
438
451
predictionsColumn <- rep(NA , max(as.numeric(rownames(dataset ))))
439
452
predictionsColumn [as.numeric(rownames(dataset ))] <- predictions [[1 ]]
440
453
jaspResults [[" predictionsColumn" ]] <- createJaspColumn(columnName = options [[" predictionsColumn" ]])
441
- jaspResults [[" predictionsColumn" ]]$ dependOn(options = c(" predictionsColumn" , " predictors" , " trainedModelFilePath" , " scaleVariables " , " addPredictions" ))
454
+ jaspResults [[" predictionsColumn" ]]$ dependOn(options = c(" predictionsColumn" , " predictors" , " trainedModelFilePath" , " addPredictions" ))
442
455
if (inherits(model , " jaspClassification" )) jaspResults [[" predictionsColumn" ]]$ setNominal(predictionsColumn )
443
456
if (inherits(model , " jaspRegression" )) jaspResults [[" predictionsColumn" ]]$ setScale(predictionsColumn )
444
457
}
@@ -451,7 +464,7 @@ is.jaspMachineLearning <- function(x) {
451
464
break
452
465
}
453
466
jaspResults [[colName ]] <- createJaspColumn(columnName = colName )
454
- jaspResults [[colName ]]$ dependOn(options = c(" predictionsColumn" , " predictors" , " trainedModelFilePath" , " scaleVariables " , " addPredictions" , " addProbabilities" ))
467
+ jaspResults [[colName ]]$ dependOn(options = c(" predictionsColumn" , " predictors" , " trainedModelFilePath" , " addPredictions" , " addProbabilities" ))
455
468
jaspResults [[colName ]]$ setScale(predictions [[2 ]][, i ])
456
469
}
457
470
}
0 commit comments