Skip to content

Commit e8027f0

Browse files
all cartesians now support wide mode
1 parent 7b022f1 commit e8027f0

File tree

3 files changed

+49
-24
lines changed

3 files changed

+49
-24
lines changed

packages/python/plotly/plotly/express/_chart_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def density_contour(
7474
animation_group=None,
7575
category_orders={},
7676
labels={},
77+
orientation=None,
7778
color_discrete_sequence=None,
7879
color_discrete_map={},
7980
marginal_x=None,
@@ -130,6 +131,7 @@ def density_heatmap(
130131
animation_group=None,
131132
category_orders={},
132133
labels={},
134+
orientation=None,
133135
color_continuous_scale=None,
134136
range_color=None,
135137
color_continuous_midpoint=None,

packages/python/plotly/plotly/express/_core.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
direct_attrables + array_attrables + group_attrables + renameable_group_attrables
3535
)
3636

37+
cartesians = [go.Scatter, go.Scattergl, go.Bar, go.Funnel, go.Box, go.Violin]
38+
cartesians += [go.Histogram, go.Histogram2d, go.Histogram2dContour]
39+
3740

3841
class PxDefaults(object):
3942
__slots__ = [
@@ -435,8 +438,6 @@ def configure_axes(args, constructor, fig, orders):
435438
go.Scattergeo: configure_geo,
436439
go.Choropleth: configure_geo,
437440
}
438-
cartesians = [go.Scatter, go.Scattergl, go.Bar, go.Funnel, go.Box, go.Violin]
439-
cartesians += [go.Histogram, go.Histogram2d, go.Histogram2dContour]
440441
for c in cartesians:
441442
configurators[c] = configure_cartesian_axes
442443
if constructor in configurators:
@@ -1130,7 +1131,8 @@ def build_dataframe(args, constructor):
11301131

11311132
wide_mode = False
11321133
var_name = None
1133-
if constructor in [go.Scatter, go.Bar, go.Violin, go.Box, go.Histogram, go.Funnel]:
1134+
hist2d_types = [go.Histogram2d, go.Histogram2dContour]
1135+
if constructor in cartesians:
11341136
wide_cross_name = None
11351137
if wide_x and wide_y:
11361138
raise ValueError(
@@ -1160,7 +1162,7 @@ def build_dataframe(args, constructor):
11601162
wide_cross_name = "__x__" if wide_y else "__y__"
11611163

11621164
missing_bar_dim = None
1163-
if constructor in [go.Scatter, go.Bar, go.Funnel]:
1165+
if constructor in [go.Scatter, go.Bar, go.Funnel] + hist2d_types:
11641166
if not wide_mode and (no_x != no_y):
11651167
for ax in ["x", "y"]:
11661168
if args.get(ax, None) is None:
@@ -1203,6 +1205,9 @@ def build_dataframe(args, constructor):
12031205
if args["orientation"] is None:
12041206
args["orientation"] = "v" if missing_bar_dim == "x" else "h"
12051207

1208+
if constructor in hist2d_types:
1209+
del args["orientation"]
1210+
12061211
if wide_mode:
12071212
# at this point, `df_output` is semi-long/semi-wide, but we know which columns
12081213
# are which, so we melt it and reassign `args` to refer to the newly-tidy
@@ -1223,10 +1228,11 @@ def build_dataframe(args, constructor):
12231228
if wide_cross_name == "__y__":
12241229
wide_cross_name = args["y"]
12251230

1226-
if constructor in [go.Scatter, go.Funnel]:
1231+
if constructor in [go.Scatter, go.Funnel] + hist2d_types:
12271232
args["x" if orient_v else "y"] = wide_cross_name
12281233
args["y" if orient_v else "x"] = "_value_"
1229-
args["color"] = args["color"] or var_name
1234+
if constructor != go.Histogram2d:
1235+
args["color"] = args["color"] or var_name
12301236
if constructor == go.Bar:
12311237
if _is_continuous(df_output, "_value_"):
12321238
args["x" if orient_v else "y"] = wide_cross_name

packages/python/plotly/plotly/tests/test_core/test_px/test_px_wide.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_is_col_list():
4848
@pytest.mark.parametrize(
4949
"px_fn",
5050
[px.scatter, px.line, px.area, px.bar, px.violin, px.box, px.strip]
51-
+ [px.histogram, px.funnel],
51+
+ [px.histogram, px.funnel, px.density_contour, px.density_heatmap],
5252
)
5353
@pytest.mark.parametrize("orientation", [None, "v", "h"])
5454
@pytest.mark.parametrize("style", ["implicit", "explicit"])
@@ -67,7 +67,7 @@ def test_wide_mode_external(px_fn, orientation, style):
6767
if style == "implicit":
6868
fig = px_fn(df, orientation=orientation)
6969

70-
if px_fn in [px.scatter, px.line, px.area, px.bar, px.funnel]:
70+
if px_fn in [px.scatter, px.line, px.area, px.bar, px.funnel, px.density_contour]:
7171
if style == "explicit":
7272
fig = px_fn(**{"data_frame": df, y: list(df.columns), x: df.index})
7373
assert len(fig.data) == 3
@@ -78,6 +78,14 @@ def test_wide_mode_external(px_fn, orientation, style):
7878
assert fig.layout[xaxis].title.text == "index"
7979
assert fig.layout[yaxis].title.text == "_value_"
8080
assert fig.layout.legend.title.text == "_column_"
81+
if px_fn in [px.density_heatmap]:
82+
if style == "explicit":
83+
fig = px_fn(**{"data_frame": df, y: list(df.columns), x: df.index})
84+
assert len(fig.data) == 1
85+
assert list(fig.data[0][x]) == [11, 12, 13, 11, 12, 13, 11, 12, 13]
86+
assert list(fig.data[0][y]) == [1, 2, 3, 4, 5, 6, 7, 8, 9]
87+
assert fig.layout[xaxis].title.text == "index"
88+
assert fig.layout[yaxis].title.text == "_value_"
8189
if px_fn in [px.violin, px.box, px.strip]:
8290
if style == "explicit":
8391
fig = px_fn(**{"data_frame": df, y: list(df.columns)})
@@ -125,7 +133,10 @@ def test_wide_mode_labels_external():
125133
"trace_type,x,y,color",
126134
[
127135
(go.Scatter, "index", "_value_", "_column_"),
136+
(go.Histogram2dContour, "index", "_value_", "_column_"),
137+
(go.Histogram2d, "index", "_value_", None),
128138
(go.Bar, "index", "_value_", "_column_"),
139+
(go.Funnel, "index", "_value_", "_column_"),
129140
(go.Box, "_column_", "_value_", None),
130141
(go.Violin, "_column_", "_value_", None),
131142
(go.Histogram, "_value_", None, "_column_"),
@@ -145,40 +156,43 @@ def test_wide_mode_internal(trace_type, x, y, color, orientation):
145156
assert_frame_equal(
146157
df_out.sort_index(axis=1), pd.DataFrame(expected).sort_index(axis=1),
147158
)
148-
if orientation is None or orientation == "v":
149-
assert args_out == dict(x=x, y=y, color=color, orientation="v")
159+
if trace_type in [go.Histogram2dContour, go.Histogram2d]:
160+
if orientation is None or orientation == "v":
161+
assert args_out == dict(x=x, y=y, color=color)
162+
else:
163+
assert args_out == dict(x=y, y=x, color=color)
150164
else:
151-
assert args_out == dict(x=y, y=x, color=color, orientation="h")
165+
if (orientation is None and trace_type != go.Funnel) or orientation == "v":
166+
assert args_out == dict(x=x, y=y, color=color, orientation="v")
167+
else:
168+
assert args_out == dict(x=y, y=x, color=color, orientation="h")
152169

153170

154171
cases = []
155172
for transpose in [True, False]:
156-
for tt in [go.Scatter, go.Bar, go.Funnel]:
173+
for tt in [go.Scatter, go.Bar, go.Funnel, go.Histogram2dContour, go.Histogram2d]:
174+
color = None if tt == go.Histogram2d else "_column_"
157175
df_in = dict(a=[1, 2], b=[3, 4])
158176
args = dict(x=None, y=["a", "b"], color=None, orientation=None)
159177
df_exp = dict(
160178
_column_=["a", "a", "b", "b"], _value_=[1, 2, 3, 4], index=[0, 1, 0, 1],
161179
)
162-
cases.append(
163-
(tt, df_in, args, "index", "_value_", "_column_", df_exp, transpose)
164-
)
180+
cases.append((tt, df_in, args, "index", "_value_", color, df_exp, transpose))
165181

166182
df_in = dict(a=[1, 2], b=[3, 4], c=[5, 6])
167183
args = dict(x="c", y=["a", "b"], color=None, orientation=None)
168184
df_exp = dict(
169185
_column_=["a", "a", "b", "b"], _value_=[1, 2, 3, 4], c=[5, 6, 5, 6],
170186
)
171-
cases.append((tt, df_in, args, "c", "_value_", "_column_", df_exp, transpose))
187+
cases.append((tt, df_in, args, "c", "_value_", color, df_exp, transpose))
172188

173189
args = dict(x=None, y=[[1, 2], [3, 4]], color=None, orientation=None)
174190
df_exp = dict(
175191
_column_=["_column__0", "_column__0", "_column__1", "_column__1"],
176192
_value_=[1, 2, 3, 4],
177193
index=[0, 1, 0, 1],
178194
)
179-
cases.append(
180-
(tt, None, args, "index", "_value_", "_column_", df_exp, transpose)
181-
)
195+
cases.append((tt, None, args, "index", "_value_", color, df_exp, transpose))
182196

183197
for tt in [go.Bar]: # bar categorical exception
184198
df_in = dict(a=["q", "r"], b=["s", "t"])
@@ -242,13 +256,16 @@ def test_wide_x_or_y(tt, df_in, args_in, x, y, color, df_out_exp, transpose):
242256
args_out = build_dataframe(args_in, tt)
243257
df_out = args_out.pop("data_frame").sort_index(axis=1)
244258
assert_frame_equal(df_out, pd.DataFrame(df_out_exp).sort_index(axis=1))
245-
orientation_exp = args_in["orientation"]
246-
if (args_in["x"] is None) != (args_in["y"] is None) and tt != go.Histogram:
247-
orientation_exp = "h" if transpose else "v"
248259
if transpose:
249-
assert args_out == dict(x=y, y=x, color=color, orientation=orientation_exp)
260+
args_exp = dict(x=y, y=x, color=color)
250261
else:
251-
assert args_out == dict(x=x, y=y, color=color, orientation=orientation_exp)
262+
args_exp = dict(x=x, y=y, color=color)
263+
if tt not in [go.Histogram2dContour, go.Histogram2d]:
264+
orientation_exp = args_in["orientation"]
265+
if (args_in["x"] is None) != (args_in["y"] is None) and tt != go.Histogram:
266+
orientation_exp = "h" if transpose else "v"
267+
args_exp["orientation"] = orientation_exp
268+
assert args_out == args_exp
252269

253270

254271
@pytest.mark.parametrize("orientation", [None, "v", "h"])

0 commit comments

Comments
 (0)