Skip to content

Commit a2f7158

Browse files
Krsto ProrokovićKrsto Proroković
authored andcommitted
Lint and add __init__.py to utils
1 parent be17bb7 commit a2f7158

File tree

4 files changed

+73
-3
lines changed

4 files changed

+73
-3
lines changed

tests/test_bahc.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
def test_shapes():
6-
# Checks that labels and biases have the right shapes
6+
# Checks that labels and scores have the right shapes
77
rng = np.random.RandomState(12)
88
X = rng.rand(20, 10)
99
y = rng.rand(20)
@@ -23,11 +23,40 @@ def test_labels():
2323
assert np.array_equal(np.unique(bahc.labels_), np.arange(bahc.n_clusters_))
2424

2525

26-
def test_biases():
27-
# Checks that biases are sorted in descending order
26+
# def test_cluster_sizes():
27+
# Checks that cluster sizes are at least bahc_min_cluster_size
28+
29+
30+
def test_scores():
31+
# Checks that scores are computed correctly
32+
rng = np.random.RandomState(12)
33+
X = rng.rand(20, 10)
34+
y = rng.rand(20)
35+
bahc = BiasAwareHierarchicalKMeans(bahc_max_iter=5, bahc_min_cluster_size=2)
36+
bahc.fit(X, y)
37+
# TODO: Check this!!!
38+
for i in range(bahc.n_clusters_):
39+
cluster_indices = np.arange(20)[bahc.labels_ == i]
40+
complement_indices = np.arange(20)[bahc.labels_ != i]
41+
score = np.mean(y[complement_indices]) - np.mean(y[cluster_indices])
42+
assert bahc.scores_[i] == score
43+
44+
45+
def test_scores_are_sorted():
46+
# Checks that scores are sorted in descending order
2847
rng = np.random.RandomState(12)
2948
X = rng.rand(20, 10)
3049
y = rng.rand(20)
3150
bahc = BiasAwareHierarchicalKMeans(bahc_max_iter=5, bahc_min_cluster_size=2)
3251
bahc.fit(X, y)
3352
assert np.all(bahc.scores_[:-1] >= bahc.scores_[1:])
53+
54+
55+
def test_predict():
56+
# Checks that predict returns the same labels as fit
57+
rng = np.random.RandomState(12)
58+
X = rng.rand(20, 10)
59+
y = rng.rand(20)
60+
bahc = BiasAwareHierarchicalKMeans(bahc_max_iter=5, bahc_min_cluster_size=2)
61+
bahc.fit(X, y)
62+
assert np.array_equal(bahc.predict(X), bahc.labels_)

unsupervised_bias_detection/cluster/_bahc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def fit(self, X, y):
9797
# We calculate the discrimination scores using formula (1) in [1]
9898
# TODO: Move y[indices0] and y[indices1] into separate variables
9999
# to avoid recomputing them
100+
# Maybe create a function to compute the score
100101
mask0 = np.ones(n_samples, dtype=bool)
101102
mask0[indices0] = False
102103
score0 = np.mean(y[mask0]) - np.mean(y[indices0])
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""The :mod:`unsupervised_bias_detection.utils` module implements utility functions."""
2+
3+
from ._get_column_dtypes import get_column_dtypes
4+
5+
__all__ = [
6+
"get_column_dtypes",
7+
]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import numpy as np
2+
import pandas as pd
3+
4+
5+
def get_column_dtypes(data) -> dict:
6+
"""
7+
Return a dictionary mapping column names to abstract data types that are compatible with the processor.
8+
9+
The mapping is as follows:
10+
- float64, float32, int64, int32 -> "numerical"
11+
- bool -> "boolean"
12+
- datetime64[...] -> "datetime"
13+
- timedelta64[...] -> "timedelta"
14+
- All others (e.g., object) -> "categorical"
15+
"""
16+
def map_dtype(dtype: str) -> str:
17+
if dtype in ['float64', 'float32', 'int64', 'int32']:
18+
return "numerical"
19+
elif dtype == 'bool':
20+
return "boolean"
21+
elif 'datetime' in dtype:
22+
return "datetime"
23+
elif 'timedelta' in dtype:
24+
return "timedelta"
25+
else:
26+
return "categorical"
27+
28+
if isinstance(data, pd.DataFrame):
29+
return {col: map_dtype(str(dtype)) for col, dtype in data.dtypes.items()}
30+
elif isinstance(data, np.ndarray) and data.dtype.names is not None:
31+
return {name: map_dtype(str(data.dtype.fields[name][0])) for name in data.dtype.names}
32+
else:
33+
raise TypeError("Data must be a pandas DataFrame or a structured numpy array.")

0 commit comments

Comments
 (0)