Skip to content

Commit 690a402

Browse files
authored
Merge pull request #4977 from janezd/owcolor-save-schemata
[ENH] OWColor: Saving and loading color schemata
2 parents de7c4c4 + e018fc6 commit 690a402

File tree

2 files changed

+562
-14
lines changed

2 files changed

+562
-14
lines changed

Orange/widgets/data/owcolor.py

Lines changed: 216 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,36 @@
1+
import os
12
from itertools import chain
3+
import json
24

35
import numpy as np
46

5-
from AnyQt.QtCore import Qt, QSize, QAbstractTableModel, QModelIndex, QTimer
7+
from AnyQt.QtCore import Qt, QSize, QAbstractTableModel, QModelIndex, QTimer, \
8+
QSettings
69
from AnyQt.QtGui import QColor, QFont, QBrush
7-
from AnyQt.QtWidgets import QHeaderView, QColorDialog, QTableView, QComboBox
10+
from AnyQt.QtWidgets import QHeaderView, QColorDialog, QTableView, QComboBox, \
11+
QFileDialog, QMessageBox
12+
13+
from orangewidget.settings import IncompatibleContext
814

915
import Orange
1016
from Orange.preprocess.transformation import Identity
11-
from Orange.util import color_to_hex
17+
from Orange.util import color_to_hex, hex_to_color
1218
from Orange.widgets import widget, settings, gui
1319
from Orange.widgets.gui import HorizontalGridDelegate
1420
from Orange.widgets.utils import itemmodels, colorpalettes
1521
from Orange.widgets.utils.widgetpreview import WidgetPreview
1622
from Orange.widgets.utils.state_summary import format_summary_details
23+
from Orange.widgets.report import colored_square as square
1724
from Orange.widgets.widget import Input, Output
18-
from orangewidget.settings import IncompatibleContext
1925

2026
ColorRole = next(gui.OrangeUserRole)
2127
StripRole = next(gui.OrangeUserRole)
2228

2329

30+
class InvalidFileFormat(Exception):
31+
pass
32+
33+
2434
class AttrDesc:
2535
"""
2636
Describes modifications that will be applied to variable.
@@ -46,6 +56,24 @@ def name(self):
4656
def name(self, name):
4757
self.new_name = name
4858

59+
def to_dict(self):
60+
d = {}
61+
if self.new_name is not None:
62+
d["rename"] = self.new_name
63+
return d
64+
65+
@classmethod
66+
def from_dict(cls, var, data):
67+
desc = cls(var)
68+
if not isinstance(data, dict):
69+
raise InvalidFileFormat
70+
new_name = data.get("rename")
71+
if new_name is not None:
72+
if not isinstance(new_name, str):
73+
raise InvalidFileFormat
74+
desc.name = new_name
75+
return desc, []
76+
4977

5078
class DiscAttrDesc(AttrDesc):
5179
"""
@@ -96,6 +124,57 @@ def create_variable(self):
96124
new_var.colors = np.asarray(self.colors)
97125
return new_var
98126

127+
def to_dict(self):
128+
d = super().to_dict()
129+
if self.new_values is not None:
130+
d["renamed_values"] = \
131+
{k: v
132+
for k, v in zip(self.var.values, self.new_values)
133+
if k != v}
134+
if self.new_colors is not None:
135+
d["colors"] = {
136+
value: color_to_hex(color)
137+
for value, color in zip(self.var.values, self.colors)}
138+
return d
139+
140+
@classmethod
141+
def from_dict(cls, var, data):
142+
143+
def _check_dict_str_str(d):
144+
if not isinstance(d, dict) or \
145+
not all(isinstance(val, str)
146+
for val in chain(d, d.values())):
147+
raise InvalidFileFormat
148+
149+
obj, warnings = super().from_dict(var, data)
150+
151+
val_map = data.get("renamed_values")
152+
if val_map is not None:
153+
_check_dict_str_str(val_map)
154+
mapped_values = [val_map.get(value, value) for value in var.values]
155+
if len(set(mapped_values)) != len(mapped_values):
156+
warnings.append(
157+
f"{var.name}: "
158+
"renaming of values ignored due to duplicate names")
159+
else:
160+
obj.new_values = mapped_values
161+
162+
new_colors = data.get("colors")
163+
if new_colors is not None:
164+
_check_dict_str_str(new_colors)
165+
colors = []
166+
for value, def_color in zip(var.values, var.palette.palette):
167+
if value in new_colors:
168+
try:
169+
color = hex_to_color(new_colors[value])
170+
except ValueError as exc:
171+
raise InvalidFileFormat from exc
172+
colors.append(color)
173+
else:
174+
colors.append(def_color)
175+
obj.new_colors = colors
176+
return obj, warnings
177+
99178

100179
class ContAttrDesc(AttrDesc):
101180
"""
@@ -136,6 +215,22 @@ def create_variable(self):
136215
new_var.attributes["palette"] = self.palette_name
137216
return new_var
138217

218+
def to_dict(self):
219+
d = super().to_dict()
220+
if self.new_palette_name is not None:
221+
d["colors"] = self.palette_name
222+
return d
223+
224+
@classmethod
225+
def from_dict(cls, var, data):
226+
obj, warnings = super().from_dict(var, data)
227+
colors = data.get("colors")
228+
if colors is not None:
229+
if colors not in colorpalettes.ContinuousPalettes:
230+
raise InvalidFileFormat
231+
obj.palette_name = colors
232+
return obj, warnings
233+
139234

140235
class ColorTableModel(QAbstractTableModel):
141236
"""
@@ -312,7 +407,7 @@ def __init__(self, view):
312407
super().__init__()
313408
self.view = view
314409

315-
def createEditor(self, parent, option, index):
410+
def createEditor(self, parent, _, index):
316411
class Combo(QComboBox):
317412
def __init__(self, parent, initial_data, view):
318413
super().__init__(parent)
@@ -454,7 +549,6 @@ class Outputs:
454549
match_values=settings.PerfectDomainContextHandler.MATCH_VALUES_ALL)
455550
disc_descs = settings.ContextSetting([])
456551
cont_descs = settings.ContextSetting([])
457-
color_settings = settings.Setting(None)
458552
selected_schema_index = settings.Setting(0)
459553
auto_apply = settings.Setting(True)
460554

@@ -481,9 +575,13 @@ def __init__(self):
481575

482576
box = gui.auto_apply(self.controlArea, self, "auto_apply")
483577
box.button.setFixedWidth(180)
578+
save = gui.button(None, self, "Save", callback=self.save)
579+
load = gui.button(None, self, "Load", callback=self.load)
484580
reset = gui.button(None, self, "Reset", callback=self.reset)
485-
box.layout().insertWidget(0, reset)
486-
box.layout().insertStretch(1)
581+
box.layout().insertWidget(0, save)
582+
box.layout().insertWidget(0, load)
583+
box.layout().insertWidget(2, reset)
584+
box.layout().insertStretch(3)
487585

488586
self.info.set_input_summary(self.info.NoInput)
489587
self.info.set_output_summary(self.info.NoOutput)
@@ -524,6 +622,114 @@ def reset(self):
524622
self.cont_model.reset()
525623
self.commit()
526624

625+
def save(self):
626+
fname, _ = QFileDialog.getSaveFileName(
627+
self, "File name", self._start_dir(),
628+
"Variable definitions (*.colors)")
629+
if not fname:
630+
return
631+
QSettings().setValue("colorwidget/last-location",
632+
os.path.split(fname)[0])
633+
self._save_var_defs(fname)
634+
635+
def _save_var_defs(self, fname):
636+
with open(fname, "w") as f:
637+
json.dump(
638+
{vartype: {
639+
var.name: var_data
640+
for var, var_data in (
641+
(desc.var, desc.to_dict()) for desc in repo)
642+
if var_data}
643+
for vartype, repo in (("categorical", self.disc_descs),
644+
("numeric", self.cont_descs))
645+
},
646+
f,
647+
indent=4)
648+
649+
def load(self):
650+
fname, _ = QFileDialog.getOpenFileName(
651+
self, "File name", self._start_dir(),
652+
"Variable definitions (*.colors)")
653+
if not fname:
654+
return
655+
656+
try:
657+
f = open(fname)
658+
except IOError:
659+
QMessageBox.critical(self, "File error", "File cannot be opened.")
660+
return
661+
662+
try:
663+
js = json.load(f) #: dict
664+
self._parse_var_defs(js)
665+
except (json.JSONDecodeError, InvalidFileFormat):
666+
QMessageBox.critical(self, "File error", "Invalid file format.")
667+
668+
def _parse_var_defs(self, js):
669+
if not isinstance(js, dict) or set(js) != {"categorical", "numeric"}:
670+
raise InvalidFileFormat
671+
try:
672+
renames = {
673+
var_name: desc["rename"]
674+
for repo in js.values() for var_name, desc in repo.items()
675+
if "rename" in desc
676+
}
677+
# js is an object coming from json file that can be manipulated by
678+
# the user, so there are too many things that can go wrong.
679+
# Catch all exceptions, therefore.
680+
except Exception as exc:
681+
raise InvalidFileFormat from exc
682+
if not all(isinstance(val, str)
683+
for val in chain(renames, renames.values())):
684+
raise InvalidFileFormat
685+
renamed_vars = {
686+
renames.get(desc.var.name, desc.var.name)
687+
for desc in chain(self.disc_descs, self.cont_descs)
688+
}
689+
if len(renamed_vars) != len(self.disc_descs) + len(self.cont_descs):
690+
QMessageBox.warning(
691+
self,
692+
"Duplicated variable names",
693+
"Variables will not be renamed due to duplicated names.")
694+
for repo in js.values():
695+
for desc in repo.values():
696+
desc.pop("rename", None)
697+
698+
# First, construct all descriptions; assign later, after we know
699+
# there won't be exceptions due to invalid file format
700+
both_descs = []
701+
warnings = []
702+
for old_desc, repo, desc_type in (
703+
(self.disc_descs, "categorical", DiscAttrDesc),
704+
(self.cont_descs, "numeric", ContAttrDesc)):
705+
var_by_name = {desc.var.name: desc.var for desc in old_desc}
706+
new_descs = {}
707+
for var_name, var_data in js[repo].items():
708+
var = var_by_name.get(var_name)
709+
if var is None:
710+
continue
711+
# This can throw InvalidFileFormat
712+
new_descs[var_name], warn = desc_type.from_dict(var, var_data)
713+
warnings += warn
714+
both_descs.append(new_descs)
715+
716+
self.disc_descs = [both_descs[0].get(desc.var.name, desc)
717+
for desc in self.disc_descs]
718+
self.cont_descs = [both_descs[1].get(desc.var.name, desc)
719+
for desc in self.cont_descs]
720+
if warnings:
721+
QMessageBox.warning(
722+
self, "Invalid definitions", "\n".join(warnings))
723+
724+
self.disc_model.set_data(self.disc_descs)
725+
self.cont_model.set_data(self.cont_descs)
726+
self.unconditional_commit()
727+
728+
def _start_dir(self):
729+
return self.workflowEnv().get("basedir") \
730+
or QSettings().value("colorwidget/last-location") \
731+
or os.path.expanduser(f"~{os.sep}")
732+
527733
def commit(self):
528734
def make(variables):
529735
new_vars = []
@@ -552,8 +758,6 @@ def make(variables):
552758
def send_report(self):
553759
"""Send report"""
554760
def _report_variables(variables):
555-
from Orange.widgets.report import colored_square as square
556-
557761
def was(n, o):
558762
return n if n == o else f"{n} (was: {o})"
559763

@@ -597,10 +801,10 @@ def was(n, o):
597801
table = "".join(f"<tr><th>{name}</th></tr>{rows}"
598802
for name, rows in sections if rows)
599803
if table:
600-
self.report_raw(r"<table>{table}</table>")
804+
self.report_raw(f"<table>{table}</table>")
601805

602806
@classmethod
603-
def migrate_context(cls, context, version):
807+
def migrate_context(cls, _, version):
604808
if not version or version < 2:
605809
raise IncompatibleContext
606810

0 commit comments

Comments
 (0)