Skip to content

Commit f9d1419

Browse files
committed
proper checkColors
1 parent 6708a3b commit f9d1419

File tree

1 file changed

+47
-54
lines changed

1 file changed

+47
-54
lines changed

prody/dynamics/plotting.py

Lines changed: 47 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
and keyword arguments are passed to the Matplotlib functions."""
88

99
from collections import defaultdict
10+
from matplotlib.colors import is_color_like
11+
1012
from numbers import Number
1113
import numpy as np
1214

@@ -218,7 +220,7 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs):
218220
Default is to use ensemble.getData('size')
219221
:type weights: int, list, :class:`~numpy.ndarray`
220222
221-
:keyword color: a color name or value or a list of length ensemble.numConfs() of these,
223+
:keyword color: a color name or value or a list of length ensemble.numConfs() or projection.shape[0] of these,
222224
or a dictionary with these with keys corresponding to labels provided by keyword label
223225
default is ``'blue'``
224226
Color values can have 1 element to be mapped with cmap or 3 as RGB or 4 as RGBA.
@@ -295,21 +297,7 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs):
295297

296298
c = kwargs.pop('c', 'b')
297299
colors = kwargs.pop('color', c)
298-
colors_dict = {}
299-
if isinstance(colors, np.ndarray):
300-
colors = tuple(colors)
301-
if isinstance(colors, (str, tuple)) or colors is None:
302-
colors = [colors] * num
303-
elif isinstance(colors, list):
304-
if len(colors) != num:
305-
raise ValueError('length of color must be {0}'.format(num))
306-
elif isinstance(colors, dict):
307-
if labels is None:
308-
raise TypeError('color must be a string or a list unless labels are provided')
309-
colors_dict = colors
310-
colors = [colors_dict[label] for label in labels]
311-
else:
312-
raise TypeError('color must be a string or a list or a dict if labels are provided')
300+
colors, colors_dict = checkColors(colors, num, labels, allowNumbers=True)
313301

314302
if labels is not None and len(colors_dict) == 0:
315303
cycle_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
@@ -364,21 +352,6 @@ def showProjection(ensemble=None, modes=None, projection=None, *args, **kwargs):
364352
else:
365353
raise TypeError('marker must be a string or a list')
366354

367-
c = kwargs.pop('c', 'blue')
368-
colors = kwargs.pop('color', c)
369-
370-
if isListLike(colors) and len(colors) == 2:
371-
raise ValueError('each entry of color should have 1, 3 or 4 values not 2')
372-
elif isListLike(colors) and not len(colors) in [3, 4]:
373-
colors = list(colors)
374-
elif isinstance(colors, str) or colors is None or isListLike(colors):
375-
colors = [colors] * num
376-
else:
377-
raise TypeError('color must be string or list-like or None')
378-
379-
if len(colors) != num:
380-
raise ValueError('final length of color must be {0}'.format(num))
381-
382355
color_norm = None
383356
if isinstance(colors[0], Number):
384357
color_norm = matplotlib.colors.Normalize(vmin=min(colors), vmax=max(colors))
@@ -529,7 +502,7 @@ def showCrossProjection(ensemble, mode_x, mode_y, scale=None, *args, **kwargs):
529502
:keyword scalar: scalar factor for projection onto selected mode
530503
:type scalar: float
531504
532-
:keyword color: a color name or a list of color name, default is ``'blue'``
505+
:keyword color: a color spec or a list of color specs, default is ``'blue'``
533506
:type color: str, list
534507
535508
:keyword label: label or a list of labels
@@ -578,13 +551,6 @@ def showCrossProjection(ensemble, mode_x, mode_y, scale=None, *args, **kwargs):
578551
raise TypeError('marker must be a string or a list')
579552

580553
colors = kwargs.pop('color', 'blue')
581-
if isinstance(colors, str) or colors is None:
582-
colors = [colors] * num
583-
elif isinstance(colors, list):
584-
if len(colors) != num:
585-
raise ValueError('length of color must be {0}'.format(num))
586-
else:
587-
raise TypeError('color must be a string or a list')
588554

589555
labels = kwargs.pop('label', None)
590556
if isinstance(labels, str) or labels is None:
@@ -597,21 +563,7 @@ def showCrossProjection(ensemble, mode_x, mode_y, scale=None, *args, **kwargs):
597563

598564
kwargs['ls'] = kwargs.pop('linestyle', None) or kwargs.pop('ls', 'None')
599565

600-
colors_dict = {}
601-
if isinstance(colors, np.ndarray):
602-
colors = tuple(colors)
603-
if isinstance(colors, (str, tuple)) or colors is None:
604-
colors = [colors] * num
605-
elif isinstance(colors, list):
606-
if len(colors) != num:
607-
raise ValueError('length of color must be {0}'.format(num))
608-
elif isinstance(colors, dict):
609-
if labels is None:
610-
raise TypeError('color must be a string or a list unless labels are provided')
611-
colors_dict = colors
612-
colors = [colors_dict[label] for label in labels]
613-
else:
614-
raise TypeError('color must be a string or a list or a dict if labels are provided')
566+
colors, colors_dict = checkColors(colors, num, labels)
615567

616568
if labels is not None and len(colors_dict) == 0:
617569
cycle_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
@@ -2403,3 +2355,44 @@ def showTree_networkx(tree, node_size=20, node_color='red', node_shape='o',
24032355
showFigure()
24042356

24052357
return mpl.gca()
2358+
2359+
2360+
def checkColors(colors, num, labels, allowNumbers=False):
2361+
"""Check colors and process them if needed"""
2362+
2363+
colors_dict = {}
2364+
2365+
if isinstance(colors, np.ndarray):
2366+
colors = tuple(colors)
2367+
2368+
if is_color_like(colors) or colors is None:
2369+
colors = [colors] * num
2370+
elif isListLike(colors):
2371+
colors = list(colors)
2372+
2373+
if isinstance(colors, list):
2374+
if len(colors) != num and not is_color_like(colors):
2375+
raise ValueError('colors should have the length of the set to be colored or satisfy matplotlib color rules')
2376+
2377+
if np.any([not is_color_like(color) for color in colors]):
2378+
if not allowNumbers:
2379+
raise ValueError('each element of colors should satisfy matplotlib color rules')
2380+
elif np.any([not isinstance(color, Number) for color in colors]):
2381+
raise ValueError('each element of colors should be a number or satisfy matplotlib color rules')
2382+
2383+
elif isinstance(colors, dict):
2384+
if labels is None:
2385+
raise TypeError('color must be a string or a list unless labels are provided')
2386+
colors_dict = colors
2387+
colors = [colors_dict[label] for label in labels]
2388+
2389+
if np.any([not is_color_like(color) for color in colors]):
2390+
if not allowNumbers:
2391+
raise ValueError('each element of colors should satisfy matplotlib color rules')
2392+
elif np.any([not isinstance(color, Number) for color in colors]):
2393+
raise ValueError('each element of colors should be a number or satisfy matplotlib color rules')
2394+
2395+
elif not (isinstance(colors, Number) or is_color_like(colors)):
2396+
raise TypeError('color must be a number, string, list, matplotlib color spec, or a dict if labels are provided')
2397+
2398+
return colors, colors_dict

0 commit comments

Comments
 (0)