Skip to content

Commit 7b022f1

Browse files
funnel is wideable
1 parent 9524d94 commit 7b022f1

File tree

2 files changed

+29
-21
lines changed

2 files changed

+29
-21
lines changed

packages/python/plotly/plotly/express/_core.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -424,14 +424,6 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
424424

425425
def configure_axes(args, constructor, fig, orders):
426426
configurators = {
427-
go.Scatter: configure_cartesian_axes,
428-
go.Scattergl: configure_cartesian_axes,
429-
go.Bar: configure_cartesian_axes,
430-
go.Box: configure_cartesian_axes,
431-
go.Violin: configure_cartesian_axes,
432-
go.Histogram: configure_cartesian_axes,
433-
go.Histogram2dContour: configure_cartesian_axes,
434-
go.Histogram2d: configure_cartesian_axes,
435427
go.Scatter3d: configure_3d_axes,
436428
go.Scatterternary: configure_ternary_axes,
437429
go.Scatterpolar: configure_polar_axes,
@@ -443,6 +435,10 @@ def configure_axes(args, constructor, fig, orders):
443435
go.Scattergeo: configure_geo,
444436
go.Choropleth: configure_geo,
445437
}
438+
cartesians = [go.Scatter, go.Scattergl, go.Bar, go.Funnel, go.Box, go.Violin]
439+
cartesians += [go.Histogram, go.Histogram2d, go.Histogram2dContour]
440+
for c in cartesians:
441+
configurators[c] = configure_cartesian_axes
446442
if constructor in configurators:
447443
configurators[constructor](args, fig, orders)
448444

@@ -1134,7 +1130,7 @@ def build_dataframe(args, constructor):
11341130

11351131
wide_mode = False
11361132
var_name = None
1137-
if constructor in [go.Scatter, go.Bar, go.Violin, go.Box, go.Histogram]:
1133+
if constructor in [go.Scatter, go.Bar, go.Violin, go.Box, go.Histogram, go.Funnel]:
11381134
wide_cross_name = None
11391135
if wide_x and wide_y:
11401136
raise ValueError(
@@ -1144,7 +1140,10 @@ def build_dataframe(args, constructor):
11441140
wide_mode = True
11451141
args["_column_"] = list(df_input.columns)
11461142
var_name = df_input.columns.name or "_column_"
1147-
wide_orientation = args.get("orientation", None) or "v"
1143+
if constructor == go.Funnel:
1144+
wide_orientation = args.get("orientation", None) or "h"
1145+
else:
1146+
wide_orientation = args.get("orientation", None) or "v"
11481147
args["orientation"] = wide_orientation
11491148
args["wide_cross"] = None
11501149
elif wide_x != wide_y:
@@ -1161,17 +1160,19 @@ def build_dataframe(args, constructor):
11611160
wide_cross_name = "__x__" if wide_y else "__y__"
11621161

11631162
missing_bar_dim = None
1164-
if constructor in [go.Scatter, go.Bar]:
1163+
if constructor in [go.Scatter, go.Bar, go.Funnel]:
11651164
if not wide_mode and (no_x != no_y):
11661165
for ax in ["x", "y"]:
11671166
if args.get(ax, None) is None:
11681167
args[ax] = df_input.index if df_provided else Range()
1169-
if constructor == go.Scatter:
1170-
if args["orientation"] is None:
1171-
args["orientation"] = "v" if ax == "x" else "h"
11721168
if constructor == go.Bar:
11731169
missing_bar_dim = ax
1170+
else:
1171+
if args["orientation"] is None:
1172+
args["orientation"] = "v" if ax == "x" else "h"
11741173
if wide_mode and wide_cross_name is None:
1174+
if no_x != no_y and args["orientation"] is None:
1175+
args["orientation"] = "v" if no_x else "h"
11751176
if df_provided:
11761177
args["wide_cross"] = df_input.index
11771178
wide_cross_name = df_input.index.name or "index"
@@ -1222,7 +1223,7 @@ def build_dataframe(args, constructor):
12221223
if wide_cross_name == "__y__":
12231224
wide_cross_name = args["y"]
12241225

1225-
if constructor == go.Scatter:
1226+
if constructor in [go.Scatter, go.Funnel]:
12261227
args["x" if orient_v else "y"] = wide_cross_name
12271228
args["y" if orient_v else "x"] = "_value_"
12281229
args["color"] = args["color"] or var_name

packages/python/plotly/plotly/tests/test_core/test_px/test_px_wide.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def test_is_col_list():
4747

4848
@pytest.mark.parametrize(
4949
"px_fn",
50-
[px.scatter, px.line, px.area, px.bar, px.violin, px.box, px.strip, px.histogram],
50+
[px.scatter, px.line, px.area, px.bar, px.violin, px.box, px.strip]
51+
+ [px.histogram, px.funnel],
5152
)
5253
@pytest.mark.parametrize("orientation", [None, "v", "h"])
5354
@pytest.mark.parametrize("style", ["implicit", "explicit"])
@@ -56,14 +57,17 @@ def test_wide_mode_external(px_fn, orientation, style):
5657
# inspecting the figure... this is important but clunky, and is mostly a smoke test
5758
# allowing us to do more "white box" testing below
5859

59-
x, y = ("y", "x") if orientation == "h" else ("x", "y")
60+
if px_fn != px.funnel:
61+
x, y = ("y", "x") if orientation == "h" else ("x", "y")
62+
else:
63+
x, y = ("y", "x") if orientation != "v" else ("x", "y")
6064
xaxis, yaxis = x + "axis", y + "axis"
6165

6266
df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6], c=[7, 8, 9]), index=[11, 12, 13])
6367
if style == "implicit":
6468
fig = px_fn(df, orientation=orientation)
6569

66-
if px_fn in [px.scatter, px.line, px.area, px.bar]:
70+
if px_fn in [px.scatter, px.line, px.area, px.bar, px.funnel]:
6771
if style == "explicit":
6872
fig = px_fn(**{"data_frame": df, y: list(df.columns), x: df.index})
6973
assert len(fig.data) == 3
@@ -149,7 +153,7 @@ def test_wide_mode_internal(trace_type, x, y, color, orientation):
149153

150154
cases = []
151155
for transpose in [True, False]:
152-
for tt in [go.Scatter, go.Bar]:
156+
for tt in [go.Scatter, go.Bar, go.Funnel]:
153157
df_in = dict(a=[1, 2], b=[3, 4])
154158
args = dict(x=None, y=["a", "b"], color=None, orientation=None)
155159
df_exp = dict(
@@ -238,10 +242,13 @@ def test_wide_x_or_y(tt, df_in, args_in, x, y, color, df_out_exp, transpose):
238242
args_out = build_dataframe(args_in, tt)
239243
df_out = args_out.pop("data_frame").sort_index(axis=1)
240244
assert_frame_equal(df_out, pd.DataFrame(df_out_exp).sort_index(axis=1))
245+
orientation_exp = args_in["orientation"]
246+
if (args_in["x"] is None) != (args_in["y"] is None) and tt != go.Histogram:
247+
orientation_exp = "h" if transpose else "v"
241248
if transpose:
242-
assert args_out == dict(x=y, y=x, color=color, orientation=None)
249+
assert args_out == dict(x=y, y=x, color=color, orientation=orientation_exp)
243250
else:
244-
assert args_out == dict(x=x, y=y, color=color, orientation=None)
251+
assert args_out == dict(x=x, y=y, color=color, orientation=orientation_exp)
245252

246253

247254
@pytest.mark.parametrize("orientation", [None, "v", "h"])

0 commit comments

Comments
 (0)