Skip to content

Commit b8ca66d

Browse files
ishevchejoaopfonseca
authored andcommitted
Code formating
1 parent c442f69 commit b8ca66d

File tree

8 files changed

+82
-45
lines changed

8 files changed

+82
-45
lines changed

experiments/0.2-basic-experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,4 +279,4 @@
279279

280280
results[dataset["name"]][xai_method["name"]].append(contributions)
281281
result_df = pd.DataFrame(contributions, columns=X.columns, index=X.index)
282-
result_df.to_csv(result_fname)
282+
result_df.to_csv(result_fname)

experiments/0.3-time-experiment.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@
156156
+ [f"fidelity_{i}" for i in range(N_RUNS)]
157157
)
158158

159-
160159
for dataset in datasets:
161160
result_df = []
162161
# Set up basic settings
@@ -174,8 +173,12 @@
174173
ranking = scores_to_ordering(scores)
175174

176175
# Set experiment size if we deleted too many items
177-
dataset["n_observations"] = dataset["n_observations"] if X.shape[0] > dataset["n_observations"] else X.shape[0]
178-
176+
dataset["n_observations"] = (
177+
dataset["n_observations"]
178+
if X.shape[0] > dataset["n_observations"]
179+
else X.shape[0]
180+
)
181+
179182
rng = check_random_state(RNG_SEED)
180183

181184
# rank and score indexes
@@ -184,7 +187,7 @@
184187
size=dataset["n_observations"],
185188
replace=False,
186189
)
187-
190+
188191
# pairwise pairs
189192
combos = list(itertools.combinations(np.indices((X.shape[0],)).squeeze(), 2))
190193
pairs_indexes = rng.choice(
@@ -193,14 +196,25 @@
193196
replace=False,
194197
)
195198
pairs_sample = [combos[i] for i in pairs_indexes]
196-
pairs = [(pair[0], pair[1]) if np.random.choice([0,1]) else (pair[1], pair[0]) for pair in pairs_sample]
199+
pairs = [
200+
(pair[0], pair[1]) if np.random.choice([0, 1]) else (pair[1], pair[0])
201+
for pair in pairs_sample
202+
]
197203

198204
for approach in approaches:
199205
iteration_qoi = approach
200206
if approach.startswith("pairwise"):
201207
iteration_qoi = approach.split("-")[1]
202208
approach = "pairwise"
203-
print("----------------", dataset["name"], "|", approach, "|", iteration_qoi, "----------------")
209+
print(
210+
"----------------",
211+
dataset["name"],
212+
"|",
213+
approach,
214+
"|",
215+
iteration_qoi,
216+
"----------------",
217+
)
204218

205219
times = []
206220
kendall_cons = []
@@ -275,7 +289,7 @@
275289
[
276290
dataset["name"],
277291
dataset["n_observations"],
278-
approach+"_"+iteration_qoi,
292+
approach + "_" + iteration_qoi,
279293
np.nan,
280294
np.nan,
281295
np.mean(times),
@@ -368,10 +382,13 @@
368382
contr, baseline_contr, measure="jaccard", n_features=2
369383
)[0]
370384
)
371-
#Eulidean consistency
385+
# Eulidean consistency
372386
euclidean_cons.append(
373387
cross_method_explanation_consistency(
374-
contr, baseline_contr, measure="euclidean", normalization=True
388+
contr,
389+
baseline_contr,
390+
measure="euclidean",
391+
normalization=True,
375392
)[0]
376393
)
377394
# Iniatialize normalizer
@@ -396,14 +413,14 @@
396413
target_pairs=target[sam_idx2],
397414
rank=True,
398415
)
399-
416+
400417
fidelity.append(res_)
401418

402419
results_row = (
403420
[
404421
dataset["name"],
405422
dataset["n_observations"],
406-
approach+"_"+iteration_qoi,
423+
approach + "_" + iteration_qoi,
407424
parameter,
408425
parameter_value,
409426
np.mean(times),

xai_ranking/_min_dependencies.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,40 +8,45 @@
88
"pandas": ("1.3.5", "metrics, datasets"),
99
"scipy": ("1.14.1", "metrics"),
1010
"scikit-learn": ("1.2.0", "metrics"),
11-
1211
"pytest-cov": ("3.0.0", "tests"),
1312
"flake8": ("3.8.2", "tests"),
1413
"black": ("22.3", "tests"),
1514
"pylint": ("2.12.2", "tests"),
1615
"mypy": ("1.6.1", "tests"),
1716
"sphinx": ("4.2.0", "docs"),
18-
1917
# dev
2018
# "coverage": ("", "tests"),
2119
# "click": ("", "tests"),
22-
2320
# nutrition labels
2421
# "matplotlib" : ("", "install"),
2522
# "seaborn" : ("", "install"),
26-
2723
# L2R
2824
# "lightgbm" : ("", "install"),
29-
3025
# general?
3126
# "xai-sharp": ("0.1.a1", "install"),
3227
# "shap" : ("", "install"),
3328
# "lime" : ("", "install"),
3429
# "statsmodels" : ("", "install"),
3530
# "ml-research" : ("", "install"),
36-
3731
# dataset module
3832
# "openpyxl" : ("", "install"),
3933
# "" : ("", "install"),
4034
}
4135

4236
# create inverse mapping for setuptools
4337
tag_to_packages: dict = {
44-
extra: [] for extra in ["install", "optional", "docs", "examples", "tests", "all", "metrics", "datasets", "scores"]
38+
extra: []
39+
for extra in [
40+
"install",
41+
"optional",
42+
"docs",
43+
"examples",
44+
"tests",
45+
"all",
46+
"metrics",
47+
"datasets",
48+
"scores",
49+
]
4550
}
4651
for package, (min_version, extras) in dependent_packages.items():
4752
for extra in extras.split(", "):

xai_ranking/datasets/_make_synthetic.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@
88
def fetch_synthetic_data(synth_dt_version=2, item_num=1000):
99
# Feature names
1010
column_names = ["n1", "n2", "n3"]
11-
11+
1212
# Check if files exist, if not we will make them
13-
filepath = join(dirname(abspath(__file__)), "files", f"Synthetic_{synth_dt_version}_{item_num}.txt")
13+
filepath = join(
14+
dirname(abspath(__file__)),
15+
"files",
16+
f"Synthetic_{synth_dt_version}_{item_num}.txt",
17+
)
1418

1519
if Path(filepath).is_file():
1620
df = pd.read_csv(
@@ -22,7 +26,7 @@ def fetch_synthetic_data(synth_dt_version=2, item_num=1000):
2226
else:
2327
# Make index names
2428
ind = range(0, item_num)
25-
29+
2630
# Make features based on synthetic data version passed
2731
if synth_dt_version == 0:
2832
# All features are independent
@@ -38,7 +42,7 @@ def fetch_synthetic_data(synth_dt_version=2, item_num=1000):
3842
corr = -0.8
3943
cov1_2 = math.sqrt(var[0]) * math.sqrt(var[1]) * corr
4044
covs = [[var[0], cov1_2, 0], [cov1_2, var[1], 0], [0, 0, var[2]]]
41-
features = np.random.multivariate_normal(means, covs, item_num)
45+
features = np.random.multivariate_normal(means, covs, item_num)
4246
elif synth_dt_version == 2:
4347
# Features 1 & 2 are negatively correlated
4448
# Feature 1 & 3 are positively correlated
@@ -57,15 +61,15 @@ def fetch_synthetic_data(synth_dt_version=2, item_num=1000):
5761
features = np.random.multivariate_normal(means, covs, item_num)
5862
else:
5963
return None
60-
64+
6165
# Make dataframe
6266
df = pd.DataFrame(features, columns=column_names, index=ind)
63-
67+
6468
# Normalize data
6569
for series_name, series in df.items():
6670
df[series_name] = (series - series.min()) / (series.max() - series.min())
6771

6872
# Write to file
6973
df.to_csv(filepath, index=False, header=False)
7074

71-
return df
75+
return df

xai_ranking/metrics/_base.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sharp.utils import scores_to_ordering
66
import pandas as pd
77

8+
89
# Not reviewed
910
# Returns neighbors that are either close or far ranking wise
1011
# AND subselects the top n neighbors in terms of feature similarity
@@ -23,7 +24,7 @@ def _find_neighbors(
2324
& (rankings <= max_ranking)
2425
& (rankings != row_rank)
2526
)
26-
else: # Select neighbors that are far ranking wise
27+
else: # Select neighbors that are far ranking wise
2728
mask = (rankings < min_ranking) | (rankings > max_ranking)
2829
data_neighbors = np.array(original_data)[mask]
2930
cont_neighbors = np.array(contributions)[mask]
@@ -41,7 +42,9 @@ def _find_neighbors(
4142
# Not reviewed
4243
# Returns all neighbors that are similar feature wise
4344
# The Euclidean distance between items has to be under the threshold
44-
def _find_all_neighbors(original_data, rankings, contributions, row_idx, threshold=None):
45+
def _find_all_neighbors(
46+
original_data, rankings, contributions, row_idx, threshold=None
47+
):
4548
row_data = np.array(original_data)[row_idx]
4649

4750
data_neighbors = np.array(original_data)
@@ -66,11 +69,11 @@ def _find_all_neighbors(original_data, rankings, contributions, row_idx, thresho
6669
)
6770
# Or return distances from all items
6871
return (
69-
data_neighbors,
70-
cont_neighbors,
71-
rank_neighbors,
72-
distances,
73-
)
72+
data_neighbors,
73+
cont_neighbors,
74+
rank_neighbors,
75+
distances,
76+
)
7477

7578

7679
# Reviewed
@@ -79,10 +82,12 @@ def _get_importance_mask(row_cont, threshold):
7982
# Calculate order of absolute contributions
8083
row_abs = np.abs(row_cont)
8184
# Find n=threshold largest items
82-
res = sorted(row_abs.index.values, key = lambda sub: row_abs[sub])[-threshold:]
85+
res = sorted(row_abs.index.values, key=lambda sub: row_abs[sub])[-threshold:]
8386
# Set mask
84-
mask = pd.Series(data=[True if i in res else False for i in row_cont.index.values],
85-
index=row_cont.index.values)
87+
mask = pd.Series(
88+
data=[True if i in res else False for i in row_cont.index.values],
89+
index=row_cont.index.values,
90+
)
8691
else:
8792
# Calculate cumulative absolute contribution order
8893
total_contribution = np.sum(np.abs(row_cont))
@@ -128,10 +133,12 @@ def kendall_similarity(a, b):
128133
idx_pair = list(combinations(range(len(a)), 2))
129134
val_pair_a = [(a[i], a[j]) for i, j in idx_pair if a[i] != a[j]]
130135
val_pair_b = [(b[i], b[j]) for i, j in idx_pair if b[i] != b[j]]
131-
inversions=0
136+
inversions = 0
132137
for (val11, val12), (val21, val22) in zip(val_pair_a, val_pair_b):
133-
if ((val11 > val12) and (val21 < val22)) or ((val11 < val12) and (val21 > val22)):
134-
inversions = inversions+1
138+
if ((val11 > val12) and (val21 < val22)) or (
139+
(val11 < val12) and (val21 > val22)
140+
):
141+
inversions = inversions + 1
135142
kt = 1 - (2 * inversions) / normalizer
136143
return (kt + 1) / 2
137144

@@ -223,7 +230,7 @@ def row_wise_jaccard(results1, results2, n_features):
223230
>>> n_features = 2
224231
>>> row_wise_jaccard(results1, results2, n_features)
225232
"""
226-
233+
227234
if n_features is None:
228235
n_features = results1.shape[1]
229236

@@ -246,9 +253,9 @@ def row_wise_euclidean(results1, results2, normalization=True):
246253
# Make vectors into unit vectors
247254
v1 = normalize([results1])[0]
248255
v2 = normalize([results2])[0]
249-
return euclidean(v1,v2)/2
256+
return euclidean(v1, v2) / 2
250257
else:
251-
return euclidean(results1,results2)
258+
return euclidean(results1, results2)
252259

253260

254261
# Reviewed
@@ -279,7 +286,8 @@ def euclidean_agreement(results1, results2, normalization):
279286
vectors in `results1` and `results2` using the Euclidean distance.
280287
"""
281288
return results1.reset_index(drop=True).apply(
282-
lambda row: 1 - row_wise_euclidean(row, results2.iloc[row.name], normalization), axis=1
289+
lambda row: 1 - row_wise_euclidean(row, results2.iloc[row.name], normalization),
290+
axis=1,
283291
)
284292

285293

@@ -315,6 +323,7 @@ def kendall_agreement(results1, results2):
315323
lambda row: row_wise_kendall(row, results2.iloc[row.name]), axis=1
316324
)
317325

326+
318327
# Reviewed
319328
def jaccard_agreement(results1, results2, n_features=0.8):
320329
"""

xai_ranking/metrics/_consistency.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def bootstrapped_explanation_consistency(
1919
sem = np.std(batch_agreement) / np.sqrt(batch_agreement.size)
2020
return mean, sem
2121

22+
2223
# Reviewed
2324
def cross_method_explanation_consistency(
2425
results1, results2, measure="kendall", **kwargs

xai_ranking/metrics/_fidelity.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22

3+
34
# Reviewed
45
def outcome_fidelity(
56
contributions, target, avg_target, target_max=1, target_pairs=None, rank=True
@@ -21,7 +22,7 @@ def outcome_fidelity(
2122
better_than = target < target_pairs
2223
else:
2324
better_than = target > target_pairs
24-
25+
2526
est_better_than = contributions.sum(axis=1) > 0
2627
avg_est_err = (better_than == est_better_than).mean()
2728
return avg_est_err

xai_ranking/metrics/_sensitivity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def row_wise_explanation_sensitivity_all_neighbors(
190190

191191

192192
# Calculates the explanation sensitivity of every row of original data and its
193-
# closest neighbors,
193+
# closest neighbors,
194194
def explanation_sensitivity(
195195
original_data,
196196
contributions,

0 commit comments

Comments
 (0)