22import numpy as np
33import dice_ml
44from dice_ml .utils import helpers
5- from dice_ml .utils .exception import UserConfigValidationException
65from dice_ml .diverse_counterfactuals import CounterfactualExamples
76from dice_ml .counterfactual_explanations import CounterfactualExplanations
87
@@ -46,17 +45,6 @@ def _initiate_exp_object(self, KD_binary_classification_exp_object):
4645 self .exp = KD_binary_classification_exp_object # explainer object
4746 self .data_df_copy = self .exp .data_interface .data_df .copy ()
4847
49- # When no elements in the desired_class are present in the training data
50- @pytest .mark .parametrize ("desired_class, total_CFs" , [(1 , 3 ), ('a' , 3 )])
51- def test_unsupported_binary_class (self , desired_class , sample_custom_query_1 , total_CFs ):
52- with pytest .raises (UserConfigValidationException ) as ucve :
53- self .exp ._generate_counterfactuals (query_instance = sample_custom_query_1 , total_CFs = total_CFs ,
54- desired_class = desired_class )
55- if desired_class == 1 :
56- assert "Desired class not present in training data!" in str (ucve )
57- else :
58- assert "The target class for {0} could not be identified" .format (desired_class ) in str (ucve )
59-
6048 # When a query's feature value is not within the permitted range and the feature is not allowed to vary
6149 @pytest .mark .parametrize ("desired_range, desired_class, total_CFs, features_to_vary, permitted_range" ,
6250 [(None , 0 , 4 , ['Numerical' ], {'Categorical' : ['b' , 'c' ]})])
@@ -119,20 +107,6 @@ def test_permitted_range_categorical(self, desired_class, sample_custom_query_2,
119107 total_CFs = total_CFs , permitted_range = permitted_range )
120108 assert all (i in permitted_range ["Categorical" ] for i in self .exp .final_cfs_df .Categorical .values )
121109
122- # Testing if an error is thrown when the query instance has an unknown categorical variable
123- @pytest .mark .parametrize ("desired_class, total_CFs" , [(0 , 1 )])
124- def test_query_instance_outside_bounds (self , desired_class , sample_custom_query_3 , total_CFs ):
125- with pytest .raises (ValueError ):
126- self .exp ._generate_counterfactuals (query_instance = sample_custom_query_3 , total_CFs = total_CFs ,
127- desired_class = desired_class )
128-
129- # Testing if an error is thrown when the query instance has an unknown column
130- @pytest .mark .parametrize ("desired_class, total_CFs" , [(0 , 1 )])
131- def test_query_instance_unknown_column (self , desired_class , sample_custom_query_5 , total_CFs ):
132- with pytest .raises (ValueError ):
133- self .exp ._generate_counterfactuals (query_instance = sample_custom_query_5 , total_CFs = total_CFs ,
134- desired_class = desired_class )
135-
136110 # Ensuring that there are no duplicates in the resulting counterfactuals even if the dataset has duplicates
137111 @pytest .mark .parametrize ("desired_class, total_CFs" , [(0 , 2 )])
138112 def test_duplicates (self , desired_class , sample_custom_query_4 , total_CFs ):
@@ -147,12 +121,6 @@ def test_duplicates(self, desired_class, sample_custom_query_4, total_CFs):
147121
148122 assert all (self .exp .final_cfs_df == expected_output )
149123
150- # Testing for 0 CFs needed
151- @pytest .mark .parametrize ("desired_class, total_CFs" , [(0 , 0 )])
152- def test_zero_cfs (self , desired_class , sample_custom_query_4 , total_CFs ):
153- self .exp ._generate_counterfactuals (query_instance = sample_custom_query_4 , total_CFs = total_CFs ,
154- desired_class = desired_class )
155-
156124 # Testing for index returned
157125 @pytest .mark .parametrize ("desired_class, total_CFs" , [(0 , 1 )])
158126 @pytest .mark .parametrize ('posthoc_sparsity_algorithm' , ['linear' , 'binary' , None ])
@@ -179,33 +147,6 @@ def test_KD_tree_output(self, desired_class, sample_custom_query_2, total_CFs,
179147 posthoc_sparsity_algorithm = posthoc_sparsity_algorithm )
180148 assert all (i == desired_class for i in self .exp_multi .cfs_preds )
181149
182- # Testing that the output of multiclass classification lies in the desired_class
183- @pytest .mark .parametrize ("desired_class, total_CFs" , [(2 , 3 )])
184- def test_KD_tree_counterfactual_explanations_output (self , desired_class , sample_custom_query_2 , total_CFs ):
185- counterfactual_explanations = self .exp_multi .generate_counterfactuals (
186- query_instances = sample_custom_query_2 , total_CFs = total_CFs ,
187- desired_class = desired_class )
188- assert all (i == desired_class for i in self .exp_multi .cfs_preds )
189-
190- assert counterfactual_explanations is not None
191-
192- # Testing for 0 CFs needed
193- @pytest .mark .parametrize ("desired_class, total_CFs" , [(0 , 0 )])
194- def test_zero_cfs (self , desired_class , sample_custom_query_4 , total_CFs ):
195- self .exp_multi ._generate_counterfactuals (query_instance = sample_custom_query_4 , total_CFs = total_CFs ,
196- desired_class = desired_class )
197-
198- # When no elements in the desired_class are present in the training data
199- @pytest .mark .parametrize ("desired_class, total_CFs" , [(100 , 3 ), ('opposite' , 3 )])
200- def test_unsupported_multiclass (self , desired_class , sample_custom_query_4 , total_CFs ):
201- with pytest .raises (UserConfigValidationException ) as ucve :
202- self .exp_multi ._generate_counterfactuals (query_instance = sample_custom_query_4 , total_CFs = total_CFs ,
203- desired_class = desired_class )
204- if desired_class == 100 :
205- assert "Desired class not present in training data!" in str (ucve )
206- else :
207- assert "Desired class cannot be opposite if the number of classes is more than 2." in str (ucve )
208-
209150
210151class TestDiceKDRegressionMethods :
211152 @pytest .fixture (autouse = True )
0 commit comments