@@ -49,9 +49,11 @@ class CustomRuleLearner(_RuleLearner):
4949 name = 'Custom rule inducer'
5050 __returns__ = CustomRuleClassifier
5151
52- def __init__ (self , preprocessors , base_rules , params , widget ):
52+ def __init__ (self , preprocessors , base_rules , params ):
5353 super ().__init__ (preprocessors , base_rules )
54+ self .progress_advance_callback = None
5455 assert params is not None
56+ self .params = params
5557
5658 # top-level control procedure (rule ordering)
5759 self .rule_ordering = params ["Rule ordering" ]
@@ -92,8 +94,21 @@ def __init__(self, preprocessors, base_rules, params, widget):
9294 self .rule_finder .significance_validator .default_alpha = default_alpha
9395 self .rule_finder .significance_validator .parent_alpha = parent_alpha
9496
95- self .params = params
96- self .widget = widget
97+ def set_progress_advance_callback (self , f ):
98+ """
99+ Assign callback to update the corresponding widget's progress
100+ bar after each generated rule. Callback is used to ensure that
101+ the progress bar is always accessed correctly (additional
102+ widgets may however use the generated learner).
103+ """
104+ self .progress_advance_callback = f
105+
106+ def clear_progress_advance_callback (self ):
107+ """
108+ Make sure to clear the callback function immediately after the
109+ classifier is trained.
110+ """
111+ self .progress_advance_callback = None
97112
98113 def find_rules_and_measure_progress (self , X , Y , W , target_class ,
99114 base_rules , domain , progress_amount ):
@@ -119,7 +134,7 @@ def find_rules_and_measure_progress(self, X, Y, W, target_class,
119134 domain : Orange.data.domain.Domain
120135 Data domain, used to calculate class distributions.
121136 progress_amount: int, percentage
122- Amount of the learning algorithm covered by this function
137+ Part of the learning algorithm covered by this function
123138 call.
124139
125140 Returns
@@ -152,26 +167,24 @@ def find_rules_and_measure_progress(self, X, Y, W, target_class,
152167 rule_list .append (new_rule )
153168
154169 # update progress
155- progress = (((temp_class_dist [target_class ] -
156- get_dist (Y , W , domain )[target_class ])
157- / initial_class_dist [target_class ]
158- * progress_amount ) if target_class is not None else
159- ((temp_class_dist - get_dist (Y , W , domain )).sum ()
160- / initial_class_dist .sum () * progress_amount ))
161- self .widget .progressBarAdvance (progress )
170+ if self .progress_advance_callback is not None :
171+ progress = (((temp_class_dist [target_class ] -
172+ get_dist (Y , W , domain )[target_class ])
173+ / initial_class_dist [target_class ]
174+ * progress_amount ) if target_class is not None else
175+ ((temp_class_dist - get_dist (Y , W , domain )).sum ()
176+ / initial_class_dist .sum () * progress_amount ))
177+ self .progress_advance_callback (progress )
162178
163179 return rule_list
164180
165181 def fit (self , X , Y , W = None ):
166- # init & show progress bar
167- self .widget .progressBarInit ()
168-
169182 rule_list = []
170183 Y = Y .astype (dtype = int )
171184 if self .rule_ordering == "ordered" :
172185 rule_list = self .find_rules_and_measure_progress (
173186 X , Y , np .copy (W ) if W is not None else None , None ,
174- self .base_rules , self .domain , progress_amount = 100 )
187+ self .base_rules , self .domain , progress_amount = 1 )
175188 # add the default rule, if required
176189 if (not rule_list or rule_list and rule_list [- 1 ].length > 0 or
177190 self .covering_algorithm == "weighted" ):
@@ -180,13 +193,12 @@ def fit(self, X, Y, W=None):
180193 elif self .rule_ordering == "unordered" :
181194 for curr_class in range (len (self .domain .class_var .values )):
182195 rule_list .extend (self .find_rules_and_measure_progress (
183- X , Y , W , curr_class , self .base_rules , self .domain ,
184- progress_amount = 100 / len (self .domain .class_var .values )))
196+ X , Y , np .copy (W ) if W is not None else None ,
197+ curr_class , self .base_rules , self .domain ,
198+ progress_amount = 1 / len (self .domain .class_var .values )))
185199 # add the default rule
186200 rule_list .append (self .generate_default_rule (X , Y , W , self .domain ))
187201
188- # hide progress bar
189- self .widget .progressBarFinished ()
190202 return CustomRuleClassifier (domain = self .domain , rule_list = rule_list ,
191203 params = self .params )
192204
@@ -199,7 +211,7 @@ class OWRuleLearner(OWBaseLearner):
199211
200212 want_main_area = False
201213 resizing_enabled = False
202- auto_apply = Setting (False )
214+ auto_apply = Setting (True )
203215
204216 LEARNER = CustomRuleLearner
205217
@@ -301,12 +313,29 @@ def settings_changed(self, *args, **kwargs):
301313 self .storage_covers [self .covering_algorithm ] != "weighted" )
302314 super ().settings_changed (* args , ** kwargs )
303315
316+ def update_model (self ):
317+ """
318+ Ensure that the progress bar is updated only if the generated
319+ learner is used within this widget (for example, it must not be
320+ accessed from test&score widget).
321+ """
322+ if self .check_data ():
323+ with self .progressBar () as progress :
324+ self .learner .set_progress_advance_callback (progress .advance )
325+ self .model = self .learner (self .data )
326+ self .learner .clear_progress_advance_callback ()
327+ self .model .name = self .learner_name
328+ self .model .instances = self .data
329+ self .valid_data = True
330+ else :
331+ self .model = None
332+ self .send (self .OUTPUT_MODEL_NAME , self .model )
333+
304334 def create_learner (self ):
305335 return self .LEARNER (
306336 preprocessors = self .preprocessors ,
307337 base_rules = self .base_rules ,
308- params = self .get_learner_parameters (),
309- widget = self
338+ params = self .get_learner_parameters ()
310339 )
311340
312341 def get_learner_parameters (self ):
0 commit comments