Skip to content

Commit 7d611fb

Browse files
committed
use return_type directly when building datasets
1 parent bb327d5 commit 7d611fb

File tree

9 files changed

+94
-118
lines changed

9 files changed

+94
-118
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
1010

1111
- Updated plotly.py to use base64 encoding of arrays in plotly JSON to improve performance.
1212
- Add `subtitle` attribute to all Plotly Express traces
13-
- Allow to load plotly data directly via pandas, polars and pyarrow, without depending directly on any [#4843](https://github.com/plotly/plotly.py/pull/4843)
13+
- Make plotly-express dataframe agnostic via Narwhals [#4790](https://github.com/plotly/plotly.py/pull/4790)
1414

1515
## [5.24.1] - 2024-09-12
1616

packages/python/plotly/plotly/tests/test_optional/test_px/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,8 @@ def pyarrow_table_constructor(obj) -> IntoDataFrame:
4141
@pytest.fixture(params=constructors)
4242
def constructor(request: pytest.FixtureRequest):
4343
return request.param # type: ignore[no-any-return]
44+
45+
46+
@pytest.fixture(params=["pandas", "pyarrow", "polars"])
47+
def backend(request: pytest.FixtureRequest) -> str:
48+
return request.param # type: ignore[no-any-return]

packages/python/plotly/plotly/tests/test_optional/test_px/test_facets.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
import random
55

66

7-
def test_facets(constructor):
8-
data = px.data.tips().to_dict(orient="list")
9-
df = constructor(data)
7+
def test_facets(backend):
8+
df = px.data.tips(return_type=backend)
109

1110
fig = px.scatter(df, x="total_bill", y="tip")
1211
assert "xaxis2" not in fig.layout
@@ -47,9 +46,8 @@ def test_facets(constructor):
4746
assert fig.layout.yaxis4.domain[0] - fig.layout.yaxis.domain[1] == approx(0.08)
4847

4948

50-
def test_facets_with_marginals(constructor):
51-
data = px.data.tips().to_dict(orient="list")
52-
df = constructor(data)
49+
def test_facets_with_marginals(backend):
50+
df = px.data.tips(return_type=backend)
5351

5452
fig = px.histogram(df, x="total_bill", facet_col="sex", marginal="rug")
5553
assert len(fig.data) == 4

packages/python/plotly/plotly/tests/test_optional/test_px/test_marginals.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
@pytest.mark.parametrize("px_fn", [px.scatter, px.density_heatmap, px.density_contour])
66
@pytest.mark.parametrize("marginal_x", [None, "histogram", "box", "violin"])
77
@pytest.mark.parametrize("marginal_y", [None, "rug"])
8-
def test_xy_marginals(constructor, px_fn, marginal_x, marginal_y):
9-
data = px.data.tips().to_dict(orient="list")
10-
df = constructor(data)
8+
def test_xy_marginals(backend, px_fn, marginal_x, marginal_y):
9+
df = px.data.tips(return_type=backend)
1110

1211
fig = px_fn(
1312
df, x="total_bill", y="tip", marginal_x=marginal_x, marginal_y=marginal_y
@@ -18,9 +17,8 @@ def test_xy_marginals(constructor, px_fn, marginal_x, marginal_y):
1817
@pytest.mark.parametrize("px_fn", [px.histogram, px.ecdf])
1918
@pytest.mark.parametrize("marginal", [None, "rug", "histogram", "box", "violin"])
2019
@pytest.mark.parametrize("orientation", ["h", "v"])
21-
def test_single_marginals(constructor, px_fn, marginal, orientation):
22-
data = px.data.tips().to_dict(orient="list")
23-
df = constructor(data)
20+
def test_single_marginals(backend, px_fn, marginal, orientation):
21+
df = px.data.tips(return_type=backend)
2422

2523
fig = px_fn(
2624
df, x="total_bill", y="total_bill", marginal=marginal, orientation=orientation

packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
from itertools import permutations
77

88

9-
def test_scatter(constructor):
10-
data = px.data.iris().to_dict(orient="list")
11-
iris = nw.from_native(constructor(data))
9+
def test_scatter(backend):
10+
iris = nw.from_native(px.data.iris(return_type=backend))
1211
fig = px.scatter(iris.to_native(), x="sepal_width", y="sepal_length")
1312
assert fig.data[0].type == "scatter"
1413
assert np.all(fig.data[0].x == iris.get_column("sepal_width").to_numpy())
@@ -17,9 +16,8 @@ def test_scatter(constructor):
1716
assert fig.data[0].mode == "markers"
1817

1918

20-
def test_custom_data_scatter(constructor):
21-
data = px.data.iris().to_dict(orient="list")
22-
iris = nw.from_native(constructor(data))
19+
def test_custom_data_scatter(backend):
20+
iris = nw.from_native(px.data.iris(return_type=backend))
2321
# No hover, no custom data
2422
fig = px.scatter(
2523
iris.to_native(), x="sepal_width", y="sepal_length", color="species"
@@ -67,9 +65,8 @@ def test_custom_data_scatter(constructor):
6765
)
6866

6967

70-
def test_labels(constructor):
71-
data = px.data.tips().to_dict(orient="list")
72-
tips = nw.from_native(constructor(data))
68+
def test_labels(backend):
69+
tips = nw.from_native(px.data.tips(return_type=backend))
7370
fig = px.scatter(
7471
tips.to_native(),
7572
x="total_bill",
@@ -100,10 +97,8 @@ def test_labels(constructor):
10097
({"text": "continent"}, "lines+markers+text"),
10198
],
10299
)
103-
def test_line_mode(constructor, extra_kwargs, expected_mode):
104-
data = px.data.gapminder().to_dict(orient="list")
105-
gapminder = constructor(data)
106-
100+
def test_line_mode(backend, extra_kwargs, expected_mode):
101+
gapminder = px.data.gapminder(return_type=backend)
107102
fig = px.line(
108103
gapminder,
109104
x="year",
@@ -114,12 +109,11 @@ def test_line_mode(constructor, extra_kwargs, expected_mode):
114109
assert fig.data[0].mode == expected_mode
115110

116111

117-
def test_px_templates(constructor):
112+
def test_px_templates(backend):
118113
try:
119114
import plotly.graph_objects as go
120115

121-
data = px.data.tips().to_dict(orient="list")
122-
tips = constructor(data)
116+
tips = px.data.tips(return_type=backend)
123117

124118
# use the normal defaults
125119
fig = px.scatter()
@@ -245,12 +239,11 @@ def test_px_defaults():
245239
pio.templates.default = "plotly"
246240

247241

248-
def assert_orderings(constructor, days_order, days_check, times_order, times_check):
242+
def assert_orderings(backend, days_order, days_check, times_order, times_check):
249243
symbol_sequence = ["circle", "diamond", "square", "cross", "circle", "diamond"]
250244
color_sequence = ["red", "blue", "red", "blue", "red", "blue", "red", "blue"]
251245

252-
data = px.data.tips().to_dict(orient="list")
253-
tips = nw.from_native(constructor(data))
246+
tips = nw.from_native(px.data.tips(return_type=backend))
254247

255248
fig = px.scatter(
256249
tips.to_native(),
@@ -284,16 +277,16 @@ def assert_orderings(constructor, days_order, days_check, times_order, times_che
284277

285278
@pytest.mark.parametrize("days", permutations(["Sun", "Sat", "Fri", "x"]))
286279
@pytest.mark.parametrize("times", permutations(["Lunch", "x"]))
287-
def test_orthogonal_and_missing_orderings(constructor, days, times):
280+
def test_orthogonal_and_missing_orderings(backend, days, times):
288281
assert_orderings(
289-
constructor, days, list(days) + ["Thur"], times, list(times) + ["Dinner"]
282+
backend, days, list(days) + ["Thur"], times, list(times) + ["Dinner"]
290283
)
291284

292285

293286
@pytest.mark.parametrize("days", permutations(["Sun", "Sat", "Fri", "Thur"]))
294287
@pytest.mark.parametrize("times", permutations(["Lunch", "Dinner"]))
295-
def test_orthogonal_orderings(constructor, days, times):
296-
assert_orderings(constructor, days, days, times, times)
288+
def test_orthogonal_orderings(backend, days, times):
289+
assert_orderings(backend, days, days, times, times)
297290

298291

299292
def test_permissive_defaults():
@@ -302,9 +295,8 @@ def test_permissive_defaults():
302295
px.defaults.should_not_work = "test"
303296

304297

305-
def test_marginal_ranges(constructor):
306-
data = px.data.tips().to_dict(orient="list")
307-
df = constructor(data)
298+
def test_marginal_ranges(backend):
299+
df = px.data.tips(return_type=backend)
308300
fig = px.scatter(
309301
df,
310302
x="total_bill",
@@ -318,9 +310,8 @@ def test_marginal_ranges(constructor):
318310
assert fig.layout.yaxis3.range is None
319311

320312

321-
def test_render_mode(constructor):
322-
data = px.data.gapminder().to_dict(orient="list")
323-
df = nw.from_native(constructor(data))
313+
def test_render_mode(backend):
314+
df = nw.from_native(px.data.gapminder(return_type=backend))
324315
df2007 = df.filter(nw.col("year") == 2007)
325316

326317
fig = px.scatter(df2007.to_native(), x="gdpPercap", y="lifeExp", trendline="ols")

packages/python/plotly/plotly/tests/test_optional/test_px/test_px_functions.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -187,31 +187,28 @@ def test_sunburst_treemap_with_path(constructor):
187187
assert fig.data[0].values[-1] == 8
188188

189189

190-
def test_sunburst_treemap_with_path_and_hover(constructor):
191-
data = px.data.tips().to_dict(orient="list")
192-
df = constructor(data)
190+
def test_sunburst_treemap_with_path_and_hover(backend):
191+
df = px.data.tips(return_type=backend)
193192
fig = px.sunburst(
194193
df, path=["sex", "day", "time", "smoker"], color="smoker", hover_data=["smoker"]
195194
)
196195
assert "smoker" in fig.data[0].hovertemplate
197196

198-
data = px.data.gapminder().query("year == 2007").to_dict(orient="list")
199-
df = constructor(data)
200-
197+
df = nw.from_native(px.data.gapminder(year=2007, return_type=backend))
201198
fig = px.sunburst(
202-
df, path=["continent", "country"], color="lifeExp", hover_data=df.columns
199+
df.to_native(),
200+
path=["continent", "country"],
201+
color="lifeExp",
202+
hover_data=df.columns,
203203
)
204204
assert fig.layout.coloraxis.colorbar.title.text == "lifeExp"
205205

206-
data = px.data.tips().to_dict(orient="list")
207-
df = constructor(data)
208-
206+
df = px.data.tips(return_type=backend)
209207
fig = px.sunburst(df, path=["sex", "day", "time", "smoker"], hover_name="smoker")
210208
assert "smoker" not in fig.data[0].hovertemplate # represented as '%{hovertext}'
211209
assert "%{hovertext}" in fig.data[0].hovertemplate # represented as '%{hovertext}'
212210

213-
data = px.data.tips().to_dict(orient="list")
214-
df = constructor(data)
211+
df = px.data.tips(return_type=backend)
215212
fig = px.sunburst(df, path=["sex", "day", "time", "smoker"], custom_data=["smoker"])
216213
assert fig.data[0].customdata[0][0] in ["Yes", "No"]
217214
assert "smoker" not in fig.data[0].hovertemplate
@@ -414,9 +411,8 @@ def test_funnel():
414411
assert len(fig.data) == 2
415412

416413

417-
def test_parcats_dimensions_max(constructor):
418-
data = px.data.tips().to_dict(orient="list")
419-
df = constructor(data)
414+
def test_parcats_dimensions_max(backend):
415+
df = px.data.tips(return_type=backend)
420416

421417
# default behaviour
422418
fig = px.parallel_categories(df)
@@ -449,13 +445,12 @@ def test_parcats_dimensions_max(constructor):
449445

450446

451447
@pytest.mark.parametrize("histfunc,y", [(None, None), ("count", "tip")])
452-
def test_histfunc_hoverlabels_univariate(constructor, histfunc, y):
448+
def test_histfunc_hoverlabels_univariate(backend, histfunc, y):
453449
def check_label(label, fig):
454450
assert fig.layout.yaxis.title.text == label
455451
assert label + "=" in fig.data[0].hovertemplate
456452

457-
data = px.data.tips().to_dict(orient="list")
458-
df = constructor(data)
453+
df = px.data.tips(return_type=backend)
459454

460455
# base case, just "count" (note count(tip) is same as count())
461456
fig = px.histogram(df, x="total_bill", y=y, histfunc=histfunc)
@@ -481,13 +476,12 @@ def check_label(label, fig):
481476
check_label("%s (normalized as %s)" % (histnorm, barnorm), fig)
482477

483478

484-
def test_histfunc_hoverlabels_bivariate(constructor):
479+
def test_histfunc_hoverlabels_bivariate(backend):
485480
def check_label(label, fig):
486481
assert fig.layout.yaxis.title.text == label
487482
assert label + "=" in fig.data[0].hovertemplate
488483

489-
data = px.data.tips().to_dict(orient="list")
490-
df = constructor(data)
484+
df = px.data.tips(return_type=backend)
491485

492486
# with y, should be same as forcing histfunc to sum
493487
fig = px.histogram(df, x="total_bill", y="tip")

packages/python/plotly/plotly/tests/test_optional/test_px/test_px_hover.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
from collections import OrderedDict # an OrderedDict is needed for Python 2
77

88

9-
def test_skip_hover(constructor):
10-
data = px.data.iris().to_dict(orient="list")
11-
df = constructor(data)
9+
def test_skip_hover(backend):
10+
df = px.data.iris(return_type=backend)
1211
fig = px.scatter(
1312
df,
1413
x="petal_length",
@@ -19,9 +18,8 @@ def test_skip_hover(constructor):
1918
assert fig.data[0].hovertemplate == "species_id=%{marker.size}<extra></extra>"
2019

2120

22-
def test_hover_data_string_column(constructor):
23-
data = px.data.tips().to_dict(orient="list")
24-
df = constructor(data)
21+
def test_hover_data_string_column(backend):
22+
df = px.data.tips(return_type=backend)
2523
fig = px.scatter(
2624
df,
2725
x="tip",
@@ -31,9 +29,8 @@ def test_hover_data_string_column(constructor):
3129
assert "sex" in fig.data[0].hovertemplate
3230

3331

34-
def test_composite_hover(constructor):
35-
data = px.data.tips().to_dict(orient="list")
36-
df = constructor(data)
32+
def test_composite_hover(backend):
33+
df = px.data.tips(return_type=backend)
3734
hover_dict = OrderedDict(
3835
{"day": False, "time": False, "sex": True, "total_bill": ":.1f"}
3936
)
@@ -91,9 +88,8 @@ def test_newdatain_hover_data():
9188
)
9289

9390

94-
def test_formatted_hover_and_labels(constructor):
95-
data = px.data.tips().to_dict(orient="list")
96-
df = constructor(data)
91+
def test_formatted_hover_and_labels(backend):
92+
df = px.data.tips(return_type=backend)
9793
fig = px.scatter(
9894
df,
9995
x="tip",
@@ -176,9 +172,8 @@ def test_fail_wrong_column():
176172
)
177173

178174

179-
def test_sunburst_hoverdict_color(constructor):
180-
data = px.data.gapminder().query("year == 2007").to_dict(orient="list")
181-
df = constructor(data)
175+
def test_sunburst_hoverdict_color(backend):
176+
df = px.data.gapminder(year=2007, return_type=backend)
182177
fig = px.sunburst(
183178
df,
184179
path=["continent", "country"],
@@ -189,7 +184,7 @@ def test_sunburst_hoverdict_color(constructor):
189184
assert "color" in fig.data[0].hovertemplate
190185

191186

192-
def test_date_in_hover(request, constructor):
187+
def test_date_in_hover(constructor):
193188
df = nw.from_native(
194189
constructor({"date": ["2015-04-04 19:31:30+01:00"], "value": [3]})
195190
).with_columns(date=nw.col("date").str.to_datetime(format="%Y-%m-%d %H:%M:%S%z"))

0 commit comments

Comments
 (0)