Skip to content

Commit ae6221d

Browse files
committed
color scales for treemap and sunburst
1 parent 2e2444f commit ae6221d

File tree

2 files changed

+75
-9
lines changed

2 files changed

+75
-9
lines changed

packages/python/plotly/plotly/express/_chart_types.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,9 +1122,8 @@ def pie(
11221122
names=None,
11231123
values=None,
11241124
color=None,
1125-
color_continuous_scale=None,
1126-
range_color=None,
1127-
color_continuous_midpoint=None,
1125+
color_discrete_sequence=None,
1126+
color_discrete_map={},
11281127
textinfo=None,
11291128
hover_name=None,
11301129
hover_data=None,
@@ -1168,6 +1167,8 @@ def sunburst(
11681167
color_continuous_scale=None,
11691168
range_color=None,
11701169
color_continuous_midpoint=None,
1170+
color_discrete_sequence=None,
1171+
color_discrete_map={},
11711172
hover_name=None,
11721173
hover_data=None,
11731174
custom_data=None,
@@ -1203,6 +1204,8 @@ def treemap(
12031204
color_continuous_scale=None,
12041205
range_color=None,
12051206
color_continuous_midpoint=None,
1207+
color_discrete_sequence=None,
1208+
color_discrete_map={},
12061209
hover_name=None,
12071210
hover_data=None,
12081211
custom_data=None,

packages/python/plotly/plotly/express/_core.py

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,19 @@ def get_trendline_results(fig):
5959
return fig._px_trendlines
6060

6161

62+
def make_color_mapping(cat_list, discrete_colorscale):
63+
mapping = {}
64+
colors = []
65+
taken = 0
66+
length = len(discrete_colorscale)
67+
for cat in cat_list:
68+
if mapping.get(cat) is None:
69+
mapping[cat] = discrete_colorscale[taken % length]
70+
taken += 1
71+
colors.append(mapping[cat])
72+
return colors
73+
74+
6275
Mapping = namedtuple(
6376
"Mapping",
6477
[
@@ -295,9 +308,31 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
295308
colorable = "marker"
296309
if colorable not in result:
297310
result[colorable] = dict()
298-
result[colorable]["colors"] = g[v]
299-
result[colorable]["coloraxis"] = "coloraxis1"
300-
mapping_labels[v_label] = "%{color}"
311+
print("ok")
312+
if args.get("color_is_continuous"):
313+
print(
314+
"continuous scale", args["color_continuous_scale"],
315+
)
316+
result[colorable]["colors"] = g[v]
317+
result[colorable]["colorscale"] = args["color_continuous_scale"]
318+
# result[colorable]["coloraxis"] = "coloraxis1"
319+
mapping_labels[v_label] = "%{color}"
320+
else:
321+
print(
322+
"discrete",
323+
args["color_discrete_sequence"],
324+
args.get("color_is_continuous"),
325+
)
326+
result[colorable]["colors"] = make_color_mapping(
327+
g[v], args["color_discrete_sequence"]
328+
)
329+
elif trace_spec.constructor == go.Pie:
330+
colorable = "marker"
331+
if colorable not in result:
332+
result[colorable] = dict()
333+
result[colorable]["colors"] = make_color_mapping(
334+
g[v], args["color_discrete_sequence"]
335+
)
301336
else:
302337
colorable = "marker"
303338
if trace_spec.constructor in [go.Parcats, go.Parcoords]:
@@ -708,6 +743,16 @@ def one_group(x):
708743

709744
def apply_default_cascade(args):
710745
# first we apply px.defaults to unspecified args
746+
# If a discrete or a continuous colorscale is given then we do not set the other type
747+
# This is used for Sunburst and Treemap which accept the two
748+
# if ("color_discrete_sequence" in args and "color_continuous_scale" in args):
749+
# if args["color_discrete_sequence"] is None and args["color_continuous_scale"] is None:
750+
# for param in ["color_discrete_sequence", "color_continuous_scale"]:
751+
# args[param] = getattr(defaults, param)
752+
# else:
753+
# if param in args and args[param] is None:
754+
# args[param] = getattr(defaults, param)
755+
711756
for param in (
712757
["color_discrete_sequence", "color_continuous_scale"]
713758
+ ["symbol_sequence", "line_dash_sequence", "template"]
@@ -733,6 +778,9 @@ def apply_default_cascade(args):
733778
# if colors not set explicitly or in px.defaults, defer to a template
734779
# if the template doesn't have one, we set some final fallback defaults
735780
if "color_continuous_scale" in args:
781+
if args["color_continuous_scale"] is not None:
782+
print("True in cascade")
783+
args["color_is_continuous"] = True
736784
if (
737785
args["color_continuous_scale"] is None
738786
and args["template"].layout.colorscale.sequential
@@ -744,6 +792,9 @@ def apply_default_cascade(args):
744792
args["color_continuous_scale"] = sequential.Viridis
745793

746794
if "color_discrete_sequence" in args:
795+
if args["color_discrete_sequence"] is not None:
796+
print("False in cascade")
797+
args["color_is_continuous"] = False
747798
if args["color_discrete_sequence"] is None and args["template"].layout.colorway:
748799
args["color_discrete_sequence"] = args["template"].layout.colorway
749800
if args["color_discrete_sequence"] is None:
@@ -1024,14 +1075,26 @@ def infer_config(args, constructor, trace_patch):
10241075
and args["data_frame"][args["color"]].dtype.kind in "bifc"
10251076
):
10261077
attrs.append("color")
1078+
if not "color_is_continuous" in args:
1079+
print("True in infer 2")
1080+
args["color_is_continuous"] = True
1081+
elif constructor in [go.Sunburst, go.Treemap]:
1082+
attrs.append("color")
10271083
else:
1028-
grouped_attrs.append("marker.color")
1084+
if constructor not in [go.Pie]:
1085+
grouped_attrs.append("marker.color")
10291086
elif "line_group" in args or constructor == go.Histogram2dContour:
10301087
grouped_attrs.append("line.color")
1031-
else:
1088+
elif constructor not in [go.Pie, go.Sunburst, go.Treemap]:
10321089
grouped_attrs.append("marker.color")
1090+
else:
1091+
attrs.append("color")
10331092

1034-
show_colorbar = bool("color" in attrs and args["color"])
1093+
show_colorbar = bool(
1094+
"color" in attrs
1095+
and args["color"]
1096+
and constructor not in [go.Pie, go.Sunburst, go.Treemap]
1097+
)
10351098
else:
10361099
show_colorbar = False
10371100

0 commit comments

Comments
 (0)