@@ -435,10 +435,12 @@ class TopDownSearchStrategy(SearchStrategy):
435435 instances is developed. The hypothesis space of possible rules is
436436 then searched repeatedly by specialising candidate rules.
437437 """
438- def __init__ (self , constrain_continuous = True , evaluate = True ):
438+ def __init__ (self , constrain_continuous = True , evaluate = True ,
439+ restrict_equality = False ):
439440 self .constrain_continuous = constrain_continuous
440441 self .storage = None
441442 self .evaluate = evaluate
443+ self .restrict_equality = restrict_equality
442444
443445 def initialise_rule (self , X , Y , W , target_class , base_rules , domain ,
444446 initial_class_dist , prior_class_dist ,
@@ -531,14 +533,15 @@ def find_new_selectors(self, X, Y, W, domain, existing_selectors):
531533
532534 possible_selectors = []
533535 # examine covered examples, for each variable
536+ disc_operators = ["==" ] if self .restrict_equality else ["==" , "!=" ]
534537 for i , attribute in enumerate (domain .attributes ):
535538 # if discrete variable
536539 if attribute .is_discrete :
537540 # for each unique value, generate all possible selectors
538541 for val in np .unique (X [:, i ]):
539- s1 = Selector ( column = i , op = "==" , value = val )
540- s2 = Selector (column = i , op = "!=" , value = val )
541- possible_selectors . extend ([ s1 , s2 ] )
542+ possible_selectors += (
543+ Selector (column = i , op = op , value = val )
544+ for op in disc_operators )
542545 # if continuous variable
543546 elif attribute .is_continuous :
544547 if X .shape [0 ] == 1 :
@@ -914,7 +917,8 @@ class _RuleLearner(Learner):
914917 """
915918 preprocessors = [RemoveNaNColumns (), HasClass (), Impute ()]
916919
917- def __init__ (self , preprocessors = None , base_rules = None ):
920+ def __init__ (self , preprocessors = None , base_rules = None ,
921+ * , restrict_equality = False ):
918922 """
919923 Constrain the search algorithm with a list of base rules.
920924
@@ -940,6 +944,7 @@ def __init__(self, preprocessors=None, base_rules=None):
940944 super ().__init__ (preprocessors = preprocessors )
941945 self .base_rules = base_rules if base_rules is not None else []
942946 self .rule_finder = RuleHunter ()
947+ self .rule_finder .search_strategy .restrict_equality = restrict_equality
943948
944949 self .data_stopping = self .positive_remaining_data_stopping
945950 self .cover_and_remove = self .exclusive_cover_and_remove
@@ -1247,8 +1252,10 @@ class _BaseCN2Learner(_RuleLearner):
12471252 """
12481253 def __init__ (self , preprocessors = None , base_rules = None , beam_width = 5 ,
12491254 constrain_continuous = True , min_covered_examples = 1 ,
1250- max_rule_length = 5 , default_alpha = 1.0 , parent_alpha = 1.0 ):
1251- super ().__init__ (preprocessors , base_rules )
1255+ max_rule_length = 5 , default_alpha = 1.0 , parent_alpha = 1.0 , * ,
1256+ restrict_equality = False ):
1257+ super ().__init__ (preprocessors , base_rules ,
1258+ restrict_equality = restrict_equality )
12521259 rf = self .rule_finder
12531260 rf .search_algorithm .beam_width = beam_width
12541261 rf .search_strategy .constrain_continuous = constrain_continuous
@@ -1272,8 +1279,10 @@ class CN2Learner(_RuleLearner):
12721279 "The CN2 Induction Algorithm", Peter Clark and Tim Niblett, Machine
12731280 Learning Journal, 3 (4), pp261-283, (1989)
12741281 """
1275- def __init__ (self , preprocessors = None , base_rules = None ):
1276- super ().__init__ (preprocessors , base_rules )
1282+ def __init__ (self , preprocessors = None , base_rules = None ,
1283+ * , restrict_equality = False ):
1284+ super ().__init__ (preprocessors , base_rules ,
1285+ restrict_equality = restrict_equality )
12771286 self .rule_finder .quality_evaluator = EntropyEvaluator ()
12781287
12791288 def fit_storage (self , data ):
@@ -1326,8 +1335,10 @@ class CN2UnorderedLearner(_RuleLearner):
13261335 """
13271336 name = 'CN2 unordered inducer'
13281337
1329- def __init__ (self , preprocessors = None , base_rules = None ):
1330- super ().__init__ (preprocessors , base_rules )
1338+ def __init__ (self , preprocessors = None , base_rules = None ,
1339+ * , restrict_equality = False ):
1340+ super ().__init__ (preprocessors , base_rules ,
1341+ restrict_equality = restrict_equality )
13311342 self .rule_finder .quality_evaluator = LaplaceAccuracyEvaluator ()
13321343
13331344 def fit_storage (self , data ):
@@ -1392,8 +1403,10 @@ class CN2SDLearner(_RuleLearner):
13921403 """
13931404 name = 'CN2-SD inducer'
13941405
1395- def __init__ (self , preprocessors = None , base_rules = None ):
1396- super ().__init__ (preprocessors , base_rules )
1406+ def __init__ (self , preprocessors = None , base_rules = None ,
1407+ * , restrict_equality = False ):
1408+ super ().__init__ (preprocessors , base_rules ,
1409+ restrict_equality = restrict_equality )
13971410 self .rule_finder .quality_evaluator = WeightedRelativeAccuracyEvaluator ()
13981411 self .cover_and_remove = self .weighted_cover_and_remove
13991412 self .gamma = 0.7
@@ -1461,8 +1474,10 @@ class CN2SDUnorderedLearner(_RuleLearner):
14611474 """
14621475 name = 'CN2-SD unordered inducer'
14631476
1464- def __init__ (self , preprocessors = None , base_rules = None ):
1465- super ().__init__ (preprocessors , base_rules )
1477+ def __init__ (self , preprocessors = None , base_rules = None ,
1478+ * , restrict_equality = False ):
1479+ super ().__init__ (preprocessors , base_rules ,
1480+ restrict_equality = restrict_equality )
14661481 self .rule_finder .quality_evaluator = WeightedRelativeAccuracyEvaluator ()
14671482 self .cover_and_remove = self .weighted_cover_and_remove
14681483 self .gamma = 0.7
0 commit comments