Skip to content

Commit ea14fc9

Browse files
smarter x or y behaviour
1 parent e4071a9 commit ea14fc9

File tree

4 files changed

+171
-46
lines changed

4 files changed

+171
-46
lines changed

packages/python/plotly/plotly/express/_chart_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def scatter(
2525
animation_group=None,
2626
category_orders={},
2727
labels={},
28+
orientation=None,
2829
color_discrete_sequence=None,
2930
color_discrete_map={},
3031
color_continuous_scale=None,
@@ -192,6 +193,7 @@ def line(
192193
animation_group=None,
193194
category_orders={},
194195
labels={},
196+
orientation=None,
195197
color_discrete_sequence=None,
196198
color_discrete_map={},
197199
line_dash_sequence=None,

packages/python/plotly/plotly/express/_core.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -930,12 +930,10 @@ def build_dataframe(args, constructor):
930930

931931
df_input = args["data_frame"]
932932

933-
wide_mode = (
934-
df_provided
935-
and args.get("x", None) is None
936-
and args.get("y", None) is None
937-
and constructor in [go.Scatter, go.Bar, go.Violin, go.Box, go.Histogram]
938-
)
933+
no_x = args.get("x", None) is None
934+
no_y = args.get("y", None) is None
935+
wideable = [go.Scatter, go.Bar, go.Violin, go.Box, go.Histogram]
936+
wide_mode = df_provided and no_x and no_y and constructor in wideable
939937
wide_id_vars = set()
940938

941939
if wide_mode:
@@ -944,6 +942,17 @@ def build_dataframe(args, constructor):
944942
else:
945943
df_output = pd.DataFrame()
946944

945+
missing_bar_dim = None
946+
if constructor in [go.Scatter, go.Bar] and (no_x != no_y):
947+
for ax in ["x", "y"]:
948+
if args.get(ax, None) is None:
949+
args[ax] = df_input.index if df_provided else Range()
950+
if constructor == go.Scatter:
951+
if args["orientation"] is None:
952+
args["orientation"] = "v" if ax == "x" else "h"
953+
if constructor == go.Bar:
954+
missing_bar_dim = ax
955+
947956
# Initialize set of column names
948957
# These are reserved names
949958
if df_provided:
@@ -1088,12 +1097,27 @@ def build_dataframe(args, constructor):
10881097
args[field_name][i] = str(col_name)
10891098
wide_id_vars.add(str(col_name))
10901099

1091-
for col_name in constants:
1092-
df_output[col_name] = constants[col_name]
1100+
if missing_bar_dim and constructor == go.Bar:
1101+
# now that we've populated df_output, we check to see if the non-missing
1102+
# dimensio is categorical: if so, then setting the missing dimension to a
1103+
# constant 1 is a less-insane thing to do than setting it to the index by
1104+
# default and we let the normal auto-orientation-code do its thing later
1105+
other_dim = "x" if missing_bar_dim == "y" else "y"
1106+
if not _is_continuous(df_output, args[other_dim]):
1107+
args[missing_bar_dim] = missing_bar_dim
1108+
constants[missing_bar_dim] = 1
1109+
else:
1110+
# on the other hand, if the non-missing dimension is continuous, then we
1111+
# can use this information to override the normal auto-orientation code
1112+
if args["orientation"] is None:
1113+
args["orientation"] = "v" if missing_bar_dim == "x" else "h"
10931114

10941115
for col_name in ranges:
10951116
df_output[col_name] = range(len(df_output))
10961117

1118+
for col_name in constants:
1119+
df_output[col_name] = constants[col_name]
1120+
10971121
if wide_mode:
10981122
# TODO multi-level index
10991123
# TODO multi-level columns
@@ -1105,9 +1129,8 @@ def build_dataframe(args, constructor):
11051129
id_vars=wide_id_vars, var_name=var_name, value_name="_value_"
11061130
)
11071131
df_output[var_name] = df_output[var_name].astype(str)
1108-
orient_v = "v" == (args.get("orientation", None) or "v")
1109-
if "orientation" in args:
1110-
args["orientation"] = "v" if orient_v else "h"
1132+
args["orientation"] = args.get("orientation", None) or "v"
1133+
orient_v = args["orientation"] == "v"
11111134
if constructor in [go.Scatter, go.Bar]:
11121135
args["x" if orient_v else "y"] = index_name
11131136
args["y" if orient_v else "x"] = "_value_"

packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -393,33 +393,32 @@ def test_auto_orient():
393393
categorical = ["a", "a", "b", "b"]
394394
numerical = [1, 2, 3, 4]
395395

396-
pattern_x_or_y = [
397-
(numerical, None, "h"), # auto
398-
(categorical, None, "h"), # auto
399-
(None, categorical, "v"), # auto/default
400-
(None, numerical, "v"), # auto/default
401-
]
396+
auto_orientable = [px.scatter, px.line, px.area, px.violin, px.box, px.strip]
397+
auto_orientable += [px.bar, px.funnel, px.histogram]
402398

403399
pattern_x_and_y = [
404400
(numerical, categorical, "h"), # auto
405401
(categorical, numerical, "v"), # auto/default
406402
(categorical, categorical, "v"), # default
407403
(numerical, numerical, "v"), # default
408404
]
409-
410-
for fn in [px.violin, px.box, px.strip, px.bar, px.funnel]:
411-
for x, y, result in pattern_x_or_y:
405+
for fn in auto_orientable:
406+
for x, y, result in pattern_x_and_y:
412407
assert fn(x=x, y=y).data[0].orientation == result
413408

414-
# these ones are the opposite of the ones above in the "or" cases
415-
for fn in [px.area, px.histogram]:
416-
for x, y, result in pattern_x_or_y:
417-
assert fn(x=x, y=y).data[0].orientation != result
409+
pattern_x_or_y = [
410+
(numerical, None, "h"), # auto
411+
(categorical, None, "h"), # auto
412+
(None, categorical, "v"), # auto/default
413+
(None, numerical, "v"), # auto/default
414+
]
418415

419-
# all behave the same for the "and" cases
420-
for fn in [px.violin, px.box, px.strip, px.bar, px.funnel, px.area, px.histogram]:
421-
for x, y, result in pattern_x_and_y:
422-
assert fn(x=x, y=y).data[0].orientation == result
416+
for fn in auto_orientable:
417+
for x, y, result in pattern_x_or_y:
418+
if fn == px.histogram or (fn == px.bar and categorical in [x, y]):
419+
assert fn(x=x, y=y).data[0].orientation != result
420+
else:
421+
assert fn(x=x, y=y).data[0].orientation == result
423422

424423
assert px.histogram(x=numerical, nbins=5).data[0].nbinsx == 5
425424
assert px.histogram(y=numerical, nbins=5).data[0].nbinsy == 5
@@ -465,3 +464,69 @@ def test_auto_boxlike_overlay():
465464
for fn, mode in fn_and_mode:
466465
for x, y, color, result in pattern:
467466
assert fn(df, x=x, y=y, color=color).layout[mode] == result
467+
468+
469+
def test_x_or_y():
470+
categorical = ["a", "a", "b", "b"]
471+
numerical = [1, 2, 3, 4]
472+
constant = [1, 1, 1, 1]
473+
range_4 = [0, 1, 2, 3]
474+
index = [11, 12, 13, 14]
475+
numerical_df = pd.DataFrame(dict(col=numerical), index=index)
476+
categorical_df = pd.DataFrame(dict(col=categorical), index=index)
477+
scatter_like = [px.scatter, px.line, px.area]
478+
bar_like = [px.bar]
479+
480+
for fn in scatter_like + bar_like:
481+
fig = fn(x=numerical)
482+
assert list(fig.data[0].x) == numerical
483+
assert list(fig.data[0].y) == range_4
484+
assert fig.data[0].orientation == "h"
485+
fig = fn(y=numerical)
486+
assert list(fig.data[0].x) == range_4
487+
assert list(fig.data[0].y) == numerical
488+
assert fig.data[0].orientation == "v"
489+
fig = fn(numerical_df, x="col")
490+
assert list(fig.data[0].x) == numerical
491+
assert list(fig.data[0].y) == index
492+
assert fig.data[0].orientation == "h"
493+
fig = fn(numerical_df, y="col")
494+
assert list(fig.data[0].x) == index
495+
assert list(fig.data[0].y) == numerical
496+
assert fig.data[0].orientation == "v"
497+
498+
for fn in scatter_like:
499+
fig = fn(x=categorical)
500+
assert list(fig.data[0].x) == categorical
501+
assert list(fig.data[0].y) == range_4
502+
assert fig.data[0].orientation == "h"
503+
fig = fn(y=categorical)
504+
assert list(fig.data[0].x) == range_4
505+
assert list(fig.data[0].y) == categorical
506+
assert fig.data[0].orientation == "v"
507+
fig = fn(categorical_df, x="col")
508+
assert list(fig.data[0].x) == categorical
509+
assert list(fig.data[0].y) == index
510+
assert fig.data[0].orientation == "h"
511+
fig = fn(categorical_df, y="col")
512+
assert list(fig.data[0].x) == index
513+
assert list(fig.data[0].y) == categorical
514+
assert fig.data[0].orientation == "v"
515+
516+
for fn in bar_like:
517+
fig = fn(x=categorical)
518+
assert list(fig.data[0].x) == categorical
519+
assert list(fig.data[0].y) == constant
520+
assert fig.data[0].orientation == "v"
521+
fig = fn(y=categorical)
522+
assert list(fig.data[0].x) == constant
523+
assert list(fig.data[0].y) == categorical
524+
assert fig.data[0].orientation == "h"
525+
fig = fn(categorical_df, x="col")
526+
assert list(fig.data[0].x) == categorical
527+
assert list(fig.data[0].y) == constant
528+
assert fig.data[0].orientation == "v"
529+
fig = fn(categorical_df, y="col")
530+
assert list(fig.data[0].x) == constant
531+
assert list(fig.data[0].y) == categorical
532+
assert fig.data[0].orientation == "h"

0 commit comments

Comments
 (0)