|
8 | 8 |
|
9 | 9 | from fair_mango.dataset.dataset import Dataset |
10 | 10 | from fair_mango.typing import ( |
| 11 | + BaseMetricResult, |
| 12 | + CombinedPerformanceResult, |
11 | 13 | DisparityResultDict, |
12 | | - FairnessSummaryDifferenceResult, |
13 | 14 | FairnessSummaryDifferenceFairResult, |
14 | | - FairnessSummaryRatioResult, |
| 15 | + FairnessSummaryDifferenceResult, |
15 | 16 | FairnessSummaryRatioFairResult, |
16 | | - BaseMetricResult, |
17 | | - CombinedPerformanceResult, |
| 17 | + FairnessSummaryRatioResult, |
18 | 18 | RankResult, |
19 | | - SensitiveGroupTupleT, |
20 | 19 | SensitiveGroupOptionalT, |
| 20 | + SensitiveGroupTupleT, |
21 | 21 | ) |
22 | 22 |
|
23 | 23 | LabelT = TypeVar("LabelT", bound=str) |
@@ -270,11 +270,26 @@ def _to_float(result) -> float: |
270 | 270 | return float(data) |
271 | 271 |
|
272 | 272 | a, b = _to_float(rec_i), _to_float(rec_j) |
| 273 | + |
273 | 274 | if method == "difference": |
| 275 | + # # Ensure a is the larger value for consistent disparity calculation |
| 276 | + # # that way, the difference will always be >= 0 |
| 277 | + if b > a: |
| 278 | + a, b = b, a |
| 279 | + grp_i, grp_j = grp_j, grp_i |
| 280 | + |
274 | 281 | disp = a - b |
275 | 282 | else: |
276 | | - if b == 0: |
277 | | - disp = 1.0 if a == 0 else float("inf") |
| 283 | + # Ensure b is the larger value for consistent disparity calculation |
| 284 | + # that way, the ratio will always be <= 1 |
| 285 | + if a > b: |
| 286 | + a, b = b, a |
| 287 | + grp_i, grp_j = grp_j, grp_i |
| 288 | + |
| 289 | + if a == b: |
| 290 | + disp = 1.0 |
| 291 | + elif b == 0: |
| 292 | + disp = float("inf") |
278 | 293 | else: |
279 | 294 | disp = a / b |
280 | 295 |
|
@@ -394,14 +409,16 @@ def summary( |
394 | 409 | unprivileged_sensitive_group: SensitiveGroupOptionalT = None |
395 | 410 |
|
396 | 411 | for disparity_result in self.results: |
397 | | - abs_disparity = abs(disparity_result["disparity"]) |
| 412 | + disparity = disparity_result["disparity"] |
| 413 | + abs_disparity = abs(disparity) |
398 | 414 |
|
399 | 415 | if abs_disparity > max_disparity: |
400 | 416 | max_disparity = abs_disparity |
401 | | - if disparity_result["disparity"] > 0: |
| 417 | + if disparity > 0: |
402 | 418 | privileged_sensitive_group = disparity_result["group_1"] |
403 | 419 | unprivileged_sensitive_group = disparity_result["group_2"] |
404 | 420 | else: |
| 421 | + # in case disparity is negative, we swap the two groups to make it positive again |
405 | 422 | privileged_sensitive_group = disparity_result["group_2"] |
406 | 423 | unprivileged_sensitive_group = disparity_result["group_1"] |
407 | 424 |
|
|
0 commit comments