1010from Orange .data import Table , Storage , Instance , Value
1111from Orange .data .filter import HasClass
1212from Orange .data .table import DomainTransformationError
13- from Orange .data .util import one_hot , progress_callback , dummy_callback
13+ from Orange .data .util import one_hot
1414from Orange .misc .wrapper_meta import WrapperMeta
1515from Orange .preprocess import Continuize , RemoveNaNColumns , SklImpute , Normalize
1616from Orange .statistics .util import all_nan
17- from Orange .util import Reprable , OrangeDeprecationWarning
17+ from Orange .util import Reprable , OrangeDeprecationWarning , wrap_callback , \
18+ dummy_callback
1819
1920__all__ = ["Learner" , "Model" , "SklLearner" , "SklModel" ,
2021 "ReprableWithPreprocessors" ]
@@ -102,7 +103,7 @@ def fit_storage(self, data):
102103 X , Y , W = data .X , data .Y , data .W if data .has_weights () else None
103104 return self .fit (X , Y , W )
104105
105- def __call__ (self , data , callback = None ):
106+ def __call__ (self , data , progress_callback = None ):
106107 if not self .check_learner_adequacy (data .domain ):
107108 raise ValueError (self .learner_adequacy_err_msg )
108109
@@ -112,33 +113,33 @@ def __call__(self, data, callback=None):
112113 data = Table (data .domain , [data ])
113114 origdata = data
114115
115- if callback is None :
116- callback = dummy_callback
117- callback (0 , "Preprocessing..." )
116+ if progress_callback is None :
117+ progress_callback = dummy_callback
118+ progress_callback (0 , "Preprocessing..." )
118119 try :
119- cb = progress_callback ( callback , end = 0.1 )
120- data = self .preprocess (data , callback = cb )
120+ cb = wrap_callback ( progress_callback , end = 0.1 )
121+ data = self .preprocess (data , progress_callback = cb )
121122 except TypeError :
122123 data = self .preprocess (data )
123- warnings .warn ("A keyword argument 'callback ' has been added to the "
124- " preprocess() signature. Implementing the method "
125- "without the argument is deprecated and will result "
126- "in an error in the future." ,
124+ warnings .warn ("A keyword argument 'progress_callback ' has been "
125+ "added to the preprocess() signature. Implementing "
126+ "the method without the argument is deprecated and "
127+ "will result in an error in the future." ,
127128 OrangeDeprecationWarning )
128129
129130 if len (data .domain .class_vars ) > 1 and not self .supports_multiclass :
130131 raise TypeError ("%s doesn't support multiple class variables" %
131132 self .__class__ .__name__ )
132133
133- callback (0.1 , "Fitting..." )
134+ progress_callback (0.1 , "Fitting..." )
134135 model = self ._fit_model (data )
135136 model .used_vals = [np .unique (y ).astype (int ) for y in data .Y [:, None ].T ]
136137 model .domain = data .domain
137138 model .supports_multiclass = self .supports_multiclass
138139 model .name = self .name
139140 model .original_domain = origdomain
140141 model .original_data = origdata
141- callback (1 )
142+ progress_callback (1 )
142143 return model
143144
144145 def _fit_model (self , data ):
@@ -148,15 +149,15 @@ def _fit_model(self, data):
148149 X , Y , W = data .X , data .Y , data .W if data .has_weights () else None
149150 return self .fit (X , Y , W )
150151
151- def preprocess (self , data , callback = None ):
152+ def preprocess (self , data , progress_callback = None ):
152153 """Apply the `preprocessors` to the data"""
153- if callback is None :
154- callback = dummy_callback
154+ if progress_callback is None :
155+ progress_callback = dummy_callback
155156 n_pps = len (list (self .active_preprocessors ))
156157 for i , pp in enumerate (self .active_preprocessors ):
157- callback (i / n_pps )
158+ progress_callback (i / n_pps )
158159 data = pp (data )
159- callback (1 )
160+ progress_callback (1 )
160161 return data
161162
162163 @property
@@ -489,8 +490,8 @@ def _get_sklparams(self, values):
489490 raise TypeError ("Wrapper does not define '__wraps__'" )
490491 return params
491492
492- def preprocess (self , data , callback = None ):
493- data = super ().preprocess (data , callback )
493+ def preprocess (self , data , progress_callback = None ):
494+ data = super ().preprocess (data , progress_callback )
494495
495496 if any (v .is_discrete and len (v .values ) > 2
496497 for v in data .domain .attributes ):
@@ -499,8 +500,8 @@ def preprocess(self, data, callback=None):
499500
500501 return data
501502
502- def __call__ (self , data , callback = None ):
503- m = super ().__call__ (data , callback )
503+ def __call__ (self , data , progress_callback = None ):
504+ m = super ().__call__ (data , progress_callback )
504505 m .params = self .params
505506 return m
506507
0 commit comments