Skip to content

Commit e78cb50

Browse files
lock down edge cases around name collisions
1 parent f733831 commit e78cb50

File tree

3 files changed

+101
-53
lines changed

3 files changed

+101
-53
lines changed

packages/python/plotly/plotly/express/_core.py

Lines changed: 53 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -899,8 +899,8 @@ def _check_name_not_reserved(field_name, reserved_names):
899899
return field_name
900900
else:
901901
raise NameError(
902-
"A name conflict was encountered for argument %s. "
903-
"A column with name %s is already used." % (field_name, field_name)
902+
"A name conflict was encountered for argument '%s'. "
903+
"A column or index with name '%s' is ambiguous." % (field_name, field_name)
904904
)
905905

906906

@@ -929,6 +929,8 @@ def _get_reserved_col_names(args):
929929
in_df = arg is df[arg_name]
930930
if in_df:
931931
reserved_names.add(arg_name)
932+
elif arg is df.index and arg.name is not None:
933+
reserved_names.add(arg.name)
932934

933935
return reserved_names
934936

@@ -970,8 +972,8 @@ def _isinstance_listlike(x):
970972
return True
971973

972974

973-
def _escape_col_name(df_input, col_name):
974-
while df_input is not None and col_name in df_input.columns:
975+
def _escape_col_name(df_input, col_name, extra):
976+
while df_input is not None and (col_name in df_input.columns or col_name in extra):
975977
col_name = "_" + col_name
976978
return col_name
977979

@@ -1040,6 +1042,7 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
10401042
length = len(df_output)
10411043
if argument is None:
10421044
continue
1045+
col_name = None
10431046
# Case of multiindex
10441047
if isinstance(argument, pd.MultiIndex):
10451048
raise TypeError(
@@ -1107,31 +1110,25 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
11071110
df_output[col_name] = df_input[argument].values
11081111
# ----------------- argument is a column / array / list.... -------
11091112
else:
1110-
is_index = isinstance(argument, pd.Index)
1111-
# First pandas
1112-
# pandas series have a name but it's None
1113-
if (
1114-
hasattr(argument, "name") and argument.name is not None
1115-
) or is_index:
1116-
col_name = argument.name # pandas df
1117-
if col_name is None and is_index:
1118-
col_name = "index"
1119-
if not df_provided:
1120-
col_name = field
1121-
else:
1122-
if is_index:
1123-
keep_name = df_provided and argument is df_input.index
1113+
if df_provided and hasattr(argument, "name"):
1114+
if argument is df_input.index:
1115+
if argument.name is None or argument.name in df_input:
1116+
col_name = "index"
11241117
else:
1125-
keep_name = (
1126-
col_name in df_input and argument is df_input[col_name]
1127-
)
1128-
col_name = (
1129-
col_name
1130-
if keep_name
1131-
else _check_name_not_reserved(field, reserved_names)
1118+
col_name = argument.name
1119+
col_name = _escape_col_name(
1120+
df_input, col_name, [var_name, value_name]
11321121
)
1133-
else: # numpy array, list...
1122+
else:
1123+
if (
1124+
argument.name is not None
1125+
and argument.name in df_input
1126+
and argument is df_input[argument.name]
1127+
):
1128+
col_name = argument.name
1129+
if col_name is None: # numpy array, list...
11341130
col_name = _check_name_not_reserved(field, reserved_names)
1131+
11351132
if length and len(argument) != length:
11361133
raise ValueError(
11371134
"All arguments should have the same length. "
@@ -1145,6 +1142,12 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
11451142
df_output[str(col_name)] = np.array(argument)
11461143

11471144
# Finally, update argument with column name now that column exists
1145+
assert col_name is not None, (
1146+
"Data-frame processing failure, likely due to a internal bug. "
1147+
"Please report this to "
1148+
"https://github.com/plotly/plotly.py/issues/new and we will try to "
1149+
"replicate and fix it."
1150+
)
11481151
if field_name not in array_attrables:
11491152
args[field_name] = str(col_name)
11501153
elif isinstance(args[field_name], dict):
@@ -1204,7 +1207,7 @@ def build_dataframe(args, constructor):
12041207
wide_mode = False
12051208
var_name = None # will likely be "variable" in wide_mode
12061209
wide_cross_name = None # will likely be "index" in wide_mode
1207-
value_name = "value"
1210+
value_name = None # will likely be "value" in wide_mode
12081211
hist2d_types = [go.Histogram2d, go.Histogram2dContour]
12091212
if constructor in cartesians:
12101213
if wide_x and wide_y:
@@ -1220,7 +1223,9 @@ def build_dataframe(args, constructor):
12201223
"at the moment."
12211224
)
12221225
args["wide_variable"] = list(df_input.columns)
1223-
var_name = df_input.columns.name or "variable"
1226+
var_name = df_input.columns.name
1227+
if var_name in [None, "value", "index"] or var_name in df_input:
1228+
var_name = "variable"
12241229
if constructor == go.Funnel:
12251230
wide_orientation = args.get("orientation", None) or "h"
12261231
else:
@@ -1240,6 +1245,10 @@ def build_dataframe(args, constructor):
12401245
if not no_x and not no_y:
12411246
wide_cross_name = "__x__" if wide_y else "__y__"
12421247

1248+
if wide_mode:
1249+
value_name = _escape_col_name(df_input, "value", [])
1250+
var_name = _escape_col_name(df_input, var_name, [])
1251+
12431252
missing_bar_dim = None
12441253
if constructor in [go.Scatter, go.Bar, go.Funnel] + hist2d_types:
12451254
if not wide_mode and (no_x != no_y):
@@ -1262,14 +1271,10 @@ def build_dataframe(args, constructor):
12621271
"at the moment."
12631272
)
12641273
args["wide_cross"] = df_input.index
1265-
wide_cross_name = df_input.index.name or "index"
12661274
else:
1267-
wide_cross_name = _escape_col_name(df_input, "index")
1268-
args["wide_cross"] = Range(label=wide_cross_name)
1269-
1270-
if wide_mode:
1271-
var_name = _escape_col_name(df_input, var_name)
1272-
value_name = _escape_col_name(df_input, value_name)
1275+
args["wide_cross"] = Range(
1276+
label=_escape_col_name(df_input, "index", [var_name, value_name])
1277+
)
12731278

12741279
# now that things have been prepped, we do the systematic rewriting of `args`
12751280

@@ -1281,7 +1286,7 @@ def build_dataframe(args, constructor):
12811286
# the special-case and wide-mode handling by further rewriting args and/or mutating
12821287
# df_output
12831288

1284-
count_name = _escape_col_name(df_output, "count")
1289+
count_name = _escape_col_name(df_output, "count", [var_name, value_name])
12851290
if not wide_mode and missing_bar_dim and constructor == go.Bar:
12861291
# now that we've populated df_output, we check to see if the non-missing
12871292
# dimension is categorical: if so, then setting the missing dimension to a
@@ -1306,19 +1311,27 @@ def build_dataframe(args, constructor):
13061311
# columns, keeping track of various names and manglings set up above
13071312
wide_value_vars = [c for c in args["wide_variable"] if c not in wide_id_vars]
13081313
del args["wide_variable"]
1314+
if wide_cross_name == "__x__":
1315+
wide_cross_name = args["x"]
1316+
elif wide_cross_name == "__y__":
1317+
wide_cross_name = args["y"]
1318+
else:
1319+
wide_cross_name = args["wide_cross"]
13091320
del args["wide_cross"]
13101321
df_output = df_output.melt(
13111322
id_vars=wide_id_vars,
13121323
value_vars=wide_value_vars,
13131324
var_name=var_name,
13141325
value_name=value_name,
13151326
)
1327+
assert len(df_output.columns) == len(set(df_output.columns)), (
1328+
"Wide-mode name-inference failure, likely due to a internal bug. "
1329+
"Please report this to "
1330+
"https://github.com/plotly/plotly.py/issues/new and we will try to "
1331+
"replicate and fix it."
1332+
)
13161333
df_output[var_name] = df_output[var_name].astype(str)
13171334
orient_v = wide_orientation == "v"
1318-
if wide_cross_name == "__x__":
1319-
wide_cross_name = args["x"]
1320-
if wide_cross_name == "__y__":
1321-
wide_cross_name = args["y"]
13221335

13231336
if constructor in [go.Scatter, go.Funnel] + hist2d_types:
13241337
args["x" if orient_v else "y"] = wide_cross_name

packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ def test_several_dataframes():
7373
# Name conflict
7474
with pytest.raises(NameError) as err_msg:
7575
fig = px.scatter(df, x="z", y=df2.money, size="y")
76-
assert "A name conflict was encountered for argument y" in str(err_msg.value)
76+
assert "A name conflict was encountered for argument 'y'" in str(err_msg.value)
7777
with pytest.raises(NameError) as err_msg:
7878
fig = px.scatter(df, x="z", y=df2.money, size=df.y)
79-
assert "A name conflict was encountered for argument y" in str(err_msg.value)
79+
assert "A name conflict was encountered for argument 'y'" in str(err_msg.value)
8080

8181
# No conflict when the dataframe is not given, fields are used
8282
df = pd.DataFrame(dict(x=[0, 1], y=[3, 4]))

packages/python/plotly/plotly/tests/test_core/test_px/test_px_wide.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -618,16 +618,61 @@ def append_special_case(df_in, args_in, args_expect, df_expect):
618618
),
619619
)
620620

621+
# df has columns named after every special string
622+
df = pd.DataFrame(dict(variable=[1, 2], index=[3, 4], value=[5, 6]), index=[7, 8])
623+
append_special_case(
624+
df_in=df,
625+
args_in=dict(x=None, y=None, color=None),
626+
args_expect=dict(x="_index", y="_value", color="_variable", orientation="v",),
627+
df_expect=pd.DataFrame(
628+
dict(
629+
_index=[7, 8, 7, 8, 7, 8],
630+
_value=[1, 2, 3, 4, 5, 6],
631+
_variable=["variable", "variable", "index", "index", "value", "value"],
632+
)
633+
),
634+
)
635+
636+
# df has columns with name collisions with indexes
637+
df = pd.DataFrame(dict(a=[1, 2], b=[3, 4]), index=[7, 8])
638+
df.index.name = "a"
639+
df.columns.name = "b"
640+
append_special_case(
641+
df_in=df,
642+
args_in=dict(x=None, y=None, color=None),
643+
args_expect=dict(x="index", y="value", color="variable", orientation="v",),
644+
df_expect=pd.DataFrame(
645+
dict(index=[7, 8, 7, 8], value=[1, 2, 3, 4], variable=["a", "a", "b", "b"],)
646+
),
647+
)
648+
649+
# everything is called value, OMG
650+
df = pd.DataFrame(dict(value=[1, 2], b=[3, 4]), index=[7, 8])
651+
df.index.name = "value"
652+
df.columns.name = "value"
653+
append_special_case(
654+
df_in=df,
655+
args_in=dict(x=None, y=None, color=None),
656+
args_expect=dict(x="index", y="_value", color="variable", orientation="v",),
657+
df_expect=pd.DataFrame(
658+
dict(
659+
index=[7, 8, 7, 8],
660+
_value=[1, 2, 3, 4],
661+
variable=["value", "value", "b", "b"],
662+
)
663+
),
664+
)
665+
621666

622667
@pytest.mark.parametrize("df_in, args_in, args_expect, df_expect", special_cases)
623668
def test_wide_mode_internal_special_cases(df_in, args_in, args_expect, df_expect):
624669
args_in["data_frame"] = df_in
625670
args_out = build_dataframe(args_in, go.Scatter)
626671
df_out = args_out.pop("data_frame")
672+
assert args_out == args_expect
627673
assert_frame_equal(
628674
df_out.sort_index(axis=1), df_expect.sort_index(axis=1),
629675
)
630-
assert args_out == args_expect
631676

632677

633678
def test_multi_index():
@@ -642,13 +687,3 @@ def test_multi_index():
642687
with pytest.raises(TypeError) as err_msg:
643688
px.scatter(df)
644689
assert "pandas MultiIndex is not supported by plotly express" in str(err_msg.value)
645-
646-
647-
def test_special_name_collisions():
648-
df = pd.DataFrame(
649-
dict(a=range(10), b=range(10), value=range(10), variable=range(10))
650-
)
651-
args_in = dict(data_frame=df, color="value", symbol="variable")
652-
args_out = build_dataframe(args_in, go.Scatter)
653-
df_out = args_out["data_frame"]
654-
assert len(set(df_out.columns)) == len(df_out.columns)

0 commit comments

Comments
 (0)