Skip to content

Commit db3d9a2

Browse files
committed
CN2 Rules: Add restriction to == for categorical variables
1 parent 5eb97b8 commit db3d9a2

File tree

3 files changed

+43
-15
lines changed

3 files changed

+43
-15
lines changed

Orange/classification/rules.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -435,10 +435,12 @@ class TopDownSearchStrategy(SearchStrategy):
435435
instances is developed. The hypothesis space of possible rules is
436436
then searched repeatedly by specialising candidate rules.
437437
"""
438-
def __init__(self, constrain_continuous=True, evaluate=True):
438+
def __init__(self, constrain_continuous=True, evaluate=True,
439+
restrict_equality=False):
439440
self.constrain_continuous = constrain_continuous
440441
self.storage = None
441442
self.evaluate = evaluate
443+
self.restrict_equality = restrict_equality
442444

443445
def initialise_rule(self, X, Y, W, target_class, base_rules, domain,
444446
initial_class_dist, prior_class_dist,
@@ -531,14 +533,15 @@ def find_new_selectors(self, X, Y, W, domain, existing_selectors):
531533

532534
possible_selectors = []
533535
# examine covered examples, for each variable
536+
disc_operators = ["=="] if self.restrict_equality else ["==", "!="]
534537
for i, attribute in enumerate(domain.attributes):
535538
# if discrete variable
536539
if attribute.is_discrete:
537540
# for each unique value, generate all possible selectors
538541
for val in np.unique(X[:, i]):
539-
s1 = Selector(column=i, op="==", value=val)
540-
s2 = Selector(column=i, op="!=", value=val)
541-
possible_selectors.extend([s1, s2])
542+
possible_selectors += (
543+
Selector(column=i, op=op, value=val)
544+
for op in disc_operators)
542545
# if continuous variable
543546
elif attribute.is_continuous:
544547
if X.shape[0] == 1:
@@ -914,7 +917,8 @@ class _RuleLearner(Learner):
914917
"""
915918
preprocessors = [RemoveNaNColumns(), HasClass(), Impute()]
916919

917-
def __init__(self, preprocessors=None, base_rules=None):
920+
def __init__(self, preprocessors=None, base_rules=None,
921+
*, restrict_equality=False):
918922
"""
919923
Constrain the search algorithm with a list of base rules.
920924
@@ -940,6 +944,7 @@ def __init__(self, preprocessors=None, base_rules=None):
940944
super().__init__(preprocessors=preprocessors)
941945
self.base_rules = base_rules if base_rules is not None else []
942946
self.rule_finder = RuleHunter()
947+
self.rule_finder.search_strategy.restrict_equality = restrict_equality
943948

944949
self.data_stopping = self.positive_remaining_data_stopping
945950
self.cover_and_remove = self.exclusive_cover_and_remove
@@ -1247,8 +1252,10 @@ class _BaseCN2Learner(_RuleLearner):
12471252
"""
12481253
def __init__(self, preprocessors=None, base_rules=None, beam_width=5,
12491254
constrain_continuous=True, min_covered_examples=1,
1250-
max_rule_length=5, default_alpha=1.0, parent_alpha=1.0):
1251-
super().__init__(preprocessors, base_rules)
1255+
max_rule_length=5, default_alpha=1.0, parent_alpha=1.0, *,
1256+
restrict_equality=False):
1257+
super().__init__(preprocessors, base_rules,
1258+
restrict_equality=restrict_equality)
12521259
rf = self.rule_finder
12531260
rf.search_algorithm.beam_width = beam_width
12541261
rf.search_strategy.constrain_continuous = constrain_continuous
@@ -1272,8 +1279,10 @@ class CN2Learner(_RuleLearner):
12721279
"The CN2 Induction Algorithm", Peter Clark and Tim Niblett, Machine
12731280
Learning Journal, 3 (4), pp261-283, (1989)
12741281
"""
1275-
def __init__(self, preprocessors=None, base_rules=None):
1276-
super().__init__(preprocessors, base_rules)
1282+
def __init__(self, preprocessors=None, base_rules=None,
1283+
*, restrict_equality=False):
1284+
super().__init__(preprocessors, base_rules,
1285+
restrict_equality=restrict_equality)
12771286
self.rule_finder.quality_evaluator = EntropyEvaluator()
12781287

12791288
def fit_storage(self, data):
@@ -1326,8 +1335,10 @@ class CN2UnorderedLearner(_RuleLearner):
13261335
"""
13271336
name = 'CN2 unordered inducer'
13281337

1329-
def __init__(self, preprocessors=None, base_rules=None):
1330-
super().__init__(preprocessors, base_rules)
1338+
def __init__(self, preprocessors=None, base_rules=None,
1339+
*, restrict_equality=False):
1340+
super().__init__(preprocessors, base_rules,
1341+
restrict_equality=restrict_equality)
13311342
self.rule_finder.quality_evaluator = LaplaceAccuracyEvaluator()
13321343

13331344
def fit_storage(self, data):
@@ -1392,8 +1403,10 @@ class CN2SDLearner(_RuleLearner):
13921403
"""
13931404
name = 'CN2-SD inducer'
13941405

1395-
def __init__(self, preprocessors=None, base_rules=None):
1396-
super().__init__(preprocessors, base_rules)
1406+
def __init__(self, preprocessors=None, base_rules=None,
1407+
*, restrict_equality=False):
1408+
super().__init__(preprocessors, base_rules,
1409+
restrict_equality=restrict_equality)
13971410
self.rule_finder.quality_evaluator = WeightedRelativeAccuracyEvaluator()
13981411
self.cover_and_remove = self.weighted_cover_and_remove
13991412
self.gamma = 0.7
@@ -1461,8 +1474,10 @@ class CN2SDUnorderedLearner(_RuleLearner):
14611474
"""
14621475
name = 'CN2-SD unordered inducer'
14631476

1464-
def __init__(self, preprocessors=None, base_rules=None):
1465-
super().__init__(preprocessors, base_rules)
1477+
def __init__(self, preprocessors=None, base_rules=None,
1478+
*, restrict_equality=False):
1479+
super().__init__(preprocessors, base_rules,
1480+
restrict_equality=restrict_equality)
14661481
self.rule_finder.quality_evaluator = WeightedRelativeAccuracyEvaluator()
14671482
self.cover_and_remove = self.weighted_cover_and_remove
14681483
self.gamma = 0.7

Orange/widgets/model/owrules.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(self, preprocessors, base_rules, params):
7373

7474
# bottom-level search procedure (search strategy)
7575
self.rule_finder.search_strategy.constrain_continuous = True
76+
self.rule_finder.search_strategy.restrict_equality = params["Restrict to equality"]
7677

7778
# bottom-level search procedure (search heuristics)
7879
evaluation_measure = params["Evaluation measure"]
@@ -229,6 +230,7 @@ class OWRuleLearner(OWBaseLearner):
229230
covering_algorithm = Setting(0)
230231
gamma = Setting(0.7)
231232
evaluation_measure = Setting(0)
233+
restrict_equality = Setting(False)
232234
beam_width = Setting(5)
233235
min_covered_examples = Setting(1)
234236
max_rule_length = Setting(5)
@@ -312,6 +314,12 @@ def add_main_layout(self):
312314
alignment=Qt.AlignRight, controlWidth=80,
313315
checked="checked_parent_alpha")
314316

317+
gui.checkBox(
318+
widget=bottom_box, master=self, value="restrict_equality",
319+
label="Restrict operator for categorical values to equality",
320+
callback=self.settings_changed,
321+
)
322+
315323
def settings_changed(self, *args, **kwargs):
316324
self.gamma_spin.setDisabled(self.covering_algorithm == 0)
317325
super().settings_changed(*args, **kwargs)
@@ -349,6 +357,7 @@ def get_learner_parameters(self):
349357
("Covering algorithm", self.storage_covers[self.covering_algorithm]),
350358
("Gamma", self.gamma),
351359
("Evaluation measure", self.storage_measures[self.evaluation_measure]),
360+
("Restrict to equality", self.restrict_equality),
352361
("Beam width", self.beam_width),
353362
("Minimum rule coverage", self.min_covered_examples),
354363
("Maximum rule length", self.max_rule_length),

i18n/si/msgs.jaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10372,6 +10372,7 @@ widgets/model/owrules.py:
1037210372
weighted: utežena
1037310373
Gamma: Gama
1037410374
Beam width: Širina snopa
10375+
Restrict to equality: Omejitev na enakost
1037510376
Evaluation measure: Ocena kvalitete pravila
1037610377
entropy: entropija
1037710378
laplace: Laplacova natančnost
@@ -10427,11 +10428,14 @@ widgets/model/owrules.py:
1042710428
parent_alpha: false
1042810429
Relative significance (parent α):: Relativna značilnost (predniko α):
1042910430
checked_parent_alpha: false
10431+
restrict_equality: false
10432+
Restrict operator for categorical values to equality: Omeji operator za kategorične vrednosti na enakost
1043010433
def `get_learner_parameters`:
1043110434
Rule ordering: Urejenost pravil
1043210435
Covering algorithm: Način prekrivanja
1043310436
Gamma: Gama
1043410437
Evaluation measure: Ocena kvalitete pravila
10438+
Restrict to equality: Omejitev na enakost
1043510439
Beam width: Širina snopa
1043610440
Minimum rule coverage: Najmanjše število pokritih primerov
1043710441
Maximum rule length: Največja dolžina pravila

0 commit comments

Comments
 (0)