Skip to content

Commit 977dbcf

Browse files
wip wide_y
1 parent f2a0079 commit 977dbcf

File tree

3 files changed

+145
-75
lines changed

3 files changed

+145
-75
lines changed

packages/python/plotly/plotly/express/_core.py

Lines changed: 84 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
+ ["ids", "error_x", "error_x_minus", "error_y", "error_y_minus", "error_z"]
2424
+ ["error_z_minus", "lat", "lon", "locations", "animation_group"]
2525
)
26-
array_attrables = ["dimensions", "custom_data", "hover_data", "path", "wide_cols"]
26+
array_attrables = ["dimensions", "custom_data", "hover_data", "path", "_column_"]
2727
group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"]
2828
renameable_group_attrables = [
2929
"color", # renamed to marker.color or line.color in infer_config
@@ -913,6 +913,27 @@ def _get_reserved_col_names(args):
913913
return reserved_names
914914

915915

916+
def _is_col_list(df_input, arg):
917+
if arg is None or isinstance(arg, str) or isinstance(arg, int):
918+
return False
919+
if isinstance(arg, pd.MultiIndex):
920+
return False # just to keep existing behaviour for now
921+
try:
922+
iter(arg)
923+
except TypeError:
924+
return False # not iterable
925+
for c in arg:
926+
if isinstance(c, str) or isinstance(c, int):
927+
if df_input is None or c not in df_input.columns:
928+
return False
929+
else:
930+
try:
931+
iter(c)
932+
except TypeError:
933+
return False # not iterable
934+
return True
935+
936+
916937
def build_dataframe(args, constructor):
917938
"""
918939
Constructs a dataframe and modifies `args` in-place.
@@ -946,60 +967,60 @@ def build_dataframe(args, constructor):
946967

947968
no_x = args.get("x", None) is None
948969
no_y = args.get("y", None) is None
949-
wideable = [go.Scatter, go.Bar, go.Violin, go.Box, go.Histogram]
950-
wide_mode = df_provided and no_x and no_y and constructor in wideable
951-
wide_id_vars = set()
970+
wide_x = False if no_x else _is_col_list(df_input, args["x"])
971+
wide_y = False if no_y else _is_col_list(df_input, args["y"])
952972

953-
if wide_mode:
954-
# currently assuming that df_provided == True
955-
args["wide_cols"] = list(df_input.columns)
956-
args["wide_cross"] = df_input.index
957-
var_name = df_input.columns.name or "_column_"
958-
wide_orientation = args.get("orientation", None) or "v"
959-
args["orientation"] = wide_orientation
973+
wide_mode = False
974+
if constructor in [go.Scatter, go.Bar, go.Violin, go.Box, go.Histogram]:
975+
wide_cross_name = None
976+
if wide_x and wide_y:
977+
raise ValueError(
978+
"Cannot accept list of column references or list of columns for both `x` and `y`."
979+
)
980+
if df_provided and no_x and no_y:
981+
wide_mode = True
982+
args["_column_"] = list(df_input.columns)
983+
var_name = df_input.columns.name or "_column_"
984+
wide_orientation = args.get("orientation", None) or "v"
985+
args["orientation"] = wide_orientation
986+
args["wide_cross"] = None
987+
elif wide_x != wide_y:
988+
wide_mode = True
989+
args["_column_"] = args["y"] if wide_y else args["x"]
990+
var_name = "_column_"
991+
if constructor == go.Histogram:
992+
wide_orientation = "v" if wide_x else "h"
993+
else:
994+
wide_orientation = "v" if wide_y else "h"
995+
args["y" if wide_y else "x"] = None
996+
args["wide_cross"] = None
997+
if not no_x and not no_y:
998+
wide_cross_name = "__x__" if wide_y else "__y__"
960999

961-
"""
962-
wide_x detection
963-
- if scalar = False
964-
- else if list of lists = True
965-
- else if not df_provided = False
966-
- else if contents are unique and are contained in columns = True
967-
- else = False
968-
969-
970-
wide detection:
971-
- if no_x and no_y = wide mode
972-
- else if wide_x and wide_y = error
973-
- else if wide_x xor wide_y = wide mode
974-
- else = long mode
975-
976-
so what we want is:
977-
- y = [col col] -> melt just those, wide_orientation = 'v'/no override, cross_dim = index or range
978-
- y = [col col] / x=col -> wide_orientation = 'h'/no override, cross_dim = x
979-
- y = [col col] / x=[col col] -> error
980-
981-
need to merge wide logic into no_x/no_y logic below for range() etc
982-
"""
1000+
missing_bar_dim = None
1001+
if constructor in [go.Scatter, go.Bar]:
1002+
if not wide_mode and (no_x != no_y):
1003+
for ax in ["x", "y"]:
1004+
if args.get(ax, None) is None:
1005+
args[ax] = df_input.index if df_provided else Range()
1006+
if constructor == go.Scatter:
1007+
if args["orientation"] is None:
1008+
args["orientation"] = "v" if ax == "x" else "h"
1009+
if constructor == go.Bar:
1010+
missing_bar_dim = ax
1011+
if wide_mode and wide_cross_name is None:
1012+
if df_provided:
1013+
args["wide_cross"] = df_input.index
1014+
wide_cross_name = df_input.index.name or "index"
1015+
else:
1016+
args["wide_cross"] = Range(label="index")
1017+
wide_cross_name = "index"
9831018

9841019
df_output = pd.DataFrame()
985-
986-
missing_bar_dim = None
987-
if constructor in [go.Scatter, go.Bar] and (no_x != no_y):
988-
for ax in ["x", "y"]:
989-
if args.get(ax, None) is None:
990-
args[ax] = df_input.index if df_provided else Range()
991-
if constructor == go.Scatter:
992-
if args["orientation"] is None:
993-
args["orientation"] = "v" if ax == "x" else "h"
994-
if constructor == go.Bar:
995-
missing_bar_dim = ax
996-
997-
# Initialize set of column names
998-
# These are reserved names
999-
if df_provided:
1000-
reserved_names = _get_reserved_col_names(args)
1001-
else:
1002-
reserved_names = set()
1020+
constants = dict()
1021+
ranges = list()
1022+
wide_id_vars = set()
1023+
reserved_names = _get_reserved_col_names(args) if df_provided else set()
10031024

10041025
# Case of functions with a "dimensions" kw: scatter_matrix, parcats, parcoords
10051026
if "dimensions" in args and args["dimensions"] is None:
@@ -1010,8 +1031,6 @@ def build_dataframe(args, constructor):
10101031
else:
10111032
df_output[df_input.columns] = df_input[df_input.columns]
10121033

1013-
constants = dict()
1014-
ranges = list()
10151034

10161035
# Loop over possible arguments
10171036
for field_name in all_attrables:
@@ -1136,10 +1155,10 @@ def build_dataframe(args, constructor):
11361155
args[field_name] = str(col_name)
11371156
else:
11381157
args[field_name][i] = str(col_name)
1139-
if field_name != "wide_cols":
1158+
if field_name != "_column_":
11401159
wide_id_vars.add(str(col_name))
11411160

1142-
if missing_bar_dim and constructor == go.Bar:
1161+
if not wide_mode and missing_bar_dim and constructor == go.Bar:
11431162
# now that we've populated df_output, we check to see if the non-missing
11441163
# dimension is categorical: if so, then setting the missing dimension to a
11451164
# constant 1 is a less-insane thing to do than setting it to the index by
@@ -1161,9 +1180,8 @@ def build_dataframe(args, constructor):
11611180
df_output[col_name] = constants[col_name]
11621181

11631182
if wide_mode:
1164-
wide_value_vars = [c for c in args["wide_cols"] if c not in wide_id_vars]
1165-
del args["wide_cols"]
1166-
wide_cross = args["wide_cross"]
1183+
wide_value_vars = [c for c in args["_column_"] if c not in wide_id_vars]
1184+
del args["_column_"]
11671185
del args["wide_cross"]
11681186
df_output = df_output.melt(
11691187
id_vars=wide_id_vars,
@@ -1173,14 +1191,18 @@ def build_dataframe(args, constructor):
11731191
)
11741192
df_output[var_name] = df_output[var_name].astype(str)
11751193
orient_v = wide_orientation == "v"
1194+
if wide_cross_name == "__x__":
1195+
wide_cross_name = args["x"]
1196+
if wide_cross_name == "__y__":
1197+
wide_cross_name = args["y"]
11761198

11771199
if constructor == go.Scatter:
1178-
args["x" if orient_v else "y"] = wide_cross
1200+
args["x" if orient_v else "y"] = wide_cross_name
11791201
args["y" if orient_v else "x"] = "_value_"
11801202
args["color"] = args["color"] or var_name
11811203
if constructor == go.Bar:
11821204
if _is_continuous(df_output, "_value_"):
1183-
args["x" if orient_v else "y"] = wide_cross
1205+
args["x" if orient_v else "y"] = wide_cross_name
11841206
args["y" if orient_v else "x"] = "_value_"
11851207
args["color"] = args["color"] or var_name
11861208
else:
@@ -1189,10 +1211,11 @@ def build_dataframe(args, constructor):
11891211
df_output["_count_"] = 1
11901212
args["color"] = args["color"] or var_name
11911213
if constructor in [go.Violin, go.Box]:
1192-
args["x" if orient_v else "y"] = var_name
1214+
args["x" if orient_v else "y"] = wide_cross_name or var_name
11931215
args["y" if orient_v else "x"] = "_value_"
11941216
if constructor == go.Histogram:
11951217
args["x" if orient_v else "y"] = "_value_"
1218+
args["y" if orient_v else "x"] = wide_cross_name
11961219
args["color"] = args["color"] or var_name
11971220

11981221
args["data_frame"] = df_output

packages/python/plotly/plotly/tests/test_core/test_px/__init__.py

Whitespace-only changes.

packages/python/plotly/plotly/tests/test_core/test_px/test_px_wide.py

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,50 @@
11
import plotly.express as px
22
import plotly.graph_objects as go
33
import pandas as pd
4-
from plotly.express._core import build_dataframe
4+
from plotly.express._core import build_dataframe, _is_col_list
55
from pandas.testing import assert_frame_equal
66
import pytest
77

88

9+
def test_is_col_list():
10+
df_input = pd.DataFrame(dict(a=[1, 2], b=[1, 2]))
11+
assert _is_col_list(df_input, ["a"])
12+
assert _is_col_list(df_input, ["a", "b"])
13+
assert _is_col_list(df_input, [[3, 4]])
14+
assert _is_col_list(df_input, [[3, 4], [3, 4]])
15+
assert not _is_col_list(df_input, pytest)
16+
assert not _is_col_list(df_input, False)
17+
assert not _is_col_list(df_input, ["a", 1])
18+
assert not _is_col_list(df_input, "a")
19+
assert not _is_col_list(df_input, 1)
20+
assert not _is_col_list(df_input, ["a", "b", "c"])
21+
assert not _is_col_list(df_input, [1, 2])
22+
df_input = pd.DataFrame([[1, 2], [1, 2]])
23+
assert _is_col_list(df_input, [0])
24+
assert _is_col_list(df_input, [0, 1])
25+
assert _is_col_list(df_input, [[3, 4]])
26+
assert _is_col_list(df_input, [[3, 4], [3, 4]])
27+
assert not _is_col_list(df_input, pytest)
28+
assert not _is_col_list(df_input, False)
29+
assert not _is_col_list(df_input, ["a", 1])
30+
assert not _is_col_list(df_input, "a")
31+
assert not _is_col_list(df_input, 1)
32+
assert not _is_col_list(df_input, [0, 1, 2])
33+
assert not _is_col_list(df_input, ["a", "b"])
34+
df_input = None
35+
assert _is_col_list(df_input, [[3, 4]])
36+
assert _is_col_list(df_input, [[3, 4], [3, 4]])
37+
assert not _is_col_list(df_input, [0])
38+
assert not _is_col_list(df_input, [0, 1])
39+
assert not _is_col_list(df_input, pytest)
40+
assert not _is_col_list(df_input, False)
41+
assert not _is_col_list(df_input, ["a", 1])
42+
assert not _is_col_list(df_input, "a")
43+
assert not _is_col_list(df_input, 1)
44+
assert not _is_col_list(df_input, [0, 1, 2])
45+
assert not _is_col_list(df_input, ["a", "b"])
46+
47+
948
def test_wide_mode_external():
1049
# here we test this feature "black box" style by calling actual PX functions and
1150
# inspecting the figure... this is important but clunky, and is mostly a smoke test
@@ -101,25 +140,35 @@ def test_wide_mode_internal(trace_type, x, y, color, orientation):
101140
args_in = dict(data_frame=df_in, color=None, orientation=orientation)
102141
args_out = build_dataframe(args_in, trace_type)
103142
df_out = args_out.pop("data_frame")
143+
expected = dict(
144+
_column_=["a", "a", "a", "b", "b", "b"], _value_=[1, 2, 3, 4, 5, 6],
145+
)
146+
if x == "index":
147+
expected["index"] = [11, 12, 13, 11, 12, 13]
104148
assert_frame_equal(
105-
df_out.sort_index(axis=1),
106-
pd.DataFrame(
107-
dict(
108-
index=[11, 12, 13, 11, 12, 13],
109-
_column_=["a", "a", "a", "b", "b", "b"],
110-
_value_=[1, 2, 3, 4, 5, 6],
111-
)
112-
).sort_index(axis=1),
149+
df_out.sort_index(axis=1), pd.DataFrame(expected).sort_index(axis=1),
113150
)
114-
for arg in ["x", "y"]:
115-
if arg not in args_out:
116-
args_out[arg] = None # so this doesn't fail for histogram
117151
if orientation is None or orientation == "v":
118152
assert args_out == dict(x=x, y=y, color=color, orientation="v")
119153
else:
120154
assert args_out == dict(x=y, y=x, color=color, orientation="h")
121155

122156

157+
def test_wide_x_or_y():
158+
args_in = dict(data_frame=None, y=[[1, 2], [3, 4]], color=None, orientation=None)
159+
args_out = build_dataframe(args_in, go.Scatter)
160+
df_out = args_out.pop("data_frame")
161+
expected = dict(
162+
_column_=["_column__0", "_column__0", "_column__1", "_column__1"],
163+
_value_=[1, 2, 3, 4],
164+
# x=["a", "b", "a", "b"],
165+
index=[0, 1, 0, 1],
166+
)
167+
assert_frame_equal(
168+
df_out.sort_index(axis=1), pd.DataFrame(expected).sort_index(axis=1),
169+
)
170+
171+
123172
@pytest.mark.parametrize(
124173
"orientation", [None, "v", "h"],
125174
)
@@ -159,8 +208,6 @@ def assert_df_and_args(df_in, args_in, args_expect, df_expect):
159208
args_in["data_frame"] = df_in
160209
args_out = build_dataframe(args_in, go.Scatter)
161210
df_out = args_out.pop("data_frame")
162-
# print(df_out.info())
163-
# print(df_expect.info())
164211
assert_frame_equal(
165212
df_out.sort_index(axis=1), df_expect.sort_index(axis=1),
166213
)

0 commit comments

Comments
 (0)