Skip to content

Commit f764074

Browse files
authored
Merge pull request #2395 from ales-erjavec/impute-widget-threading
[ENH] Impute widget: Parallel execution in the background
2 parents 6da33a0 + 0766f46 commit f764074

File tree

7 files changed

+444
-73
lines changed

7 files changed

+444
-73
lines changed

Orange/canvas/canvas/items/nodeitem.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,7 @@ def __init__(self, *args, **kwargs):
681681
super(NameTextItem, self).__init__(*args, **kwargs)
682682
self.__selected = False
683683
self.__palette = None
684+
self.__content = ""
684685

685686
def paint(self, painter, option, widget=None):
686687
if self.__selected:
@@ -743,6 +744,11 @@ def __updateDefaultTextColor(self):
743744
role = QPalette.WindowText
744745
self.setDefaultTextColor(self.palette().color(role))
745746

747+
def setHtml(self, contents):
748+
if contents != self.__content:
749+
self.__content = contents
750+
super().setHtml(contents)
751+
746752

747753
class NodeItem(QGraphicsObject):
748754
"""

Orange/preprocess/impute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ class DropInstances(BaseImputeMethod):
8282
description = ""
8383

8484
def __call__(self, data, variable):
85-
index = data.domain.index(variable)
86-
return np.isnan(data[:, index]).reshape(-1)
85+
col, _ = data.get_column_view(variable)
86+
return np.isnan(col)
8787

8888

8989
class Average(BaseImputeMethod):

Orange/widgets/data/owimpute.py

Lines changed: 162 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
import sys
22
import copy
3+
import logging
4+
import concurrent.futures
5+
from concurrent.futures import Future # pylint: disable=unused-import
6+
from collections import namedtuple
7+
from typing import List, Any # pylint: disable=unused-import
38

49
import numpy as np
510

@@ -8,13 +13,15 @@
813
QVBoxLayout, QStackedWidget, QComboBox,
914
QButtonGroup, QStyledItemDelegate, QListView, QDoubleSpinBox
1015
)
11-
from AnyQt.QtCore import Qt
16+
from AnyQt.QtCore import Qt, QThread
17+
from AnyQt.QtCore import pyqtSlot as Slot
1218

1319
import Orange.data
1420
from Orange.preprocess import impute
1521
from Orange.base import Learner
1622
from Orange.widgets import gui, settings
1723
from Orange.widgets.utils import itemmodels
24+
from Orange.widgets.utils import concurrent as qconcurrent
1825
from Orange.widgets.utils.sql import check_sql_input
1926
from Orange.widgets.widget import OWWidget, Msg, Input, Output
2027
from Orange.classification import SimpleTreeLearner
@@ -53,6 +60,32 @@ def __call__(self, *args, **kwargs):
5360
return self.method(*args, **kwargs)
5461

5562

63+
class SparseNotSupported(ValueError):
64+
pass
65+
66+
67+
class VariableNotSupported(ValueError):
68+
pass
69+
70+
71+
RowMask = namedtuple("RowMask", ["mask"])
72+
73+
74+
class Task:
75+
futures = [] # type: List[Future]
76+
watcher = ... # type: qconcurrent.FutureSetWatcher
77+
cancelled = False
78+
79+
def __init__(self, futures, watcher):
80+
self.futures = futures
81+
self.watcher = watcher
82+
83+
def cancel(self):
84+
self.cancelled = True
85+
for f in self.futures:
86+
f.cancel()
87+
88+
5689
class OWImpute(OWWidget):
5790
name = "Impute"
5891
description = "Impute missing values in the data table."
@@ -176,6 +209,8 @@ def __init__(self):
176209
self.learner = None
177210
self.modified = False
178211
self.default_method = self.methods[self.default_method_index]
212+
self.executor = qconcurrent.ThreadExecutor(self)
213+
self.__task = None
179214

180215
@property
181216
def default_method_index(self):
@@ -246,65 +281,137 @@ def get_method_for_column(self, column_index):
246281

247282
def _invalidate(self):
248283
self.modified = True
284+
if self.__task is not None:
285+
self.cancel()
249286
self.commit()
250287

251288
def commit(self):
289+
self.cancel()
290+
self.warning()
291+
self.Error.imputation_failed.clear()
292+
self.Error.model_based_imputer_sparse.clear()
293+
294+
if self.data is None or len(self.data) == 0 or len(self.varmodel) == 0:
295+
self.Outputs.data.send(self.data)
296+
self.modified = False
297+
return
298+
252299
data = self.data
300+
impute_state = [
301+
(i, var, self.variable_methods.get(i, self.default_method))
302+
for i, var in enumerate(self.varmodel)
303+
]
304+
305+
def impute_one(method, var, data):
306+
# type: (impute.BaseImputeMethod, Variable, Table) -> Any
307+
if isinstance(method, impute.Model) and data.is_sparse():
308+
raise SparseNotSupported()
309+
elif isinstance(method, impute.DropInstances):
310+
return RowMask(method(data, var))
311+
elif not method.supports_variable(var):
312+
raise VariableNotSupported(var)
313+
else:
314+
return method(data, var)
315+
316+
futures = []
317+
for _, var, method in impute_state:
318+
f = self.executor.submit(
319+
impute_one, copy.deepcopy(method), var, data)
320+
futures.append(f)
321+
322+
w = qconcurrent.FutureSetWatcher(futures)
323+
w.doneAll.connect(self.__commit_finish)
324+
w.progressChanged.connect(self.__progress_changed)
325+
self.__task = Task(futures, w)
326+
self.progressBarInit(processEvents=False)
327+
self.setBlocking(True)
328+
329+
@Slot()
330+
def __commit_finish(self):
331+
assert QThread.currentThread() is self.thread()
332+
assert self.__task is not None
333+
futures = self.__task.futures
334+
assert len(futures) == len(self.varmodel)
335+
assert self.data is not None
336+
337+
self.__task = None
338+
self.setBlocking(False)
339+
self.progressBarFinished()
253340

254-
if self.data is not None:
255-
if not len(self.data):
256-
self.Outputs.data.send(self.data)
257-
self.modified = False
258-
return
259-
260-
drop_mask = np.zeros(len(self.data), bool)
261-
262-
attributes = []
263-
class_vars = []
264-
265-
self.warning()
266-
self.Error.imputation_failed.clear()
267-
self.Error.model_based_imputer_sparse.clear()
268-
with self.progressBar(len(self.varmodel)) as progress:
269-
for i, var in enumerate(self.varmodel):
270-
method = self.variable_methods.get(i, self.default_method)
271-
if isinstance(method, impute.Model) and data.is_sparse():
272-
self.Error.model_based_imputer_sparse()
273-
continue
274-
275-
try:
276-
if not method.supports_variable(var):
277-
self.warning("Default method can not handle '{}'".
278-
format(var.name))
279-
elif isinstance(method, impute.DropInstances):
280-
drop_mask |= method(self.data, var)
281-
else:
282-
var = method(self.data, var)
283-
except Exception: # pylint: disable=broad-except
284-
self.Error.imputation_failed(var.name)
285-
attributes = class_vars = None
286-
break
287-
288-
if isinstance(var, Orange.data.Variable):
289-
var = [var]
290-
291-
if i < len(self.data.domain.attributes):
292-
attributes.extend(var)
293-
else:
294-
class_vars.extend(var)
295-
296-
progress.advance()
297-
298-
if attributes is None:
299-
data = None
341+
data = self.data
342+
attributes = []
343+
class_vars = []
344+
drop_mask = np.zeros(len(self.data), bool)
345+
346+
for i, (var, fut) in enumerate(zip(self.varmodel, futures)):
347+
assert fut.done()
348+
newvar = []
349+
try:
350+
res = fut.result()
351+
except SparseNotSupported:
352+
self.Error.model_based_imputer_sparse()
353+
# ?? break
354+
except VariableNotSupported:
355+
self.warning("Default method can not handle '{}'".
356+
format(var.name))
357+
except Exception: # pylint: disable=broad-except
358+
log = logging.getLogger(__name__)
359+
log.info("Error for %s", var, exc_info=True)
360+
self.Error.imputation_failed(var.name)
361+
attributes = class_vars = None
362+
break
300363
else:
301-
domain = Orange.data.Domain(attributes, class_vars,
302-
self.data.domain.metas)
303-
data = self.data.from_table(domain, self.data[~drop_mask])
364+
if isinstance(res, RowMask):
365+
drop_mask |= res.mask
366+
newvar = var
367+
else:
368+
newvar = res
369+
370+
if isinstance(newvar, Orange.data.Variable):
371+
newvar = [newvar]
372+
373+
if i < len(data.domain.attributes):
374+
attributes.extend(newvar)
375+
else:
376+
class_vars.extend(newvar)
377+
378+
if attributes is None:
379+
data = None
380+
else:
381+
domain = Orange.data.Domain(attributes, class_vars,
382+
data.domain.metas)
383+
try:
384+
data = self.data.from_table(domain, data[~drop_mask])
385+
except Exception: # pylint: disable=broad-except
386+
log = logging.getLogger(__name__)
387+
log.info("Error", exc_info=True)
388+
self.Error.imputation_failed("Unknown")
389+
data = None
304390

305391
self.Outputs.data.send(data)
306392
self.modified = False
307393

394+
@Slot(int, int)
395+
def __progress_changed(self, n, d):
396+
assert QThread.currentThread() is self.thread()
397+
assert self.__task is not None
398+
self.progressBarSet(100. * n / d)
399+
400+
def cancel(self):
401+
if self.__task is not None:
402+
task, self.__task = self.__task, None
403+
task.cancel()
404+
task.watcher.doneAll.disconnect(self.__commit_finish)
405+
task.watcher.progressChanged.disconnect(self.__progress_changed)
406+
concurrent.futures.wait(task.futures)
407+
task.watcher.flush()
408+
self.progressBarFinished()
409+
self.setBlocking(False)
410+
411+
def onDeleteWidget(self):
412+
self.cancel()
413+
super().onDeleteWidget()
414+
308415
def send_report(self):
309416
specific = []
310417
for i, var in enumerate(self.varmodel):
@@ -410,7 +517,7 @@ def set_method_for_indexes(self, indexes, method_index):
410517
method = impute.Default(default=value)
411518
self.variable_methods[index.row()] = method
412519
else:
413-
method = self.methods[method_index].copy()
520+
method = self.methods[method_index]
414521
for index in indexes:
415522
self.variable_methods[index.row()] = method
416523

@@ -435,9 +542,10 @@ def reset_variable_methods(self):
435542
self.variable_button_group.button(self.DEFAULT).setChecked(True)
436543

437544

438-
def main(argv=sys.argv):
545+
def main(argv=None):
439546
from AnyQt.QtWidgets import QApplication
440-
app = QApplication(list(argv))
547+
logging.basicConfig()
548+
app = QApplication(list(argv) if argv else [])
441549
argv = app.arguments()
442550
if len(argv) > 1:
443551
filename = argv[1]
@@ -459,4 +567,4 @@ def main(argv=sys.argv):
459567
return 0
460568

461569
if __name__ == "__main__":
462-
sys.exit(main())
570+
sys.exit(main(sys.argv))

0 commit comments

Comments
 (0)