Skip to content

Commit a8e251f

Browse files
committed
OWSaveBase: Add method get_filters()
1 parent 0b41b58 commit a8e251f

File tree

4 files changed

+45
-39
lines changed

4 files changed

+45
-39
lines changed

Orange/widgets/data/owsave.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,21 +95,21 @@ def migrate_settings(cls, settings, version=0):
9595
def migrate_to_version_2():
9696
# Set the default; change later if possible
9797
settings.pop("compression", None)
98-
settings["filter"] = next(iter(cls.filters))
98+
settings["filter"] = next(iter(cls.get_filters()))
9999
filetype = settings.pop("filetype", None)
100100
if filetype is None:
101101
return
102102

103103
ext = cls._extension_from_filter(filetype)
104104
if settings.pop("compress", False):
105-
for afilter in cls.filters:
105+
for afilter in cls.get_filters():
106106
if ext + ".gz" in afilter:
107107
settings["filter"] = afilter
108108
return
109109
# If not found, uncompressed may have been erroneously set
110110
# for a writer that didn't support if (such as .xlsx), so
111111
# fall through to uncompressed
112-
for afilter in cls.filters:
112+
for afilter in cls.get_filters():
113113
if ext in afilter:
114114
settings["filter"] = afilter
115115
return
@@ -128,20 +128,18 @@ def initial_start_dir(self):
128128

129129
def valid_filters(self):
130130
if self.data is None or not self.data.is_sparse():
131-
return self.filters
131+
return self.get_filters()
132132
else:
133-
return {filt: writer for filt, writer in self.filters.items()
133+
return {filt: writer for filt, writer in self.get_filters().items()
134134
if writer.SUPPORT_SPARSE_DATA}
135135

136136
def default_valid_filter(self):
137+
valid = self.valid_filters()
137138
if self.data is None or not self.data.is_sparse() \
138-
or self.filters[self.filter].SUPPORT_SPARSE_DATA:
139+
or (self.filter in valid
140+
and valid[self.filter].SUPPORT_SPARSE_DATA):
139141
return self.filter
140-
for filt, writer in self.filters.items():
141-
if writer.SUPPORT_SPARSE_DATA:
142-
return filt
143-
# This shouldn't happen and it will trigger an error in tests
144-
return None # pragma: no cover
142+
return next(iter(valid))
145143

146144

147145
if __name__ == "__main__": # pragma: no cover

Orange/widgets/data/tests/test_owsave.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_initial_start_dir(self):
9191
@patch("Orange.widgets.utils.save.owsavebase.QFileDialog.getSaveFileName")
9292
def test_save_file_sets_name(self, _filedialog):
9393
widget = self.widget
94-
filters = iter(widget.filters)
94+
filters = iter(widget.get_filters())
9595
filter1 = next(filters)
9696
filter2 = next(filters)
9797

@@ -259,20 +259,20 @@ def test_valid_filters_for_sparse(self):
259259
widget = self.widget
260260

261261
widget.data = None
262-
self.assertEqual(widget.filters, widget.valid_filters())
262+
self.assertEqual(widget.get_filters(), widget.valid_filters())
263263

264264
widget.data = self.iris
265-
self.assertEqual(widget.filters, widget.valid_filters())
265+
self.assertEqual(widget.get_filters(), widget.valid_filters())
266266

267267
widget.data.X = sp.csr_matrix(widget.data.X)
268268
valid = widget.valid_filters()
269-
self.assertNotEqual(widget.filters, {})
269+
self.assertNotEqual(widget.get_filters(), {})
270270
# false positive, pylint: disable=no-member
271271
self.assertTrue(all(v.SUPPORT_SPARSE_DATA for v in valid.values()))
272272

273273
def test_valid_default_filter(self):
274274
widget = self.widget
275-
for widget.filter, writer in widget.filters.items():
275+
for widget.filter, writer in widget.get_filters().items():
276276
if not writer.SUPPORT_SPARSE_DATA:
277277
break
278278

@@ -284,13 +284,14 @@ def test_valid_default_filter(self):
284284

285285
widget.data.X = sp.csr_matrix(widget.data.X)
286286
self.assertTrue(
287-
widget.filters[widget.default_valid_filter()].SUPPORT_SPARSE_DATA)
287+
widget.get_filters()[widget.default_valid_filter()]
288+
.SUPPORT_SPARSE_DATA)
288289

289290
def test_send_report(self):
290291
widget = self.widget
291292

292293
widget.report_items = Mock()
293-
for writer in widget.filters.values():
294+
for writer in widget.get_filters().values():
294295
widget.writer = writer
295296
for widget.add_type_annotations in (False, True):
296297
widget.filename = f"foo.{writer.EXTENSIONS[0]}"
@@ -355,14 +356,14 @@ def test_migration_to_version_2(self):
355356
settings = {**const_settings,
356357
'compress': True, 'compression': 'lzma (.xz)'}
357358
OWSave.migrate_settings(settings)
358-
self.assertTrue(settings["filter"] in OWSave.filters)
359+
self.assertTrue(settings["filter"] in OWSave.get_filters())
359360

360361
# Unsupported file format (is this possible?)
361362
settings = {**const_settings,
362363
'compress': True, 'compression': 'lzma (.xz)',
363364
'filetype': 'Bar file (.bar)'}
364365
OWSave.migrate_settings(settings)
365-
self.assertTrue(settings["filter"] in OWSave.filters)
366+
self.assertTrue(settings["filter"] in OWSave.get_filters())
366367

367368

368369
class TestFunctionalOWSave(WidgetTest):
@@ -377,7 +378,7 @@ def test_save_uncompressed(self):
377378
spiris = Table("iris")
378379
spiris.X = sp.csr_matrix(spiris.X)
379380

380-
for selected_filter, writer in widget.filters.items():
381+
for selected_filter, writer in widget.get_filters().items():
381382
widget.write = writer
382383
ext = writer.EXTENSIONS[0]
383384
with named_file("", suffix=ext) as filename:
@@ -399,7 +400,7 @@ class TestOWSaveLinuxDialog(OWSaveTestBase):
399400
def test_get_save_filename_linux(self):
400401
widget = self.widget
401402
widget.initial_start_dir = lambda: "baz"
402-
widget.filters = dict.fromkeys("abc")
403+
widget.get_filters = lambda: dict.fromkeys("abc")
403404
widget.filter = "b"
404405
dlg = widget.SaveFileDialog = Mock() # pylint: disable=invalid-name
405406
instance = dlg.return_value
@@ -413,7 +414,7 @@ def test_get_save_filename_linux(self):
413414
instance.exec.return_value = QFileDialog.Rejected
414415
self.assertEqual(widget.get_save_filename(), ("", ""))
415416

416-
@patch.object(OWSaveBase, "filters", OWSave.filters)
417+
@patch.object(OWSaveBase, "filters", OWSave.get_filters())
417418
def test_save_file_dialog_enforces_extension_linux(self):
418419
dialog = OWSave.SaveFileDialog(
419420
OWSave, None, "Save File", "foo.bar",
@@ -468,7 +469,8 @@ def remove_star(filt):
468469
def test_get_save_filename_darwin(self, dlg):
469470
widget = self.widget
470471
widget.initial_start_dir = lambda: "baz"
471-
widget.filters = dict.fromkeys(("aa (*.a)", "bb (*.b)", "cc (*.c)"))
472+
widget.get_filters = \
473+
lambda: dict.fromkeys(("aa (*.a)", "bb (*.b)", "cc (*.c)"))
472474
widget.filter = "bb (*.b)"
473475
instance = dlg.return_value
474476
instance.exec.return_value = dlg.Accepted = QFileDialog.Accepted
@@ -489,10 +491,10 @@ def test_get_save_filename_darwin(self, dlg):
489491
def test_save_file_dialog_enforces_extension_darwin(self, dlg):
490492
widget = self.widget
491493
filter1 = "" # prevent pylint warning 'undefined-loop-variable'
492-
for filter1 in widget.filters:
494+
for filter1 in widget.get_filters():
493495
if OWSaveBase._extension_from_filter(filter1) == ".tab":
494496
break
495-
for filter2 in widget.filters:
497+
for filter2 in widget.get_filters():
496498
if OWSaveBase._extension_from_filter(filter2) == ".csv.gz":
497499
break
498500

@@ -536,7 +538,7 @@ def selected_files():
536538
widget = self.widget
537539
widget.initial_start_dir = lambda: "baz"
538540
filter1 = "" # prevent pylint warning 'undefined-loop-variable'
539-
for filter1 in widget.filters:
541+
for filter1 in widget.get_filters():
540542
if OWSaveBase._extension_from_filter(filter1) == ".tab":
541543
break
542544

Orange/widgets/utils/save/owsavebase.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ class OWSaveBase(widget.OWWidget, openclass=True):
2323
- calls `self.on_new_input`.
2424
2525
- a class attribute `filters` with a list of filters or a dictionary whose
26-
keys are filters
27-
- method `do_save` that saves `self.data` into `self.filename`.
26+
keys are filters OR a class method `get_filters` that returns such a
27+
list or dictionary
28+
- method `do_save` that saves `self.data` into `self.filename`
2829
2930
Alternatively, instead of defining `do_save` a derived class can make
3031
`filters` a dictionary whose keys are classes that define a method `write`
@@ -70,7 +71,7 @@ def __init__(self, start_row=0):
7071
self.data = None
7172
# This cannot be done outside because `filters` is defined by subclass
7273
if not self.filter:
73-
self.filter = next(iter(self.filters))
74+
self.filter = next(iter(self.get_filters()))
7475

7576
self.grid = grid = QGridLayout()
7677
gui.widgetBox(self.controlArea, orientation=grid)
@@ -89,6 +90,10 @@ def __init__(self, start_row=0):
8990
self.adjustSize()
9091
self.update_messages()
9192

93+
@classmethod
94+
def get_filters(cls):
95+
return cls.filters
96+
9297
@property
9398
def writer(self):
9499
"""
@@ -98,7 +103,7 @@ def writer(self):
98103
corresponding to the filter. Derived classes (e.g. OWSave) may also use
99104
it elsewhere.
100105
"""
101-
return self.filters[self.filter]
106+
return self.get_filters()[self.filter]
102107

103108
def on_new_input(self):
104109
"""
@@ -160,8 +165,8 @@ def do_save(self):
160165
Do the saving.
161166
162167
Default implementation calls the write method of the writer
163-
corresponding to the current filter. This requires that class attribute
164-
filters is a dictionary whose keys are classes.
168+
corresponding to the current filter. This requires that get_filters()
169+
returns is a dictionary whose keys are classes.
165170
166171
Derived classes may simplify this by providing a list of filters and
167172
override do_save. This is particularly handy if the widget supports only
@@ -217,7 +222,7 @@ def _replace_extension(cls, filename, extension):
217222
function removes anything that can appear anywhere.
218223
"""
219224
known_extensions = set()
220-
for filt in cls.filters:
225+
for filt in cls.get_filters():
221226
known_extensions |= set(cls._extension_from_filter(filt).split("."))
222227
if "" in known_extensions:
223228
known_extensions.remove("")
@@ -233,7 +238,7 @@ def _extension_from_filter(selected_filter):
233238
return re.search(r".*\(\*?(\..*)\)$", selected_filter).group(1)
234239

235240
def valid_filters(self):
236-
return self.filters
241+
return self.get_filters()
237242

238243
def default_valid_filter(self):
239244
return self.filter

Orange/widgets/utils/save/tests/test_owsavebase.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@ def test_input_handler(self):
3636
widget.on_new_input.assert_called()
3737

3838
def test_filters(self):
39-
self.assertGreaterEqual(len(self.widget.filters), 1,
39+
filters = self.widget.get_filters()
40+
self.assertGreaterEqual(len(filters), 1,
4041
msg="Widget defines no filters")
4142
if type(self.widget).do_save is OWSaveBase.do_save:
42-
self.assertIsInstance(self.widget.filters, collections.abc.Mapping)
43+
self.assertIsInstance(filters, collections.abc.Mapping)
4344

4445

4546
class TestOWSaveBaseWithWriters(WidgetTest):
@@ -92,7 +93,7 @@ def test_base_methods(self):
9293
self.assertEqual(widget.initial_start_dir(),
9394
os.path.expanduser(f"~{os.sep}"))
9495
self.assertEqual(widget.suggested_name(), "")
95-
self.assertIs(widget.valid_filters(), widget.filters)
96+
self.assertIs(widget.valid_filters(), widget.get_filters())
9697
self.assertIs(widget.default_valid_filter(), widget.filter)
9798

9899

@@ -133,7 +134,7 @@ def test_base_methods(self):
133134
self.assertEqual(widget.initial_start_dir(),
134135
os.path.expanduser(f"~{os.sep}"))
135136
self.assertEqual(widget.suggested_name(), "")
136-
self.assertIs(widget.valid_filters(), widget.filters)
137+
self.assertIs(widget.valid_filters(), widget.get_filters())
137138
self.assertIs(widget.default_valid_filter(), widget.filter)
138139

139140

0 commit comments

Comments
 (0)