Skip to content

Commit 404bc87

Browse files
committed
explainers
1 parent 7fbe88c commit 404bc87

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

R/commonMachineLearningClassification.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,8 @@
324324
table$title <- gettext("Model Summary: Multinomial Regression Classification")
325325
}
326326
family <- classificationResult[["family"]]
327-
family <- paste0(toupper(substr(family, 1, 1)), substr(family, 2, nchar(family)))
328327
row <- data.frame(
329-
family = family,
328+
family = paste0(toupper(substr(family, 1, 1)), substr(family, 2, nchar(family))),
330329
nTrain = nTrain,
331330
nTest = classificationResult[["ntest"]],
332331
testAcc = classificationResult[["testAcc"]]

R/mlClassificationLogisticMultinomial.R

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
4545
# Create the validation measures table
4646
.mlClassificationTableMetrics(dataset, options, jaspResults, ready, position = 5)
4747

48-
# # Create the variable importance table
49-
# .mlTableFeatureImportance(options, jaspResults, ready, position = 6, purpose = "classification")
48+
# Create the variable importance table
49+
.mlTableFeatureImportance(options, jaspResults, ready, position = 6, purpose = "classification")
5050

51-
# # Create the shap table
52-
# .mlTableShap(dataset, options, jaspResults, ready, position = 7, purpose = "classification")
51+
# Create the shap table
52+
.mlTableShap(dataset, options, jaspResults, ready, position = 7, purpose = "classification")
5353

5454
.mlClassificationLogisticTableCoef(options, jaspResults, ready, position = 8)
5555

@@ -124,12 +124,14 @@ mlClassificationLogisticMultinomial <- function(jaspResults, dataset, options, .
124124
result[["test"]] <- testSet
125125
result[["testIndicatorColumn"]] <- testIndicatorColumn
126126
result[["classes"]] <- dataPredictions
127-
# result[["explainer"]] <- DALEX::explain(result[["model"]], type = "classification", data = result[["train"]], y = result[["train"]][, options[["target"]]], predict_function = function(model, data) predict(model, newdata = data, type = "raw"))
128-
# if (nlevels(result[["testReal"]]) == 2) {
129-
# result[["explainer_fi"]] <- DALEX::explain(result[["model"]], type = "classification", data = result[["train"]], y = as.numeric(result[["train"]][, options[["target"]]]) - 1, predict_function = function(model, data) predict(model, newdata = data, type = "class"))
130-
# } else {
131-
# result[["explainer_fi"]] <- DALEX::explain(result[["model"]], type = "multiclass", data = result[["train"]], y = result[["train"]][, options[["target"]]] , predict_function = function(model, data) predict(model, newdata = data, type = "raw"))
132-
# }
127+
if (family == "binomial") {
128+
result[["explainer"]] <- DALEX::explain(result[["model"]], type = "classification", data = result[["train"]], y = result[["train"]][, options[["target"]]], predict_function = function(model, data) data.frame(1 - predict(model, newdata = data, type = "response"), predict(model, newdata = data, type = "response")))
129+
result[["explainer_fi"]] <- DALEX::explain(result[["model"]], type = "classification", data = result[["train"]], y = as.numeric(result[["train"]][, options[["target"]]]) - 1, predict_function = function(model, data) round(predict(model, newdata = data, type = "response"), 0) + 1)
130+
} else {
131+
# TODO
132+
result[["explainer"]] <- DALEX::explain(result[["model"]][["original"]], type = "multiclass", data = result[["train"]], y = result[["train"]][, options[["target"]]], predict_function = function(model, data) VGAM::predict(model, data, type = "response"))
133+
result[["explainer_fi"]] <- result[["explainer"]]
134+
}
133135
return(result)
134136
}
135137

0 commit comments

Comments
 (0)