Skip to content

Commit ac13544

Browse files
committed
ENH: make results more consistent by always putting the priledged group first to always have difference >= 0 and ratio <= 1
1 parent a24ca91 commit ac13544

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

fair_mango/metrics/base.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88

99
from fair_mango.dataset.dataset import Dataset
1010
from fair_mango.typing import (
11+
BaseMetricResult,
12+
CombinedPerformanceResult,
1113
DisparityResultDict,
12-
FairnessSummaryDifferenceResult,
1314
FairnessSummaryDifferenceFairResult,
14-
FairnessSummaryRatioResult,
15+
FairnessSummaryDifferenceResult,
1516
FairnessSummaryRatioFairResult,
16-
BaseMetricResult,
17-
CombinedPerformanceResult,
17+
FairnessSummaryRatioResult,
1818
RankResult,
19-
SensitiveGroupTupleT,
2019
SensitiveGroupOptionalT,
20+
SensitiveGroupTupleT,
2121
)
2222

2323
LabelT = TypeVar("LabelT", bound=str)
@@ -270,11 +270,26 @@ def _to_float(result) -> float:
270270
return float(data)
271271

272272
a, b = _to_float(rec_i), _to_float(rec_j)
273+
273274
if method == "difference":
275+
# # Ensure a is the larger value for consistent disparity calculation
276+
# # that way, the difference will always be >= 0
277+
if b > a:
278+
a, b = b, a
279+
grp_i, grp_j = grp_j, grp_i
280+
274281
disp = a - b
275282
else:
276-
if b == 0:
277-
disp = 1.0 if a == 0 else float("inf")
283+
# Ensure b is the larger value for consistent disparity calculation
284+
# that way, the ratio will always be <= 1
285+
if a > b:
286+
a, b = b, a
287+
grp_i, grp_j = grp_j, grp_i
288+
289+
if a == b:
290+
disp = 1.0
291+
elif b == 0:
292+
disp = float("inf")
278293
else:
279294
disp = a / b
280295

@@ -394,14 +409,16 @@ def summary(
394409
unprivileged_sensitive_group: SensitiveGroupOptionalT = None
395410

396411
for disparity_result in self.results:
397-
abs_disparity = abs(disparity_result["disparity"])
412+
disparity = disparity_result["disparity"]
413+
abs_disparity = abs(disparity)
398414

399415
if abs_disparity > max_disparity:
400416
max_disparity = abs_disparity
401-
if disparity_result["disparity"] > 0:
417+
if disparity > 0:
402418
privileged_sensitive_group = disparity_result["group_1"]
403419
unprivileged_sensitive_group = disparity_result["group_2"]
404420
else:
421+
# in case disparity is negative, we swap the two groups to make it positive again
405422
privileged_sensitive_group = disparity_result["group_2"]
406423
unprivileged_sensitive_group = disparity_result["group_1"]
407424

0 commit comments

Comments
 (0)