Skip to content

Commit 6c9039a

Browse files
authored
Merge pull request #4302 from janezd/save-addon-formats
[ENH] Allow add-ons to register file format for the Save widget
2 parents 48895d1 + 2177d80 commit 6c9039a

File tree

5 files changed

+85
-52
lines changed

5 files changed

+85
-52
lines changed

Orange/data/io.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,8 @@ def write_file(cls, filename, data):
395395
# the sort order in file open/save combo boxes. Lower is better.
396396
PRIORITY = 10000
397397
OPTIONAL_TYPE_ANNOTATIONS = False
398+
SUPPORT_COMPRESSED = False
399+
SUPPORT_SPARSE_DATA = False
398400

399401
def __init__(self, filename):
400402
"""

Orange/widgets/data/owsave.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os.path
22

33
from Orange.data.table import Table
4-
from Orange.data.io import TabReader, CSVReader, PickleReader, ExcelReader, \
5-
XlsReader
4+
from Orange.data.io import \
5+
TabReader, CSVReader, PickleReader, ExcelReader, XlsReader, FileFormat
66
from Orange.widgets import gui, widget
77
from Orange.widgets.widget import Input
88
from Orange.widgets.settings import Setting
@@ -22,14 +22,6 @@ class OWSave(OWSaveBase):
2222

2323
settings_version = 2
2424

25-
writers = [TabReader, CSVReader, PickleReader, ExcelReader, XlsReader]
26-
filters = {
27-
**{f"{w.DESCRIPTION} (*{w.EXTENSIONS[0]})": w
28-
for w in writers},
29-
**{f"Compressed {w.DESCRIPTION} (*{w.EXTENSIONS[0]}.gz)": w
30-
for w in writers if w.SUPPORT_COMPRESSED}
31-
}
32-
3325
class Inputs:
3426
data = Input("Data", Table)
3527

@@ -38,6 +30,8 @@ class Error(OWSaveBase.Error):
3830

3931
add_type_annotations = Setting(True)
4032

33+
builtin_order = [TabReader, CSVReader, PickleReader, ExcelReader, XlsReader]
34+
4135
def __init__(self):
4236
super().__init__(2)
4337

@@ -53,6 +47,21 @@ def __init__(self):
5347
self.grid.setRowMinimumHeight(1, 8)
5448
self.adjustSize()
5549

50+
@classmethod
51+
def get_filters(cls):
52+
writers = [format for format in FileFormat.formats
53+
if getattr(format, 'write_file', None)
54+
and getattr(format, "EXTENSIONS", None)]
55+
writers.sort(key=lambda writer: cls.builtin_order.index(writer)
56+
if writer in cls.builtin_order else 99)
57+
58+
return {
59+
**{f"{w.DESCRIPTION} (*{w.EXTENSIONS[0]})": w
60+
for w in writers},
61+
**{f"Compressed {w.DESCRIPTION} (*{w.EXTENSIONS[0]}.gz)": w
62+
for w in writers if w.SUPPORT_COMPRESSED}
63+
}
64+
5665
@Inputs.data
5766
def dataset(self, data):
5867
self.data = data
@@ -95,21 +104,21 @@ def migrate_settings(cls, settings, version=0):
95104
def migrate_to_version_2():
96105
# Set the default; change later if possible
97106
settings.pop("compression", None)
98-
settings["filter"] = next(iter(cls.filters))
107+
settings["filter"] = next(iter(cls.get_filters()))
99108
filetype = settings.pop("filetype", None)
100109
if filetype is None:
101110
return
102111

103112
ext = cls._extension_from_filter(filetype)
104113
if settings.pop("compress", False):
105-
for afilter in cls.filters:
114+
for afilter in cls.get_filters():
106115
if ext + ".gz" in afilter:
107116
settings["filter"] = afilter
108117
return
109118
# If not found, uncompressed may have been erroneously set
110119
# for a writer that didn't support if (such as .xlsx), so
111120
# fall through to uncompressed
112-
for afilter in cls.filters:
121+
for afilter in cls.get_filters():
113122
if ext in afilter:
114123
settings["filter"] = afilter
115124
return
@@ -128,20 +137,18 @@ def initial_start_dir(self):
128137

129138
def valid_filters(self):
130139
if self.data is None or not self.data.is_sparse():
131-
return self.filters
140+
return self.get_filters()
132141
else:
133-
return {filt: writer for filt, writer in self.filters.items()
142+
return {filt: writer for filt, writer in self.get_filters().items()
134143
if writer.SUPPORT_SPARSE_DATA}
135144

136145
def default_valid_filter(self):
146+
valid = self.valid_filters()
137147
if self.data is None or not self.data.is_sparse() \
138-
or self.filters[self.filter].SUPPORT_SPARSE_DATA:
148+
or (self.filter in valid
149+
and valid[self.filter].SUPPORT_SPARSE_DATA):
139150
return self.filter
140-
for filt, writer in self.filters.items():
141-
if writer.SUPPORT_SPARSE_DATA:
142-
return filt
143-
# This shouldn't happen and it will trigger an error in tests
144-
return None # pragma: no cover
151+
return next(iter(valid))
145152

146153

147154
if __name__ == "__main__": # pragma: no cover

Orange/widgets/data/tests/test_owsave.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from AnyQt.QtWidgets import QFileDialog
99

1010
from Orange.data import Table
11-
from Orange.data.io import TabReader, PickleReader, ExcelReader
11+
from Orange.data.io import TabReader, PickleReader, ExcelReader, FileFormat
1212
from Orange.tests import named_file
1313
from Orange.widgets.data.owsave import OWSave, OWSaveBase
1414
from Orange.widgets.utils.save.tests.test_owsavebase import \
@@ -23,6 +23,15 @@ def _w(s): # pylint: disable=invalid-name
2323
return s.replace("/", os.sep)
2424

2525

26+
class MockFormat(FileFormat):
27+
EXTENSIONS = ('.mock',)
28+
DESCRIPTION = "Mock file format"
29+
30+
@staticmethod
31+
def write_file(filename, data):
32+
pass
33+
34+
2635
class OWSaveTestBase(WidgetTest, SaveWidgetsTestBaseMixin):
2736
def setUp(self):
2837
with open_widget_classes():
@@ -91,7 +100,7 @@ def test_initial_start_dir(self):
91100
@patch("Orange.widgets.utils.save.owsavebase.QFileDialog.getSaveFileName")
92101
def test_save_file_sets_name(self, _filedialog):
93102
widget = self.widget
94-
filters = iter(widget.filters)
103+
filters = iter(widget.get_filters())
95104
filter1 = next(filters)
96105
filter2 = next(filters)
97106

@@ -259,20 +268,20 @@ def test_valid_filters_for_sparse(self):
259268
widget = self.widget
260269

261270
widget.data = None
262-
self.assertEqual(widget.filters, widget.valid_filters())
271+
self.assertEqual(widget.get_filters(), widget.valid_filters())
263272

264273
widget.data = self.iris
265-
self.assertEqual(widget.filters, widget.valid_filters())
274+
self.assertEqual(widget.get_filters(), widget.valid_filters())
266275

267276
widget.data.X = sp.csr_matrix(widget.data.X)
268277
valid = widget.valid_filters()
269-
self.assertNotEqual(widget.filters, {})
278+
self.assertNotEqual(widget.get_filters(), {})
270279
# false positive, pylint: disable=no-member
271280
self.assertTrue(all(v.SUPPORT_SPARSE_DATA for v in valid.values()))
272281

273282
def test_valid_default_filter(self):
274283
widget = self.widget
275-
for widget.filter, writer in widget.filters.items():
284+
for widget.filter, writer in widget.get_filters().items():
276285
if not writer.SUPPORT_SPARSE_DATA:
277286
break
278287

@@ -284,13 +293,19 @@ def test_valid_default_filter(self):
284293

285294
widget.data.X = sp.csr_matrix(widget.data.X)
286295
self.assertTrue(
287-
widget.filters[widget.default_valid_filter()].SUPPORT_SPARSE_DATA)
296+
widget.get_filters()[widget.default_valid_filter()]
297+
.SUPPORT_SPARSE_DATA)
298+
299+
def test_add_on_writers(self):
300+
# test adding file formats after registering the widget
301+
self.assertIn(MockFormat, self.widget.valid_filters().values())
302+
# this test doesn't call it - test_save_uncompressed does
288303

289304
def test_send_report(self):
290305
widget = self.widget
291306

292307
widget.report_items = Mock()
293-
for writer in widget.filters.values():
308+
for writer in widget.get_filters().values():
294309
widget.writer = writer
295310
for widget.add_type_annotations in (False, True):
296311
widget.filename = f"foo.{writer.EXTENSIONS[0]}"
@@ -355,14 +370,14 @@ def test_migration_to_version_2(self):
355370
settings = {**const_settings,
356371
'compress': True, 'compression': 'lzma (.xz)'}
357372
OWSave.migrate_settings(settings)
358-
self.assertTrue(settings["filter"] in OWSave.filters)
373+
self.assertTrue(settings["filter"] in OWSave.get_filters())
359374

360375
# Unsupported file format (is this possible?)
361376
settings = {**const_settings,
362377
'compress': True, 'compression': 'lzma (.xz)',
363378
'filetype': 'Bar file (.bar)'}
364379
OWSave.migrate_settings(settings)
365-
self.assertTrue(settings["filter"] in OWSave.filters)
380+
self.assertTrue(settings["filter"] in OWSave.get_filters())
366381

367382

368383
class TestFunctionalOWSave(WidgetTest):
@@ -377,7 +392,7 @@ def test_save_uncompressed(self):
377392
spiris = Table("iris")
378393
spiris.X = sp.csr_matrix(spiris.X)
379394

380-
for selected_filter, writer in widget.filters.items():
395+
for selected_filter, writer in widget.get_filters().items():
381396
widget.write = writer
382397
ext = writer.EXTENSIONS[0]
383398
with named_file("", suffix=ext) as filename:
@@ -386,20 +401,22 @@ def test_save_uncompressed(self):
386401

387402
self.send_signal(widget.Inputs.data, self.iris)
388403
widget.save_file_as()
389-
self.assertEqual(len(Table(filename)), 150)
404+
if hasattr(writer, "read"):
405+
self.assertEqual(len(Table(filename)), 150)
390406

391407
if writer.SUPPORT_SPARSE_DATA:
392408
self.send_signal(widget.Inputs.data, spiris)
393409
widget.save_file()
394-
self.assertEqual(len(Table(filename)), 150)
410+
if hasattr(writer, "read"):
411+
self.assertEqual(len(Table(filename)), 150)
395412

396413

397414
@unittest.skipUnless(sys.platform == "linux", "Tests for dialog on Linux")
398415
class TestOWSaveLinuxDialog(OWSaveTestBase):
399416
def test_get_save_filename_linux(self):
400417
widget = self.widget
401418
widget.initial_start_dir = lambda: "baz"
402-
widget.filters = dict.fromkeys("abc")
419+
widget.get_filters = lambda: dict.fromkeys("abc")
403420
widget.filter = "b"
404421
dlg = widget.SaveFileDialog = Mock() # pylint: disable=invalid-name
405422
instance = dlg.return_value
@@ -413,7 +430,7 @@ def test_get_save_filename_linux(self):
413430
instance.exec.return_value = QFileDialog.Rejected
414431
self.assertEqual(widget.get_save_filename(), ("", ""))
415432

416-
@patch.object(OWSaveBase, "filters", OWSave.filters)
433+
@patch.object(OWSaveBase, "filters", OWSave.get_filters())
417434
def test_save_file_dialog_enforces_extension_linux(self):
418435
dialog = OWSave.SaveFileDialog(
419436
OWSave, None, "Save File", "foo.bar",
@@ -468,7 +485,8 @@ def remove_star(filt):
468485
def test_get_save_filename_darwin(self, dlg):
469486
widget = self.widget
470487
widget.initial_start_dir = lambda: "baz"
471-
widget.filters = dict.fromkeys(("aa (*.a)", "bb (*.b)", "cc (*.c)"))
488+
widget.get_filters = \
489+
lambda: dict.fromkeys(("aa (*.a)", "bb (*.b)", "cc (*.c)"))
472490
widget.filter = "bb (*.b)"
473491
instance = dlg.return_value
474492
instance.exec.return_value = dlg.Accepted = QFileDialog.Accepted
@@ -489,10 +507,10 @@ def test_get_save_filename_darwin(self, dlg):
489507
def test_save_file_dialog_enforces_extension_darwin(self, dlg):
490508
widget = self.widget
491509
filter1 = "" # prevent pylint warning 'undefined-loop-variable'
492-
for filter1 in widget.filters:
510+
for filter1 in widget.get_filters():
493511
if OWSaveBase._extension_from_filter(filter1) == ".tab":
494512
break
495-
for filter2 in widget.filters:
513+
for filter2 in widget.get_filters():
496514
if OWSaveBase._extension_from_filter(filter2) == ".csv.gz":
497515
break
498516

@@ -536,7 +554,7 @@ def selected_files():
536554
widget = self.widget
537555
widget.initial_start_dir = lambda: "baz"
538556
filter1 = "" # prevent pylint warning 'undefined-loop-variable'
539-
for filter1 in widget.filters:
557+
for filter1 in widget.get_filters():
540558
if OWSaveBase._extension_from_filter(filter1) == ".tab":
541559
break
542560

Orange/widgets/utils/save/owsavebase.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ class OWSaveBase(widget.OWWidget, openclass=True):
2323
- calls `self.on_new_input`.
2424
2525
- a class attribute `filters` with a list of filters or a dictionary whose
26-
keys are filters
27-
- method `do_save` that saves `self.data` into `self.filename`.
26+
keys are filters OR a class method `get_filters` that returns such a
27+
list or dictionary
28+
- method `do_save` that saves `self.data` into `self.filename`
2829
2930
Alternatively, instead of defining `do_save` a derived class can make
3031
`filters` a dictionary whose keys are classes that define a method `write`
@@ -70,7 +71,7 @@ def __init__(self, start_row=0):
7071
self.data = None
7172
# This cannot be done outside because `filters` is defined by subclass
7273
if not self.filter:
73-
self.filter = next(iter(self.filters))
74+
self.filter = next(iter(self.get_filters()))
7475

7576
self.grid = grid = QGridLayout()
7677
gui.widgetBox(self.controlArea, orientation=grid)
@@ -89,6 +90,10 @@ def __init__(self, start_row=0):
8990
self.adjustSize()
9091
self.update_messages()
9192

93+
@classmethod
94+
def get_filters(cls):
95+
return cls.filters
96+
9297
@property
9398
def writer(self):
9499
"""
@@ -98,7 +103,7 @@ def writer(self):
98103
corresponding to the filter. Derived classes (e.g. OWSave) may also use
99104
it elsewhere.
100105
"""
101-
return self.filters[self.filter]
106+
return self.get_filters()[self.filter]
102107

103108
def on_new_input(self):
104109
"""
@@ -160,8 +165,8 @@ def do_save(self):
160165
Do the saving.
161166
162167
Default implementation calls the write method of the writer
163-
corresponding to the current filter. This requires that class attribute
164-
filters is a dictionary whose keys are classes.
168+
corresponding to the current filter. This requires that get_filters()
169+
returns is a dictionary whose keys are classes.
165170
166171
Derived classes may simplify this by providing a list of filters and
167172
override do_save. This is particularly handy if the widget supports only
@@ -217,7 +222,7 @@ def _replace_extension(cls, filename, extension):
217222
function removes anything that can appear anywhere.
218223
"""
219224
known_extensions = set()
220-
for filt in cls.filters:
225+
for filt in cls.get_filters():
221226
known_extensions |= set(cls._extension_from_filter(filt).split("."))
222227
if "" in known_extensions:
223228
known_extensions.remove("")
@@ -233,7 +238,7 @@ def _extension_from_filter(selected_filter):
233238
return re.search(r".*\(\*?(\..*)\)$", selected_filter).group(1)
234239

235240
def valid_filters(self):
236-
return self.filters
241+
return self.get_filters()
237242

238243
def default_valid_filter(self):
239244
return self.filter

Orange/widgets/utils/save/tests/test_owsavebase.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@ def test_input_handler(self):
3636
widget.on_new_input.assert_called()
3737

3838
def test_filters(self):
39-
self.assertGreaterEqual(len(self.widget.filters), 1,
39+
filters = self.widget.get_filters()
40+
self.assertGreaterEqual(len(filters), 1,
4041
msg="Widget defines no filters")
4142
if type(self.widget).do_save is OWSaveBase.do_save:
42-
self.assertIsInstance(self.widget.filters, collections.abc.Mapping)
43+
self.assertIsInstance(filters, collections.abc.Mapping)
4344

4445

4546
class TestOWSaveBaseWithWriters(WidgetTest):
@@ -92,7 +93,7 @@ def test_base_methods(self):
9293
self.assertEqual(widget.initial_start_dir(),
9394
os.path.expanduser(f"~{os.sep}"))
9495
self.assertEqual(widget.suggested_name(), "")
95-
self.assertIs(widget.valid_filters(), widget.filters)
96+
self.assertIs(widget.valid_filters(), widget.get_filters())
9697
self.assertIs(widget.default_valid_filter(), widget.filter)
9798

9899

@@ -133,7 +134,7 @@ def test_base_methods(self):
133134
self.assertEqual(widget.initial_start_dir(),
134135
os.path.expanduser(f"~{os.sep}"))
135136
self.assertEqual(widget.suggested_name(), "")
136-
self.assertIs(widget.valid_filters(), widget.filters)
137+
self.assertIs(widget.valid_filters(), widget.get_filters())
137138
self.assertIs(widget.default_valid_filter(), widget.filter)
138139

139140

0 commit comments

Comments
 (0)