Skip to content

Commit f733831

Browse files
manage ugly name collisions
1 parent f3039ac commit f733831

File tree

3 files changed

+59
-52
lines changed

3 files changed

+59
-52
lines changed

packages/python/plotly/plotly/express/_core.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -970,27 +970,24 @@ def _isinstance_listlike(x):
970970
return True
971971

972972

973-
def process_args_into_dataframe(args, wide_mode, var_name):
973+
def _escape_col_name(df_input, col_name):
974+
while df_input is not None and col_name in df_input.columns:
975+
col_name = "_" + col_name
976+
return col_name
977+
978+
979+
def process_args_into_dataframe(args, wide_mode, var_name, value_name):
974980
"""
975981
After this function runs, the `all_attrables` keys of `args` all contain only
976982
references to columns of `df_output`. This function handles the extraction of data
977983
from `args["attrable"]` and column-name-generation as appropriate, and adds the
978984
data to `df_output` and then replaces `args["attrable"]` with the appropriate
979985
reference.
980986
"""
981-
for field in args:
982-
if field in array_attrables and args[field] is not None:
983-
args[field] = (
984-
OrderedDict(args[field])
985-
if isinstance(args[field], dict)
986-
else list(args[field])
987-
)
988-
# Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.)
989-
df_provided = args["data_frame"] is not None
990-
if df_provided and not isinstance(args["data_frame"], pd.DataFrame):
991-
args["data_frame"] = pd.DataFrame(args["data_frame"])
987+
992988
df_input = args["data_frame"]
993989
df_provided = df_input is not None
990+
994991
df_output = pd.DataFrame()
995992
constants = dict()
996993
ranges = list()
@@ -1083,7 +1080,7 @@ def process_args_into_dataframe(args, wide_mode, var_name):
10831080
)
10841081
# Check validity of column name
10851082
if argument not in df_input.columns:
1086-
if wide_mode and argument in ("value", var_name):
1083+
if wide_mode and argument in (value_name, var_name):
10871084
continue
10881085
else:
10891086
err_msg = (
@@ -1205,10 +1202,11 @@ def build_dataframe(args, constructor):
12051202
wide_y = False if no_y else _is_col_list(df_input, args["y"])
12061203

12071204
wide_mode = False
1208-
var_name = None
1205+
var_name = None # will likely be "variable" in wide_mode
1206+
wide_cross_name = None # will likely be "index" in wide_mode
1207+
value_name = "value"
12091208
hist2d_types = [go.Histogram2d, go.Histogram2dContour]
12101209
if constructor in cartesians:
1211-
wide_cross_name = None
12121210
if wide_x and wide_y:
12131211
raise ValueError(
12141212
"Cannot accept list of column references or list of columns for both `x` and `y`."
@@ -1266,26 +1264,33 @@ def build_dataframe(args, constructor):
12661264
args["wide_cross"] = df_input.index
12671265
wide_cross_name = df_input.index.name or "index"
12681266
else:
1269-
args["wide_cross"] = Range(label="index")
1270-
wide_cross_name = "index"
1267+
wide_cross_name = _escape_col_name(df_input, "index")
1268+
args["wide_cross"] = Range(label=wide_cross_name)
1269+
1270+
if wide_mode:
1271+
var_name = _escape_col_name(df_input, var_name)
1272+
value_name = _escape_col_name(df_input, value_name)
12711273

12721274
# now that things have been prepped, we do the systematic rewriting of `args`
12731275

1274-
df_output, wide_id_vars = process_args_into_dataframe(args, wide_mode, var_name)
1276+
df_output, wide_id_vars = process_args_into_dataframe(
1277+
args, wide_mode, var_name, value_name
1278+
)
12751279

12761280
# now that `df_output` exists and `args` contains only references, we complete
12771281
# the special-case and wide-mode handling by further rewriting args and/or mutating
12781282
# df_output
12791283

1284+
count_name = _escape_col_name(df_output, "count")
12801285
if not wide_mode and missing_bar_dim and constructor == go.Bar:
12811286
# now that we've populated df_output, we check to see if the non-missing
12821287
# dimension is categorical: if so, then setting the missing dimension to a
12831288
# constant 1 is a less-insane thing to do than setting it to the index by
12841289
# default and we let the normal auto-orientation-code do its thing later
12851290
other_dim = "x" if missing_bar_dim == "y" else "y"
12861291
if not _is_continuous(df_output, args[other_dim]):
1287-
args[missing_bar_dim] = "count"
1288-
df_output["count"] = 1
1292+
args[missing_bar_dim] = count_name
1293+
df_output[count_name] = 1
12891294
else:
12901295
# on the other hand, if the non-missing dimension is continuous, then we
12911296
# can use this information to override the normal auto-orientation code
@@ -1306,7 +1311,7 @@ def build_dataframe(args, constructor):
13061311
id_vars=wide_id_vars,
13071312
value_vars=wide_value_vars,
13081313
var_name=var_name,
1309-
value_name="value",
1314+
value_name=value_name,
13101315
)
13111316
df_output[var_name] = df_output[var_name].astype(str)
13121317
orient_v = wide_orientation == "v"
@@ -1317,24 +1322,24 @@ def build_dataframe(args, constructor):
13171322

13181323
if constructor in [go.Scatter, go.Funnel] + hist2d_types:
13191324
args["x" if orient_v else "y"] = wide_cross_name
1320-
args["y" if orient_v else "x"] = "value"
1325+
args["y" if orient_v else "x"] = value_name
13211326
if constructor != go.Histogram2d:
13221327
args["color"] = args["color"] or var_name
13231328
if constructor == go.Bar:
1324-
if _is_continuous(df_output, "value"):
1329+
if _is_continuous(df_output, value_name):
13251330
args["x" if orient_v else "y"] = wide_cross_name
1326-
args["y" if orient_v else "x"] = "value"
1331+
args["y" if orient_v else "x"] = value_name
13271332
args["color"] = args["color"] or var_name
13281333
else:
1329-
args["x" if orient_v else "y"] = "value"
1330-
args["y" if orient_v else "x"] = "count"
1331-
df_output["count"] = 1
1334+
args["x" if orient_v else "y"] = value_name
1335+
args["y" if orient_v else "x"] = count_name
1336+
df_output[count_name] = 1
13321337
args["color"] = args["color"] or var_name
13331338
if constructor in [go.Violin, go.Box]:
13341339
args["x" if orient_v else "y"] = wide_cross_name or var_name
1335-
args["y" if orient_v else "x"] = "value"
1340+
args["y" if orient_v else "x"] = value_name
13361341
if constructor == go.Histogram:
1337-
args["x" if orient_v else "y"] = "value"
1342+
args["x" if orient_v else "y"] = value_name
13381343
args["y" if orient_v else "x"] = wide_cross_name
13391344
args["color"] = args["color"] or var_name
13401345

packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ def test_with_index():
3838
# We do not allow "x=index"
3939
with pytest.raises(ValueError) as err_msg:
4040
fig = px.scatter(tips, x="index", y="total_bill")
41-
assert "To use the index, pass it in directly as `df.index`." in str(
42-
err_msg.value
43-
)
41+
assert "To use the index, pass it in directly as `df.index`." in str(err_msg.value)
4442
tips = px.data.tips()
4543
tips.index.name = "item"
4644
fig = px.scatter(tips, x=tips.index, y="total_bill")
@@ -75,10 +73,10 @@ def test_several_dataframes():
7573
# Name conflict
7674
with pytest.raises(NameError) as err_msg:
7775
fig = px.scatter(df, x="z", y=df2.money, size="y")
78-
assert "A name conflict was encountered for argument y" in str(err_msg.value)
76+
assert "A name conflict was encountered for argument y" in str(err_msg.value)
7977
with pytest.raises(NameError) as err_msg:
8078
fig = px.scatter(df, x="z", y=df2.money, size=df.y)
81-
assert "A name conflict was encountered for argument y" in str(err_msg.value)
79+
assert "A name conflict was encountered for argument y" in str(err_msg.value)
8280

8381
# No conflict when the dataframe is not given, fields are used
8482
df = pd.DataFrame(dict(x=[0, 1], y=[3, 4]))
@@ -157,41 +155,41 @@ def test_arrayattrable_numpy():
157155
def test_wrong_column_name():
158156
with pytest.raises(ValueError) as err_msg:
159157
px.scatter(px.data.tips(), x="bla", y="wrong")
160-
assert "Value of 'x' is not the name of a column in 'data_frame'" in str(
161-
err_msg.value
162-
)
158+
assert "Value of 'x' is not the name of a column in 'data_frame'" in str(
159+
err_msg.value
160+
)
163161

164162

165163
def test_missing_data_frame():
166164
with pytest.raises(ValueError) as err_msg:
167165
px.scatter(x="arg1", y="arg2")
168-
assert "String or int arguments are only possible" in str(err_msg.value)
166+
assert "String or int arguments are only possible" in str(err_msg.value)
169167

170168

171169
def test_wrong_dimensions_of_array():
172170
with pytest.raises(ValueError) as err_msg:
173171
px.scatter(x=[1, 2, 3], y=[2, 3, 4, 5])
174-
assert "All arguments should have the same length." in str(err_msg.value)
172+
assert "All arguments should have the same length." in str(err_msg.value)
175173

176174

177175
def test_wrong_dimensions_mixed_case():
178176
with pytest.raises(ValueError) as err_msg:
179177
df = pd.DataFrame(dict(time=[1, 2, 3], temperature=[20, 30, 25]))
180178
px.scatter(df, x="time", y="temperature", color=[1, 3, 9, 5])
181-
assert "All arguments should have the same length." in str(err_msg.value)
179+
assert "All arguments should have the same length." in str(err_msg.value)
182180

183181

184182
def test_wrong_dimensions():
185183
with pytest.raises(ValueError) as err_msg:
186184
px.scatter(px.data.tips(), x="tip", y=[1, 2, 3])
187-
assert "All arguments should have the same length." in str(err_msg.value)
185+
assert "All arguments should have the same length." in str(err_msg.value)
188186
# the order matters
189187
with pytest.raises(ValueError) as err_msg:
190188
px.scatter(px.data.tips(), x=[1, 2, 3], y="tip")
191-
assert "All arguments should have the same length." in str(err_msg.value)
189+
assert "All arguments should have the same length." in str(err_msg.value)
192190
with pytest.raises(ValueError):
193191
px.scatter(px.data.tips(), x=px.data.iris().index, y="tip")
194-
# assert "All arguments should have the same length." in str(err_msg.value)
192+
assert "All arguments should have the same length." in str(err_msg.value)
195193

196194

197195
def test_multiindex_raise_error():
@@ -203,9 +201,7 @@ def test_multiindex_raise_error():
203201
px.scatter(df, x="A", y="B")
204202
with pytest.raises(TypeError) as err_msg:
205203
px.scatter(df, x=df.index, y="B")
206-
assert "pandas MultiIndex is not supported by plotly express" in str(
207-
err_msg.value
208-
)
204+
assert "pandas MultiIndex is not supported by plotly express" in str(err_msg.value)
209205

210206

211207
def test_build_df_from_lists():

packages/python/plotly/plotly/tests/test_core/test_px/test_px_wide.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -635,14 +635,20 @@ def test_multi_index():
635635
df.index = [["a", "a", "b", "b"], ["c", "d", "c", "d"]]
636636
with pytest.raises(TypeError) as err_msg:
637637
px.scatter(df)
638-
assert "pandas MultiIndex is not supported by plotly express" in str(
639-
err_msg.value
640-
)
638+
assert "pandas MultiIndex is not supported by plotly express" in str(err_msg.value)
641639

642640
df = pd.DataFrame([[1, 2, 3, 4], [3, 4, 5, 6], [1, 2, 3, 4], [3, 4, 5, 6]])
643641
df.columns = [["e", "e", "f", "f"], ["g", "h", "g", "h"]]
644642
with pytest.raises(TypeError) as err_msg:
645643
px.scatter(df)
646-
assert "pandas MultiIndex is not supported by plotly express" in str(
647-
err_msg.value
648-
)
644+
assert "pandas MultiIndex is not supported by plotly express" in str(err_msg.value)
645+
646+
647+
def test_special_name_collisions():
648+
df = pd.DataFrame(
649+
dict(a=range(10), b=range(10), value=range(10), variable=range(10))
650+
)
651+
args_in = dict(data_frame=df, color="value", symbol="variable")
652+
args_out = build_dataframe(args_in, go.Scatter)
653+
df_out = args_out["data_frame"]
654+
assert len(set(df_out.columns)) == len(df_out.columns)

0 commit comments

Comments
 (0)