@@ -82,21 +82,22 @@ def generate_counterfactuals(self, query_instances, total_CFs,
8282 raise UserConfigValidationException (
8383 "The number of counterfactuals generated per query instance (total_CFs) should be a positive integer." )
8484 if total_CFs > 10 :
85- if posthoc_sparsity_algorithm == None :
85+ if posthoc_sparsity_algorithm is None :
8686 posthoc_sparsity_algorithm = 'binary'
87- elif total_CFs > 50 and posthoc_sparsity_algorithm == 'linear' :
87+ elif total_CFs > 50 and posthoc_sparsity_algorithm == 'linear' :
8888 import warnings
89- warnings .warn ("The number of counterfactuals (total_CFs={}) generated per query instance could take much time; "
90- "if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to "
91- "'binary' search!" .format (total_CFs ))
92- elif posthoc_sparsity_algorithm == None :
89+ warnings .warn (
90+ "The number of counterfactuals (total_CFs={}) generated per query instance could take much time; "
91+ "if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to "
92+ "'binary' search!" .format (total_CFs ))
93+ elif posthoc_sparsity_algorithm is None :
9394 posthoc_sparsity_algorithm = 'linear'
9495
9596 cf_examples_arr = []
9697 query_instances_list = []
9798 if isinstance (query_instances , pd .DataFrame ):
9899 for ix in range (query_instances .shape [0 ]):
99- query_instances_list .append (query_instances [ix :(ix + 1 )])
100+ query_instances_list .append (query_instances [ix :(ix + 1 )])
100101 elif isinstance (query_instances , Iterable ):
101102 query_instances_list = query_instances
102103
@@ -190,11 +191,14 @@ def check_query_instance_validity(self, features_to_vary, permitted_range, query
190191
191192 if feature not in features_to_vary and permitted_range is not None :
192193 if feature in permitted_range and feature in self .data_interface .continuous_feature_names :
193- if not permitted_range [feature ][0 ] <= query_instance [feature ].values [0 ] <= permitted_range [feature ][1 ]:
194- raise ValueError ("Feature:" , feature , "is outside the permitted range and isn't allowed to vary." )
194+ if not permitted_range [feature ][0 ] <= query_instance [feature ].values [0 ] <= permitted_range [feature ][
195+ 1 ]:
196+ raise ValueError ("Feature:" , feature ,
197+ "is outside the permitted range and isn't allowed to vary." )
195198 elif feature in permitted_range and feature in self .data_interface .categorical_feature_names :
196199 if query_instance [feature ].values [0 ] not in self .feature_range [feature ]:
197- raise ValueError ("Feature:" , feature , "is outside the permitted range and isn't allowed to vary." )
200+ raise ValueError ("Feature:" , feature ,
201+ "is outside the permitted range and isn't allowed to vary." )
198202
199203 def local_feature_importance (self , query_instances , cf_examples_list = None ,
200204 total_CFs = 10 ,
@@ -440,12 +444,13 @@ def do_posthoc_sparsity_enhancement(self, final_cfs_sparse, query_instance, post
440444 cfs_preds_sparse = []
441445
442446 for cf_ix in list (final_cfs_sparse .index ):
443- current_pred = self .predict_fn_for_sparsity (final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
447+ current_pred = self .predict_fn_for_sparsity (
448+ final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
444449 for feature in features_sorted :
445450 # current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.iat[[cf_ix]][self.data_interface.feature_names])
446451 # feat_ix = self.data_interface.continuous_feature_names.index(feature)
447452 diff = query_instance [feature ].iat [0 ] - int (final_cfs_sparse .at [cf_ix , feature ])
448- if (abs (diff ) <= quantiles [feature ]):
453+ if (abs (diff ) <= quantiles [feature ]):
449454 if posthoc_sparsity_algorithm == "linear" :
450455 final_cfs_sparse = self .do_linear_search (diff , decimal_prec , query_instance , cf_ix ,
451456 feature , final_cfs_sparse , current_pred )
@@ -466,13 +471,14 @@ def do_linear_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
466471 query_instance greedily until the prediction class changes."""
467472
468473 old_diff = diff
469- change = (10 ** - decimal_prec [feature ]) # the minimal possible change for a feature
474+ change = (10 ** - decimal_prec [feature ]) # the minimal possible change for a feature
470475 current_pred = current_pred_orig
471476 if self .model .model_type == ModelTypes .Classifier :
472- while ((abs (diff ) > 10e-4 ) and (np .sign (diff * old_diff ) > 0 ) and self .is_cf_valid (current_pred )):
477+ while ((abs (diff ) > 10e-4 ) and (np .sign (diff * old_diff ) > 0 ) and self .is_cf_valid (current_pred )):
473478 old_val = int (final_cfs_sparse .at [cf_ix , feature ])
474- final_cfs_sparse .at [cf_ix , feature ] += np .sign (diff )* change
475- current_pred = self .predict_fn_for_sparsity (final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
479+ final_cfs_sparse .at [cf_ix , feature ] += np .sign (diff ) * change
480+ current_pred = self .predict_fn_for_sparsity (
481+ final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
476482 old_diff = diff
477483
478484 if not self .is_cf_valid (current_pred ):
@@ -505,11 +511,12 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
505511 right = query_instance [feature ].iat [0 ]
506512
507513 while left <= right :
508- current_val = left + ((right - left )/ 2 )
514+ current_val = left + ((right - left ) / 2 )
509515 current_val = round (current_val , decimal_prec [feature ])
510516
511517 final_cfs_sparse .at [cf_ix , feature ] = current_val
512- current_pred = self .predict_fn_for_sparsity (final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
518+ current_pred = self .predict_fn_for_sparsity (
519+ final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
513520
514521 if current_val == right or current_val == left :
515522 break
@@ -524,19 +531,20 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
524531 right = int (final_cfs_sparse .at [cf_ix , feature ])
525532
526533 while right >= left :
527- current_val = right - ((right - left )/ 2 )
534+ current_val = right - ((right - left ) / 2 )
528535 current_val = round (current_val , decimal_prec [feature ])
529536
530537 final_cfs_sparse .at [cf_ix , feature ] = current_val
531- current_pred = self .predict_fn_for_sparsity (final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
538+ current_pred = self .predict_fn_for_sparsity (
539+ final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
532540
533541 if current_val == right or current_val == left :
534542 break
535543
536544 if self .is_cf_valid (current_pred ):
537- right = current_val - (10 ** - decimal_prec [feature ])
545+ right = current_val - (10 ** - decimal_prec [feature ])
538546 else :
539- left = current_val + (10 ** - decimal_prec [feature ])
547+ left = current_val + (10 ** - decimal_prec [feature ])
540548
541549 return final_cfs_sparse
542550
@@ -578,7 +586,7 @@ def infer_target_cfs_class(self, desired_class_input, original_pred, num_output_
578586 raise UserConfigValidationException ("Desired class not present in training data!" )
579587 else :
580588 raise UserConfigValidationException ("The target class for {0} could not be identified" .format (
581- desired_class_input ))
589+ desired_class_input ))
582590
583591 def infer_target_cfs_range (self , desired_range_input ):
584592 target_range = None
@@ -597,7 +605,7 @@ def decide_cf_validity(self, model_outputs):
597605 pred = model_outputs [i ]
598606 if self .model .model_type == ModelTypes .Classifier :
599607 if self .num_output_nodes == 2 : # binary
600- pred_1 = pred [self .num_output_nodes - 1 ]
608+ pred_1 = pred [self .num_output_nodes - 1 ]
601609 validity [i ] = 1 if \
602610 ((self .target_cf_class == 0 and pred_1 <= self .stopping_threshold ) or
603611 (self .target_cf_class == 1 and pred_1 >= self .stopping_threshold )) else 0
@@ -634,7 +642,7 @@ def is_cf_valid(self, model_score):
634642 (target_cf_class == 1 and pred_1 >= self .stopping_threshold )) else False
635643 return validity
636644 if self .num_output_nodes == 2 : # binary
637- pred_1 = model_score [self .num_output_nodes - 1 ]
645+ pred_1 = model_score [self .num_output_nodes - 1 ]
638646 validity = True if \
639647 ((target_cf_class == 0 and pred_1 <= self .stopping_threshold ) or
640648 (target_cf_class == 1 and pred_1 >= self .stopping_threshold )) else False
@@ -710,7 +718,8 @@ def round_to_precision(self):
710718 for ix , feature in enumerate (self .data_interface .continuous_feature_names ):
711719 self .final_cfs_df [feature ] = self .final_cfs_df [feature ].astype (float ).round (precisions [ix ])
712720 if self .final_cfs_df_sparse is not None :
713- self .final_cfs_df_sparse [feature ] = self .final_cfs_df_sparse [feature ].astype (float ).round (precisions [ix ])
721+ self .final_cfs_df_sparse [feature ] = self .final_cfs_df_sparse [feature ].astype (float ).round (
722+ precisions [ix ])
714723
715724 def _check_any_counterfactuals_computed (self , cf_examples_arr ):
716725 """Check if any counterfactuals were generated for any query point."""
0 commit comments