@@ -110,10 +110,25 @@ class DataWithOracleGivensLC(GivensLC):
110110 a predictor oracle.
111111 """
112112
113- def process (self , data , oracle , * args , target_names = None , prelabel_data = True , ** kwargs ):
113+ def process (self , data , oracle , * args , target_names = None , prelabel_data = True , feature_groups = None , ** kwargs ):
114+ """
115+ Processing input for the explanation setting:
116+
117+ :param: data - the n-by-d feature matrix
118+ :param: oracle - the oracle, can either be a SKLearn-like classifier, in which case the predict function is
119+ used, or a function that takes data and outputs a prediction (in matrix form).
120+ :param: target_names - The names of the target classes/components.
121+ :param: prelabel_data - Whether to run the oracle on the data first, before building the model
122+ :param: feature_groups - Used in some explanation configurations - A list of lists of feature indeces
123+ representing semantically meaningful feature groups
124+ """
114125
115126 if target_names is not None :
116127 self .target_names = target_names
128+
129+ if feature_groups is not None :
130+ # TODO: Validation
131+ self .feature_groups = feature_groups
117132
118133 # Parse data
119134 self .data_matrix , self .feature_names , self .feature_spec = parse_data (
@@ -127,7 +142,7 @@ def process(self, data, oracle, *args, target_names=None, prelabel_data=True, **
127142 logger .info (
128143 'Inferring that oracle is a Scikit-Learn-like classifier '
129144 'and using the "predict" method.' )
130- self .oracle = lambda x : np .eye (len (oracle .classes_ ))[oracle .predict (x ),]
145+ self .oracle = lambda x : np .eye (len (oracle .classes_ ))[oracle .predict (x ). astype ( 'intp' ) ,]
131146 else :
132147 logger .info ('Treating oracle as a function' )
133148 self .oracle = oracle
0 commit comments