Skip to content

Commit c678871

Browse files
committed
[FIX] OWRuleLearner: Progress bar updates are now handled through a callback function.
Fixes a bug previously produced using test&score on windows machines. Other instances using the generated learner will no longer have affect on the widget's progress bar.
1 parent 037019f commit c678871

File tree

1 file changed

+51
-22
lines changed

1 file changed

+51
-22
lines changed

Orange/widgets/classify/owrules.py

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,11 @@ class CustomRuleLearner(_RuleLearner):
4949
name = 'Custom rule inducer'
5050
__returns__ = CustomRuleClassifier
5151

52-
def __init__(self, preprocessors, base_rules, params, widget):
52+
def __init__(self, preprocessors, base_rules, params):
5353
super().__init__(preprocessors, base_rules)
54+
self.progress_advance_callback = None
5455
assert params is not None
56+
self.params = params
5557

5658
# top-level control procedure (rule ordering)
5759
self.rule_ordering = params["Rule ordering"]
@@ -92,8 +94,21 @@ def __init__(self, preprocessors, base_rules, params, widget):
9294
self.rule_finder.significance_validator.default_alpha = default_alpha
9395
self.rule_finder.significance_validator.parent_alpha = parent_alpha
9496

95-
self.params = params
96-
self.widget = widget
97+
def set_progress_advance_callback(self, f):
98+
"""
99+
Assign callback to update the corresponding widget's progress
100+
bar after each generated rule. Callback is used to ensure that
101+
the progress bar is always accessed correctly (additional
102+
widgets may however use the generated learner).
103+
"""
104+
self.progress_advance_callback = f
105+
106+
def clear_progress_advance_callback(self):
107+
"""
108+
Make sure to clear the callback function immediately after the
109+
classifier is trained.
110+
"""
111+
self.progress_advance_callback = None
97112

98113
def find_rules_and_measure_progress(self, X, Y, W, target_class,
99114
base_rules, domain, progress_amount):
@@ -119,7 +134,7 @@ def find_rules_and_measure_progress(self, X, Y, W, target_class,
119134
domain : Orange.data.domain.Domain
120135
Data domain, used to calculate class distributions.
121136
progress_amount: int, percentage
122-
Amount of the learning algorithm covered by this function
137+
Part of the learning algorithm covered by this function
123138
call.
124139
125140
Returns
@@ -152,26 +167,24 @@ def find_rules_and_measure_progress(self, X, Y, W, target_class,
152167
rule_list.append(new_rule)
153168

154169
# update progress
155-
progress = (((temp_class_dist[target_class] -
156-
get_dist(Y, W, domain)[target_class])
157-
/ initial_class_dist[target_class]
158-
* progress_amount) if target_class is not None else
159-
((temp_class_dist - get_dist(Y, W, domain)).sum()
160-
/ initial_class_dist.sum() * progress_amount))
161-
self.widget.progressBarAdvance(progress)
170+
if self.progress_advance_callback is not None:
171+
progress = (((temp_class_dist[target_class] -
172+
get_dist(Y, W, domain)[target_class])
173+
/ initial_class_dist[target_class]
174+
* progress_amount) if target_class is not None else
175+
((temp_class_dist - get_dist(Y, W, domain)).sum()
176+
/ initial_class_dist.sum() * progress_amount))
177+
self.progress_advance_callback(progress)
162178

163179
return rule_list
164180

165181
def fit(self, X, Y, W=None):
166-
# init & show progress bar
167-
self.widget.progressBarInit()
168-
169182
rule_list = []
170183
Y = Y.astype(dtype=int)
171184
if self.rule_ordering == "ordered":
172185
rule_list = self.find_rules_and_measure_progress(
173186
X, Y, np.copy(W) if W is not None else None, None,
174-
self.base_rules, self.domain, progress_amount=100)
187+
self.base_rules, self.domain, progress_amount=1)
175188
# add the default rule, if required
176189
if (not rule_list or rule_list and rule_list[-1].length > 0 or
177190
self.covering_algorithm == "weighted"):
@@ -180,13 +193,12 @@ def fit(self, X, Y, W=None):
180193
elif self.rule_ordering == "unordered":
181194
for curr_class in range(len(self.domain.class_var.values)):
182195
rule_list.extend(self.find_rules_and_measure_progress(
183-
X, Y, W, curr_class, self.base_rules, self.domain,
184-
progress_amount=100/len(self.domain.class_var.values)))
196+
X, Y, np.copy(W) if W is not None else None,
197+
curr_class, self.base_rules, self.domain,
198+
progress_amount=1/len(self.domain.class_var.values)))
185199
# add the default rule
186200
rule_list.append(self.generate_default_rule(X, Y, W, self.domain))
187201

188-
# hide progress bar
189-
self.widget.progressBarFinished()
190202
return CustomRuleClassifier(domain=self.domain, rule_list=rule_list,
191203
params=self.params)
192204

@@ -199,7 +211,7 @@ class OWRuleLearner(OWBaseLearner):
199211

200212
want_main_area = False
201213
resizing_enabled = False
202-
auto_apply = Setting(False)
214+
auto_apply = Setting(True)
203215

204216
LEARNER = CustomRuleLearner
205217

@@ -301,12 +313,29 @@ def settings_changed(self, *args, **kwargs):
301313
self.storage_covers[self.covering_algorithm] != "weighted")
302314
super().settings_changed(*args, **kwargs)
303315

316+
def update_model(self):
317+
"""
318+
Ensure that the progress bar is updated only if the generated
319+
learner is used within this widget (for example, it must not be
320+
accessed from test&score widget).
321+
"""
322+
if self.check_data():
323+
with self.progressBar() as progress:
324+
self.learner.set_progress_advance_callback(progress.advance)
325+
self.model = self.learner(self.data)
326+
self.learner.clear_progress_advance_callback()
327+
self.model.name = self.learner_name
328+
self.model.instances = self.data
329+
self.valid_data = True
330+
else:
331+
self.model = None
332+
self.send(self.OUTPUT_MODEL_NAME, self.model)
333+
304334
def create_learner(self):
305335
return self.LEARNER(
306336
preprocessors=self.preprocessors,
307337
base_rules=self.base_rules,
308-
params=self.get_learner_parameters(),
309-
widget=self
338+
params=self.get_learner_parameters()
310339
)
311340

312341
def get_learner_parameters(self):

0 commit comments

Comments
 (0)