Skip to content

Commit 74c3f8b

Browse files
authored
[WIP] Unify tests (#214)
* Converge test fixtures Signed-off-by: gaugup <gaugup@microsoft.com> * Unify tests Signed-off-by: gaugup <gaugup@microsoft.com> * Unify more tests Signed-off-by: gaugup <gaugup@microsoft.com> * Fix lint Signed-off-by: gaugup <gaugup@microsoft.com> * Migrate more tests Signed-off-by: gaugup <gaugup@microsoft.com> * Migrate few more tests to common tests Signed-off-by: gaugup <gaugup@microsoft.com> * Unify more tests Signed-off-by: gaugup <gaugup@microsoft.com>
1 parent ea97c91 commit 74c3f8b

File tree

5 files changed

+288
-254
lines changed

5 files changed

+288
-254
lines changed

tests/conftest.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,34 @@ def regression_exp_object(method="random"):
5252
return exp
5353

5454

55+
@pytest.fixture(scope='session')
56+
def custom_public_data_interface():
57+
dataset = helpers.load_custom_testing_dataset_regression()
58+
d = dice_ml.Data(dataframe=dataset, continuous_features=['Numerical'], outcome_name='Outcome')
59+
return d
60+
61+
62+
@pytest.fixture(scope='session')
63+
def sklearn_binary_classification_model_interface():
64+
ML_modelpath = helpers.get_custom_dataset_modelpath_pipeline_binary()
65+
m = dice_ml.Model(model_path=ML_modelpath, backend='sklearn', model_type='classifier')
66+
return m
67+
68+
69+
@pytest.fixture(scope='session')
70+
def sklearn_multiclass_classification_model_interface():
71+
ML_modelpath = helpers.get_custom_dataset_modelpath_pipeline_multiclass()
72+
m = dice_ml.Model(model_path=ML_modelpath, backend='sklearn', model_type='classifier')
73+
return m
74+
75+
76+
@pytest.fixture(scope='session')
77+
def sklearn_regression_model_interface():
78+
ML_modelpath = helpers.get_custom_dataset_modelpath_pipeline_regression()
79+
m = dice_ml.Model(model_path=ML_modelpath, backend='sklearn', model_type='regression')
80+
return m
81+
82+
5583
@pytest.fixture
5684
def public_data_object():
5785
"""

tests/test_dice_interface/test_dice_KD.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import numpy as np
33
import dice_ml
44
from dice_ml.utils import helpers
5-
from dice_ml.utils.exception import UserConfigValidationException
65
from dice_ml.diverse_counterfactuals import CounterfactualExamples
76
from dice_ml.counterfactual_explanations import CounterfactualExplanations
87

@@ -46,17 +45,6 @@ def _initiate_exp_object(self, KD_binary_classification_exp_object):
4645
self.exp = KD_binary_classification_exp_object # explainer object
4746
self.data_df_copy = self.exp.data_interface.data_df.copy()
4847

49-
# When no elements in the desired_class are present in the training data
50-
@pytest.mark.parametrize("desired_class, total_CFs", [(1, 3), ('a', 3)])
51-
def test_unsupported_binary_class(self, desired_class, sample_custom_query_1, total_CFs):
52-
with pytest.raises(UserConfigValidationException) as ucve:
53-
self.exp._generate_counterfactuals(query_instance=sample_custom_query_1, total_CFs=total_CFs,
54-
desired_class=desired_class)
55-
if desired_class == 1:
56-
assert "Desired class not present in training data!" in str(ucve)
57-
else:
58-
assert "The target class for {0} could not be identified".format(desired_class) in str(ucve)
59-
6048
# When a query's feature value is not within the permitted range and the feature is not allowed to vary
6149
@pytest.mark.parametrize("desired_range, desired_class, total_CFs, features_to_vary, permitted_range",
6250
[(None, 0, 4, ['Numerical'], {'Categorical': ['b', 'c']})])
@@ -119,20 +107,6 @@ def test_permitted_range_categorical(self, desired_class, sample_custom_query_2,
119107
total_CFs=total_CFs, permitted_range=permitted_range)
120108
assert all(i in permitted_range["Categorical"] for i in self.exp.final_cfs_df.Categorical.values)
121109

122-
# Testing if an error is thrown when the query instance has an unknown categorical variable
123-
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 1)])
124-
def test_query_instance_outside_bounds(self, desired_class, sample_custom_query_3, total_CFs):
125-
with pytest.raises(ValueError):
126-
self.exp._generate_counterfactuals(query_instance=sample_custom_query_3, total_CFs=total_CFs,
127-
desired_class=desired_class)
128-
129-
# Testing if an error is thrown when the query instance has an unknown column
130-
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 1)])
131-
def test_query_instance_unknown_column(self, desired_class, sample_custom_query_5, total_CFs):
132-
with pytest.raises(ValueError):
133-
self.exp._generate_counterfactuals(query_instance=sample_custom_query_5, total_CFs=total_CFs,
134-
desired_class=desired_class)
135-
136110
# Ensuring that there are no duplicates in the resulting counterfactuals even if the dataset has duplicates
137111
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 2)])
138112
def test_duplicates(self, desired_class, sample_custom_query_4, total_CFs):
@@ -147,12 +121,6 @@ def test_duplicates(self, desired_class, sample_custom_query_4, total_CFs):
147121

148122
assert all(self.exp.final_cfs_df == expected_output)
149123

150-
# Testing for 0 CFs needed
151-
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 0)])
152-
def test_zero_cfs(self, desired_class, sample_custom_query_4, total_CFs):
153-
self.exp._generate_counterfactuals(query_instance=sample_custom_query_4, total_CFs=total_CFs,
154-
desired_class=desired_class)
155-
156124
# Testing for index returned
157125
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 1)])
158126
@pytest.mark.parametrize('posthoc_sparsity_algorithm', ['linear', 'binary', None])
@@ -179,33 +147,6 @@ def test_KD_tree_output(self, desired_class, sample_custom_query_2, total_CFs,
179147
posthoc_sparsity_algorithm=posthoc_sparsity_algorithm)
180148
assert all(i == desired_class for i in self.exp_multi.cfs_preds)
181149

182-
# Testing that the output of multiclass classification lies in the desired_class
183-
@pytest.mark.parametrize("desired_class, total_CFs", [(2, 3)])
184-
def test_KD_tree_counterfactual_explanations_output(self, desired_class, sample_custom_query_2, total_CFs):
185-
counterfactual_explanations = self.exp_multi.generate_counterfactuals(
186-
query_instances=sample_custom_query_2, total_CFs=total_CFs,
187-
desired_class=desired_class)
188-
assert all(i == desired_class for i in self.exp_multi.cfs_preds)
189-
190-
assert counterfactual_explanations is not None
191-
192-
# Testing for 0 CFs needed
193-
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 0)])
194-
def test_zero_cfs(self, desired_class, sample_custom_query_4, total_CFs):
195-
self.exp_multi._generate_counterfactuals(query_instance=sample_custom_query_4, total_CFs=total_CFs,
196-
desired_class=desired_class)
197-
198-
# When no elements in the desired_class are present in the training data
199-
@pytest.mark.parametrize("desired_class, total_CFs", [(100, 3), ('opposite', 3)])
200-
def test_unsupported_multiclass(self, desired_class, sample_custom_query_4, total_CFs):
201-
with pytest.raises(UserConfigValidationException) as ucve:
202-
self.exp_multi._generate_counterfactuals(query_instance=sample_custom_query_4, total_CFs=total_CFs,
203-
desired_class=desired_class)
204-
if desired_class == 100:
205-
assert "Desired class not present in training data!" in str(ucve)
206-
else:
207-
assert "Desired class cannot be opposite if the number of classes is more than 2." in str(ucve)
208-
209150

210151
class TestDiceKDRegressionMethods:
211152
@pytest.fixture(autouse=True)

tests/test_dice_interface/test_dice_genetic.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,6 @@ def test_invalid_query_instance(self, sample_custom_query_1, features_to_vary, p
5656
with pytest.raises(ValueError):
5757
self.exp.setup(features_to_vary, permitted_range, sample_custom_query_1, feature_weights)
5858

59-
# # Testing that the counterfactuals are in the desired class
60-
@pytest.mark.parametrize("desired_class, total_CFs, features_to_vary, initialization",
61-
[(1, 2, "all", "kdtree"), (1, 2, "all", "random")])
62-
def test_desired_class(self, desired_class, sample_custom_query_2, total_CFs, features_to_vary, initialization):
63-
ans = self.exp.generate_counterfactuals(query_instances=sample_custom_query_2,
64-
features_to_vary=features_to_vary,
65-
total_CFs=total_CFs, desired_class=desired_class,
66-
initialization=initialization)
67-
for cfs_example in ans.cf_examples_list:
68-
assert all(
69-
cfs_example.final_cfs_df[self.exp.data_interface.outcome_name].values == [desired_class] * total_CFs)
70-
7159
# Testing that the features_to_vary argument actually varies only the features that you wish to vary
7260
@pytest.mark.parametrize("desired_class, total_CFs, features_to_vary, initialization",
7361
[(1, 2, ["Numerical"], "kdtree"), (1, 2, ["Numerical"], "random")])
@@ -121,18 +109,6 @@ def test_permitted_range_categorical(self, desired_class, total_CFs, features_to
121109
permitted_range[feature][1] for i
122110
in range(total_CFs))
123111

124-
# Testing if an error is thrown when the query instance has an unknown categorical variable
125-
@pytest.mark.parametrize("desired_class, total_CFs, features_to_vary", [(0, 1, "all")])
126-
def test_query_instance_outside_bounds(self, desired_class, sample_custom_query_3, total_CFs, features_to_vary):
127-
with pytest.raises(ValueError):
128-
self.exp.setup(features_to_vary, None, sample_custom_query_3, "inverse_mad")
129-
130-
# Testing if an error is thrown when the query instance has an unknown categorical variable
131-
@pytest.mark.parametrize("features_to_vary", [("all")])
132-
def test_query_instance_unknown_column(self, sample_custom_query_5, features_to_vary):
133-
with pytest.raises(ValueError):
134-
self.exp.setup(features_to_vary, None, sample_custom_query_5, "inverse_mad")
135-
136112
# Testing if an error is thrown when the query instance has outcome variable
137113
def test_query_instance_with_target_column(self, sample_custom_query_6):
138114
with pytest.raises(ValueError) as ve:
@@ -167,16 +143,6 @@ class TestDiceGeneticMultiClassificationMethods:
167143
def _initiate_exp_object(self, genetic_multi_classification_exp_object):
168144
self.exp = genetic_multi_classification_exp_object # explainer object
169145

170-
# Testing that the counterfactuals are in the desired class
171-
@pytest.mark.parametrize("desired_class, total_CFs, initialization", [(2, 2, "kdtree"), (2, 2, "random")])
172-
def test_desired_class(self, desired_class, sample_custom_query_2, total_CFs, initialization):
173-
ans = self.exp.generate_counterfactuals(query_instances=sample_custom_query_2,
174-
total_CFs=total_CFs, desired_class=desired_class,
175-
initialization=initialization)
176-
for cfs_example in ans.cf_examples_list:
177-
assert all(
178-
cfs_example.final_cfs_df[self.exp.data_interface.outcome_name].values == [desired_class] * total_CFs)
179-
180146
# Testing if only valid cfs are found after maxiterations
181147
@pytest.mark.parametrize("desired_class, total_CFs, initialization, maxiterations",
182148
[(2, 7, "kdtree", 0), (2, 7, "random", 0)])

0 commit comments

Comments
 (0)