@@ -197,6 +197,8 @@ class TextRule(Rule):
197197 consequent (list[str]): A list of consequent terms of the text rule.
198198 fitness (Optional[float]): Fitness value of the text rule.
199199 transactions (Optional[pandas.DataFrame]): The tf-idf matrix as a pandas DataFrame.
200+ threshold (Optional[float]): Threshold of tf-idf weights. If a weight is less than or equal to the
201+ threshold, the term is not included in the transaction. Default: 0.
200202
201203 Attributes:
202204 aws: The sum of tf-idf values for all the terms in the rule.
@@ -216,11 +218,19 @@ class TextRule(Rule):
216218 'comprehensibility' , 'netconf' , 'yulesq' , 'aws'
217219 )
218220
219- def __post_init__ (self , transactions ):
221+ def __init__ (self , antecedent , consequent , fitness = 0.0 , transactions = None , threshold = 0 ):
222+ super ().__init__ (antecedent , consequent , fitness , transactions = None )
223+
224+ if transactions is not None :
225+ self .num_transactions = len (transactions )
226+ self .__inclusion = (len (self .antecedent ) + len (self .consequent )) / len (transactions .columns )
227+ self .__post_init__ (transactions , threshold )
228+
229+ def __post_init__ (self , transactions , threshold = 0 ):
220230 self .__inclusion = (len (self .antecedent ) + len (self .consequent )) / len (transactions .columns )
221231 self .__aws = transactions [self .antecedent + self .consequent ].values .sum ()
222- contains_antecedent = (transactions [self .antecedent ] > 0 ).all (axis = 1 )
223- contains_consequent = (transactions [self .consequent ] > 0 ).all (axis = 1 )
232+ contains_antecedent = (transactions [self .antecedent ] > threshold ).all (axis = 1 )
233+ contains_consequent = (transactions [self .consequent ] > threshold ).all (axis = 1 )
224234 self .antecedent_count = contains_antecedent .sum ()
225235 self .consequent_count = contains_consequent .sum ()
226236 self .full_count = (contains_antecedent & contains_consequent ).sum ()
@@ -232,6 +242,10 @@ def __post_init__(self, transactions):
232242 def amplitude (self ):
233243 return np .nan
234244
245+ @property
246+ def inclusion (self ):
247+ return self .__inclusion
248+
235249 @property
236250 def aws (self ):
237251 return self .__aws
@@ -253,6 +267,8 @@ class NiaARTM(NiaARM):
253267 metrics (Union[Dict[str, float], Sequence[str]]): Metrics to take into account when computing the fitness.
254268 Metrics can either be passed as a Dict of pairs {'metric_name': <weight of metric>} or
255269 a sequence of metrics as strings, in which case, the weights of the metrics will be set to 1.
270+ threshold (Optional[float]): Threshold of tf-idf weights. If a weight is less than or equal to the
271+ threshold, the term is not included in the transaction. Default: 0.
256272 logging (bool): Enable logging of fitness improvements. Default: ``False``.
257273
258274 Attributes:
@@ -264,27 +280,36 @@ class NiaARTM(NiaARM):
264280 'support' , 'confidence' , 'coverage' , 'interestingness' , 'comprehensibility' , 'inclusion' , 'rhs_support' , 'aws'
265281 )
266282
267- def __init__ (self , max_terms , terms , transactions , metrics , logging = False ):
283+ def __init__ (self , max_terms , terms , transactions , metrics , threshold = 0 , logging = False ):
268284 super ().__init__ (max_terms + 1 , terms , transactions , metrics , logging )
269285 self .max_terms = max_terms
286+ self .threshold = threshold
270287
271288 def build_rule (self , vector ):
272- y = np .zeros (self .num_features , dtype = bool )
273- y [(vector * (self .num_features - 1 )).astype (int )] = True
274- return np .array (self .features )[y ].tolist ()
289+ terms = [self .features [int (val * (self .num_features - 1 ))] for val in vector ]
290+
291+ seen = set ()
292+ rule = []
293+ for term in terms :
294+ if term in seen :
295+ continue
296+ rule .append (term )
297+ seen .add (term )
298+
299+ return rule
275300
276- def _evaluate (self , sol ):
277- cut_value = sol [self .dimension - 1 ]
278- solution = sol [:- 1 ]
279- cut = _cut_point (cut_value , self .max_terms )
301+ def _evaluate (self , x ):
302+ cut_value = x [self .dimension - 1 ]
303+ solution = x [:- 1 ]
280304
281305 rule = self .build_rule (solution )
306+ cut = _cut_point (cut_value , len (rule ))
282307
283308 antecedent = rule [:cut ]
284309 consequent = rule [cut :]
285310
286311 if antecedent and consequent :
287- rule = TextRule (antecedent , consequent , transactions = self .transactions )
312+ rule = TextRule (antecedent , consequent , transactions = self .transactions , threshold = self . threshold )
288313 metrics = [getattr (rule , metric ) for metric in self .metrics ]
289314 fitness = np .dot (self .weights , metrics ) / self .sum_weights
290315 rule .fitness = fitness
0 commit comments