Skip to content

Commit aa50ddd

Browse files
authored
New features (#43)
* move arguments * add filtering out metrics * update documentation * refactor code with tidyverse style" * documentation change * add consistency in plot and doc * Add R versions and change linux version * add instruction for adding new metric * refactor code * Add readme info * remove mac 3.5 * update version * change link
1 parent a25b22d commit aa50ddd

File tree

119 files changed

+3735
-3306
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

119 files changed

+3735
-3306
lines changed

.Rbuildignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@
66
^codecov\.yml$
77
^\.github$
88
^LICENSE$
9+
^doc$
10+
^Meta$
11+
^CRAN-RELEASE$

.github/workflows/R-CMD-check.yaml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,18 @@ jobs:
3232
matrix:
3333
config:
3434
- {os: windows-latest, r: 'devel'}
35+
- {os: windows-latest, r: '4.1'}
3536
- {os: windows-latest, r: '4.0'}
3637
- {os: windows-latest, r: '3.6'}
38+
- {os: windows-latest, r: '3.5'}
39+
- {os: macOS-latest, r: '4.1'}
3740
- {os: macOS-latest, r: '4.0'}
3841
- {os: macOS-latest, r: '3.6'}
39-
- {os: ubuntu-16.04, r: '4.0', rspm: "https://demo.rstudiopm.com/all/__linux__/xenial/latest"}
40-
- {os: ubuntu-16.04, r: '3.6', rspm: "https://demo.rstudiopm.com/all/__linux__/xenial/latest"}
42+
- {os: ubuntu-18.04, r: '4.1', vdiffr: true, xref: true, rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest"}
43+
- {os: ubuntu-18.04, r: '4.0', vdiffr: true, xref: true, rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest"}
44+
- {os: ubuntu-18.04, r: '3.6', vdiffr: true, xref: true, rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest"}
45+
- {os: ubuntu-18.04, r: '3.5', vdiffr: true, xref: true, rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest"}
46+
4147

4248
env:
4349
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@
33
.RData
44
.Ruserdata
55
inst/doc
6+
CRAN-RELEASE
7+
/doc/
8+
/Meta/

DESCRIPTION

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Package: fairmodels
22
Type: Package
33
Title: Flexible Tool for Bias Detection, Visualization, and Mitigation
4-
Version: 1.1.1
4+
Version: 1.2.0
55
Authors@R:
66
c(person("Jakub", "Wiśniewski", role = c("aut", "cre"),
77
email = "[email protected]"),
@@ -15,18 +15,19 @@ Depends: R (>= 3.5)
1515
Imports:
1616
DALEX,
1717
ggplot2,
18+
scales,
19+
stats,
1820
patchwork,
19-
ggdendro,
20-
ggrepel,
21-
scales
2221
Suggests:
2322
ranger,
2423
gbm,
2524
knitr,
2625
rmarkdown,
2726
covr,
2827
testthat,
29-
spelling
28+
spelling,
29+
ggdendro,
30+
ggrepel,
3031
RoxygenNote: 7.1.1.9001
3132
VignetteBuilder: knitr
3233
URL: https://fairmodels.drwhy.ai/

NAMESPACE

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,5 @@ export(stack_metrics)
5555
import(ggplot2)
5656
import(patchwork)
5757
importFrom(DALEX,model_performance)
58-
importFrom(DALEX,theme_drwhy)
59-
importFrom(DALEX,theme_drwhy_vertical)
60-
importFrom(ggdendro,dendro_data)
61-
importFrom(ggdendro,segment)
62-
importFrom(ggrepel,geom_text_repel)
63-
importFrom(stats,binomial)
64-
importFrom(stats,dist)
65-
importFrom(stats,ecdf)
66-
importFrom(stats,glm)
67-
importFrom(stats,hclust)
68-
importFrom(stats,median)
6958
importFrom(stats,na.omit)
70-
importFrom(stats,quantile)
7159
importFrom(utils,head)

NEWS.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
# fairmodels 1.2.0
2+
* Added filtering metrics when plotting and printing of `fairness_object`.
3+
* Added ability to add custom measure function to print method of `fairness_object`
4+
* Refactored code with tidyverse style
5+
* Changed the order of metrics in `metric_scores` plot to match the ones in `fairness_check`
6+
* Added instruction for creating custom metric in README
7+
* Added references to vignettes
8+
* Enhanced the advanced vignette
9+
110
# fairmodels 1.1.1
211
* Fixed error which appeared when 2 fairness objects had the same labels in them. Now if this appears it throws an error. [(#41)](https://github.com/ModelOriented/fairmodels/issues/41)
312
* `privileged` parameter is now converted to character. [(#41)](https://github.com/ModelOriented/fairmodels/issues/41)

R/all_cutoffs.R

Lines changed: 49 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -7,98 +7,106 @@
77
#' @param grid_points numeric, grid for cutoffs to test. Number of points between 0 and 1 spread evenly
88
#' @param fairness_metrics character, name of parity_loss metric or vector of multiple metrics names. Full names can be found in \code{fairness_check} documentation.
99
#'
10+
#' @import ggplot2
11+
#'
1012
#' @return \code{all_cutoffs} object, \code{data.frame} containing information about label, metric and parity_loss at particular cutoff
1113
#' @export
1214
#'
1315
#' @examples
1416
#' data("german")
1517
#'
16-
#' y_numeric <- as.numeric(german$Risk) -1
18+
#' y_numeric <- as.numeric(german$Risk) - 1
1719
#'
18-
#' lm_model <- glm(Risk~.,
19-
#' data = german,
20-
#' family=binomial(link="logit"))
20+
#' lm_model <- glm(Risk ~ .,
21+
#' data = german,
22+
#' family = binomial(link = "logit")
23+
#' )
2124
#'
22-
#' explainer_lm <- DALEX::explain(lm_model, data = german[,-1], y = y_numeric)
25+
#' explainer_lm <- DALEX::explain(lm_model, data = german[, -1], y = y_numeric)
2326
#'
2427
#' fobject <- fairness_check(explainer_lm,
25-
#' protected = german$Sex,
26-
#' privileged = "male")
28+
#' protected = german$Sex,
29+
#' privileged = "male"
30+
#' )
2731
#'
2832
#' ac <- all_cutoffs(fobject)
2933
#' plot(ac)
30-
#'
3134
#' \donttest{
32-
#' rf_model <- ranger::ranger(Risk ~.,
33-
#' data = german,
34-
#' probability = TRUE,
35-
#' num.trees = 100,
36-
#' seed = 1)
35+
#' rf_model <- ranger::ranger(Risk ~ .,
36+
#' data = german,
37+
#' probability = TRUE,
38+
#' num.trees = 100,
39+
#' seed = 1
40+
#' )
3741
#'
3842
#'
3943
#' explainer_rf <- DALEX::explain(rf_model,
40-
#' data = german[,-1],
41-
#' y = y_numeric)
44+
#' data = german[, -1],
45+
#' y = y_numeric
46+
#' )
4247
#'
4348
#' fobject <- fairness_check(explainer_rf, fobject)
4449
#'
4550
#' ac <- all_cutoffs(fobject)
4651
#'
4752
#' plot(ac)
4853
#' }
49-
54+
#'
5055
all_cutoffs <- function(x,
5156
grid_points = 101,
52-
fairness_metrics = c('ACC', 'TPR', 'PPV', 'FPR', 'STP')){
53-
57+
fairness_metrics = c("ACC", "TPR", "PPV", "FPR", "STP")) {
5458
stopifnot(class(x) == "fairness_object")
5559

5660
# error if not in metrics
5761
lapply(fairness_metrics, assert_parity_metrics)
5862

59-
if (! is.numeric(grid_points) | length(grid_points) > 1) stop("grid points must be single numeric value")
63+
if (!is.numeric(grid_points) |
64+
length(grid_points) > 1) {
65+
stop("grid points must be single numeric value")
66+
}
6067

6168

6269

6370
explainers <- x$explainers
64-
n_exp <- length(explainers)
65-
cutoffs <- seq(0,1, length.out = grid_points)
66-
protected <- x$protected
71+
n_exp <- length(explainers)
72+
cutoffs <- seq(0, 1, length.out = grid_points)
73+
protected <- x$protected
6774
privileged <- x$privileged
6875

6976
n_subgroups <- length(levels(protected))
7077
cutoff_data <- data.frame()
7178

7279
# custom cutoffs will give messages (0 in matrices, NA in metrics) numerous times,
7380
# so for code below they will be suppressed
74-
parity_loss_metric_data <- matrix(nrow = n_exp, ncol = 12)
75-
76-
suppressMessages(
77-
for (i in seq_along(explainers)){
78-
for (custom_cutoff in cutoffs){
81+
parity_loss_metric_data <- matrix(nrow = n_exp, ncol = 12)
7982

80-
custom_cutoff_vec <- as.list(rep(custom_cutoff, n_subgroups))
83+
suppressMessages(for (i in seq_along(explainers)) {
84+
for (custom_cutoff in cutoffs) {
85+
custom_cutoff_vec <- as.list(rep(custom_cutoff, n_subgroups))
8186
names(custom_cutoff_vec) <- levels(protected)
82-
explainer <- explainers[[i]]
87+
explainer <- explainers[[i]]
8388

8489

85-
group_matrices <- group_matrices(protected = protected,
86-
probs = explainer$y_hat,
87-
preds = explainer$y,
88-
cutoff = custom_cutoff_vec)
90+
group_matrices <- group_matrices(
91+
protected = protected,
92+
probs = explainer$y_hat,
93+
preds = explainer$y,
94+
cutoff = custom_cutoff_vec
95+
)
8996

9097
# like in create fobject
91-
gmm <- calculate_group_fairness_metrics(group_matrices)
92-
parity_loss <- calculate_parity_loss(gmm, privileged)
98+
gmm <- calculate_group_fairness_metrics(group_matrices)
99+
parity_loss <- calculate_parity_loss(gmm, privileged)
93100
parity_loss <- parity_loss[names(parity_loss) %in% fairness_metrics]
94101

95-
to_add <- data.frame(parity_loss = as.numeric(parity_loss),
96-
metric = names(parity_loss),
97-
cutoff = rep(custom_cutoff, length(parity_loss)),
98-
label = x$label[i])
99-
100-
cutoff_data <- rbind(cutoff_data , to_add)
102+
to_add <- data.frame(
103+
parity_loss = as.numeric(parity_loss),
104+
metric = names(parity_loss),
105+
cutoff = rep(custom_cutoff, length(parity_loss)),
106+
label = x$label[i]
107+
)
101108

109+
cutoff_data <- rbind(cutoff_data, to_add)
102110
}
103111
})
104112

@@ -107,18 +115,3 @@ all_cutoffs <- function(x,
107115

108116
return(all_cutoffs)
109117
}
110-
111-
112-
113-
114-
115-
116-
117-
118-
119-
120-
121-
122-
123-
124-

R/calculate_group_fairness_metrics.R

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,50 +4,50 @@
44
#'
55
#' @param x object of class \code{group_matrices}
66
#'
7+
#'
78
#' @return \code{group_metric_matrix} object
89
#' It's a \code{data.frame} with metrics as row names and scores for those metrics for each subgroup in columns
910
#' @export
1011
#'
1112

12-
calculate_group_fairness_metrics <- function(x){
13-
stopifnot( "group_matrices" %in% class(x) )
13+
calculate_group_fairness_metrics <- function(x) {
14+
stopifnot("group_matrices" %in% class(x))
1415

15-
group_metric_matrix <- matrix(0, nrow = 12 , ncol = length(x))
16+
group_metric_matrix <- matrix(0, nrow = 13, ncol = length(x))
1617
colnames(group_metric_matrix) <- names(x)
17-
rownames(group_metric_matrix) <- c("TPR","TNR","PPV","NPV","FNR","FPR","FDR","FOR","TS","STP","ACC","F1")
18+
rownames(group_metric_matrix) <- c("TPR", "TNR", "PPV", "NPV", "FNR", "FPR", "FDR", "FOR", "TS", "STP", "ACC", "F1", "NEW_METRIC")
1819

19-
for (i in seq_along(x)){
20+
for (i in seq_along(x)) {
2021
subgroup_cm <- x[[i]]
2122

2223
tp <- subgroup_cm$tp
2324
tn <- subgroup_cm$tn
2425
fp <- subgroup_cm$fp
2526
fn <- subgroup_cm$fn
2627

27-
TPR <- tp/(tp + fn)
28-
TNR <- tn/(tn + fp)
29-
PPV <- tp/(tp + fp)
30-
NPV <- tn/(tn + fn)
31-
FNR <- fn/(fn + tp)
32-
FPR <- fp/(fp + tn)
33-
FDR <- fp/(fp + tp)
34-
FOR <- fn/(fn+ tn)
35-
TS <- tp/(tp + fn + fp)
36-
28+
TPR <- tp / (tp + fn)
29+
TNR <- tn / (tn + fp)
30+
PPV <- tp / (tp + fp)
31+
NPV <- tn / (tn + fn)
32+
FNR <- fn / (fn + tp)
33+
FPR <- fp / (fp + tn)
34+
FDR <- fp / (fp + tp)
35+
FOR <- fn / (fn + tn)
36+
TS <- tp / (tp + fn + fp)
37+
NEW_METRIC <- TPR / FNR
3738
# accumulated metrics
38-
STP <- (tp + fp) /(tp + fp + tn + fn)
39+
STP <- (tp + fp) / (tp + fp + tn + fn)
3940
ACC <- (tp + tn) / (tp + tn + fn + fp)
40-
F1 <- 2 * PPV*TPR/(PPV + TPR)
41-
42-
#m <- sqrt(tp+fp)*sqrt(tp+fn)*sqrt(tn+fp)*sqrt(tn+fn)
43-
#MCC <- (tp*tn - fp * fn)/m
41+
F1 <- 2 * PPV * TPR / (PPV + TPR)
4442

45-
group_metric_matrix[,i] <- c(TPR,TNR,PPV,NPV,FNR,FPR,FDR,FOR,TS,STP,ACC,F1)
43+
# m <- sqrt(tp+fp)*sqrt(tp+fn)*sqrt(tn+fp)*sqrt(tn+fn)
44+
# MCC <- (tp*tn - fp * fn)/m
4645

46+
group_metric_matrix[, i] <- c(TPR, TNR, PPV, NPV, FNR, FPR, FDR, FOR, TS, STP, ACC, F1, NEW_METRIC)
4747
}
4848

4949
# NA instead of NaN
50-
if (sum(is.nan(group_metric_matrix))){
50+
if (sum(is.nan(group_metric_matrix))) {
5151
group_metric_matrix[is.nan(group_metric_matrix)] <- NA
5252
}
5353

0 commit comments

Comments
 (0)