Skip to content

Commit 49938c2

Browse files
committed
fix test
1 parent 5704bcb commit 49938c2

File tree

8 files changed

+134
-92
lines changed

8 files changed

+134
-92
lines changed

.github/workflows/ci.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,16 @@ jobs:
2626
run: |
2727
python -m pip install --upgrade pip
2828
pip install -e ".[dev,docs]"
29-
pip install igraph networkx searoute
29+
pip install igraph networkx searoute nbmake
3030
3131
- name: Run tests
3232
run: |
3333
pytest --cov=ggplotly --cov-report=xml
3434
35+
- name: Test notebooks
36+
run: |
37+
pytest --nbmake docs/
38+
3539
- name: Upload coverage to Codecov
3640
if: matrix.python-version == '3.12'
3741
uses: codecov/codecov-action@v4

docs/guide/aesthetics.ipynb

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -27,41 +27,7 @@
2727
"execution_count": null,
2828
"metadata": {},
2929
"outputs": [],
30-
"source": [
31-
"import pandas as pd\n",
32-
"import numpy as np\n",
33-
"from ggplotly import *\n",
34-
"\n",
35-
"# Sample data\n",
36-
"np.random.seed(42)\n",
37-
"df = pd.DataFrame({\n",
38-
" 'x': np.random.randn(100),\n",
39-
" 'y': np.random.randn(100),\n",
40-
" 'species': np.random.choice(['A', 'B', 'C'], 100),\n",
41-
" 'value': np.random.rand(100) * 10\n",
42-
"})\n",
43-
"\n",
44-
"# Time series data\n",
45-
"ts_df = pd.DataFrame({\n",
46-
" 'date': pd.date_range('2024-01-01', periods=50, freq='D'),\n",
47-
" 'value': np.cumsum(np.random.randn(50)) + 50\n",
48-
"})\n",
49-
"ts_df.index = ts_df['date']\n",
50-
"ts_df.index.name = 'date'\n",
51-
"\n",
52-
"# Bar data\n",
53-
"bar_df = pd.DataFrame({\n",
54-
" 'category': ['A', 'B', 'C', 'D'],\n",
55-
" 'count': [25, 40, 35, 30]\n",
56-
"})\n",
57-
"\n",
58-
"# Line data with groups\n",
59-
"line_df = pd.DataFrame({\n",
60-
" 'x': np.tile(np.arange(10), 3),\n",
61-
" 'y': np.random.randn(30),\n",
62-
" 'id': np.repeat(['A', 'B', 'C'], 10)\n",
63-
"})"
64-
]
30+
"source": "import pandas as pd\nimport numpy as np\nfrom ggplotly import *\n\n# Sample data\nnp.random.seed(42)\ndf = pd.DataFrame({\n 'x': np.random.randn(100),\n 'y': np.random.randn(100),\n 'species': np.random.choice(['A', 'B', 'C'], 100),\n 'value': np.random.rand(100) * 10\n})\n\n# Time series data\nts_df = pd.DataFrame({\n 'date': pd.date_range('2024-01-01', periods=50, freq='D'),\n 'value': np.cumsum(np.random.randn(50)) + 50\n})\nts_df = ts_df.set_index('date') # Set date as index, removes it from columns\n\n# Bar data\nbar_df = pd.DataFrame({\n 'category': ['A', 'B', 'C', 'D'],\n 'count': [25, 40, 35, 30]\n})\n\n# Line data with groups\nline_df = pd.DataFrame({\n 'x': np.tile(np.arange(10), 3),\n 'y': np.random.randn(30),\n 'id': np.repeat(['A', 'B', 'C'], 10)\n})"
6531
},
6632
{
6733
"cell_type": "markdown",
@@ -292,18 +258,14 @@
292258
"execution_count": null,
293259
"metadata": {},
294260
"outputs": [],
295-
"source": [
296-
"ggplot(ts_df, aes(x='date', y='value')) + geom_line(linetype='dash')"
297-
]
261+
"source": "ggplot(ts_df, aes(y='value')) + geom_line(linetype='dash')"
298262
},
299263
{
300264
"cell_type": "code",
301265
"execution_count": null,
302266
"metadata": {},
303267
"outputs": [],
304-
"source": [
305-
"ggplot(ts_df, aes(x='date', y='value')) + geom_line(linetype='dot')"
306-
]
268+
"source": "ggplot(ts_df, aes(y='value')) + geom_line(linetype='dot')"
307269
}
308270
],
309271
"metadata": {
@@ -319,4 +281,4 @@
319281
},
320282
"nbformat": 4,
321283
"nbformat_minor": 4
322-
}
284+
}

docs/guide/coordinates.ipynb

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,7 @@
108108
"execution_count": null,
109109
"metadata": {},
110110
"outputs": [],
111-
"source": [
112-
"# Pie chart\n",
113-
"pie_df = pd.DataFrame({\n",
114-
" 'category': ['A', 'B', 'C', 'D'],\n",
115-
" 'value': [25, 30, 20, 25]\n",
116-
"})\n",
117-
"ggplot(pie_df, aes(x='', y='value', fill='category')) + \\\n",
118-
" geom_bar(stat='identity', width=1) + \\\n",
119-
" coord_polar(theta='y')"
120-
]
111+
"source": "# Pie chart\npie_df = pd.DataFrame({\n 'category': ['A', 'B', 'C', 'D'],\n 'value': [25, 30, 20, 25]\n})\npie_df['x'] = 1 # Constant x for stacked bar -> pie conversion\n\nggplot(pie_df, aes(x='x', y='value', fill='category')) + \\\n geom_bar(stat='identity', width=1) + \\\n coord_polar(theta='y')"
121112
},
122113
{
123114
"cell_type": "markdown",
@@ -179,4 +170,4 @@
179170
},
180171
"nbformat": 4,
181172
"nbformat_minor": 4
182-
}
173+
}

ggplotly/facets.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ def _has_3d_geoms(self, plot):
142142
return True
143143
return False
144144

145+
def _is_no_facet(self, value):
146+
"""Check if value means 'no faceting on this dimension'."""
147+
return value is None or value == '.'
148+
145149
def apply(self, plot):
146150
"""
147151
Apply facet grid to the plot.
@@ -155,25 +159,29 @@ def apply(self, plot):
155159
Raises:
156160
FacetColumnNotFoundError: If row or column variable doesn't exist in the data.
157161
"""
158-
# Validate row facet column exists
159-
if self.rows not in plot.data.columns:
162+
# Check if rows/cols are disabled (None or '.')
163+
has_rows = not self._is_no_facet(self.rows)
164+
has_cols = not self._is_no_facet(self.cols)
165+
166+
# Validate row facet column exists (if specified)
167+
if has_rows and self.rows not in plot.data.columns:
160168
raise FacetColumnNotFoundError(
161169
self.rows,
162170
list(plot.data.columns),
163171
facet_type="facet_grid rows"
164172
)
165173

166-
# Validate column facet column exists
167-
if self.cols not in plot.data.columns:
174+
# Validate column facet column exists (if specified)
175+
if has_cols and self.cols not in plot.data.columns:
168176
raise FacetColumnNotFoundError(
169177
self.cols,
170178
list(plot.data.columns),
171179
facet_type="facet_grid cols"
172180
)
173181

174182
# Get unique values for the row and column variables
175-
row_facets = plot.data[self.rows].unique()
176-
col_facets = plot.data[self.cols].unique()
183+
row_facets = plot.data[self.rows].unique() if has_rows else [None]
184+
col_facets = plot.data[self.cols].unique() if has_cols else [None]
177185

178186
nrows = len(row_facets)
179187
ncols = len(col_facets)
@@ -192,14 +200,22 @@ def apply(self, plot):
192200
shared_y = self.scales in ('fixed', 'free_x')
193201

194202
# Generate labels
195-
labels = [self._get_label(self.rows, row, self.cols, col)
196-
for row in row_facets for col in col_facets]
203+
def get_facet_label(row, col):
204+
if has_rows and has_cols:
205+
return self._get_label(self.rows, row, self.cols, col)
206+
elif has_rows:
207+
return str(row)
208+
elif has_cols:
209+
return str(col)
210+
else:
211+
return ""
212+
labels = [get_facet_label(row, col) for row in row_facets for col in col_facets]
197213

198214
# Calculate column widths and row heights based on space parameter
199215
column_widths = None
200216
row_heights = None
201217

202-
if self.space in ('free', 'free_x'):
218+
if self.space in ('free', 'free_x') and has_cols:
203219
# Calculate width proportional to x-axis data range per column
204220
x_col = plot.mapping.get('x')
205221
if x_col and x_col in plot.data.columns:
@@ -214,7 +230,7 @@ def apply(self, plot):
214230
if total > 0:
215231
column_widths = [r / total for r in col_ranges]
216232

217-
if self.space in ('free', 'free_y'):
233+
if self.space in ('free', 'free_y') and has_rows:
218234
# Calculate height proportional to y-axis data range per row
219235
y_col = plot.mapping.get('y')
220236
if y_col and y_col in plot.data.columns:
@@ -269,10 +285,17 @@ def apply(self, plot):
269285
scene_key = None
270286

271287
# Subset data for the current facet (row and column combination)
272-
facet_data = plot.data[
273-
(plot.data[self.rows] == row_value)
274-
& (plot.data[self.cols] == col_value)
275-
]
288+
if has_rows and has_cols:
289+
facet_data = plot.data[
290+
(plot.data[self.rows] == row_value)
291+
& (plot.data[self.cols] == col_value)
292+
]
293+
elif has_rows:
294+
facet_data = plot.data[plot.data[self.rows] == row_value]
295+
elif has_cols:
296+
facet_data = plot.data[plot.data[self.cols] == col_value]
297+
else:
298+
facet_data = plot.data
276299

277300
# Draw each geom on the subplot for the current facet
278301
for geom in plot.layers:
@@ -281,10 +304,17 @@ def apply(self, plot):
281304

282305
# If geom has its own explicit data, use that for faceting instead of plot.data
283306
if hasattr(geom, '_has_explicit_data') and geom._has_explicit_data:
284-
geom_facet_data = geom.data[
285-
(geom.data[self.rows] == row_value)
286-
& (geom.data[self.cols] == col_value)
287-
]
307+
if has_rows and has_cols:
308+
geom_facet_data = geom.data[
309+
(geom.data[self.rows] == row_value)
310+
& (geom.data[self.cols] == col_value)
311+
]
312+
elif has_rows:
313+
geom_facet_data = geom.data[geom.data[self.rows] == row_value]
314+
elif has_cols:
315+
geom_facet_data = geom.data[geom.data[self.cols] == col_value]
316+
else:
317+
geom_facet_data = geom.data
288318
geom.setup_data(geom_facet_data, plot.mapping)
289319
else:
290320
geom.setup_data(facet_data, plot.mapping)

ggplotly/geoms/geom_map.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,15 @@ def _draw_impl(self, fig, data, row, col):
131131

132132
# Determine the mode:
133133
# 1. GeoJSON mode (sf-like) - when geojson is provided or data is GeoDataFrame
134-
# 2. Choropleth mode - when map_id is provided
135-
# 3. Base map mode - no data aesthetics
134+
# 2. Choropleth mode - when map_id is provided (with optional fill)
135+
# 3. Base map mode - no map_id (fill is ignored - may be inherited from other geoms)
136136
is_geojson_mode = geojson is not None
137-
is_base_map = map_id_col is None and fill_col is None and not is_geojson_mode
137+
is_choropleth = map_id_col is not None
138+
is_base_map = not is_geojson_mode and not is_choropleth
138139

139140
if is_base_map:
140141
# Just set up the geo layout - no data traces needed
142+
# Note: fill may be inherited from ggplot aes but is ignored for base maps
141143
self._setup_geo_layout(fig)
142144
return
143145

@@ -146,9 +148,7 @@ def _draw_impl(self, fig, data, row, col):
146148
self._draw_geojson(fig, data, geojson, fill_col)
147149
return
148150

149-
# Choropleth mode - require map_id
150-
if map_id_col is None:
151-
raise ValueError("geom_map choropleth requires a 'map_id' aesthetic")
151+
# Choropleth mode - map_id is required (already checked above)
152152

153153
mapper = AestheticMapper(data, self.mapping, self.params, self.theme)
154154
style_props = mapper.get_style_properties()

ggplotly/geoms/geom_point_3d.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,23 @@ def should_show_legend(legendgroup):
131131
fig._shown_legendgroups.add(legendgroup)
132132
return True
133133

134-
# Determine grouping strategy
135-
has_color_grouping = style_props['color_series'] is not None or style_props['fill_series'] is not None
134+
# Check for continuous (numeric) color mapping
135+
has_continuous_color = (
136+
style_props.get('color_is_continuous', False) or
137+
style_props.get('fill_is_continuous', False)
138+
)
139+
140+
# Determine grouping strategy - only count as categorical if not continuous
141+
has_color_grouping = (
142+
style_props['color_series'] is not None or style_props['fill_series'] is not None
143+
) and not has_continuous_color
136144
has_shape_grouping = shape_series is not None
137145

138146
# Get categorical aesthetic info
139-
if style_props['color_series'] is not None:
147+
if style_props['color_series'] is not None and not has_continuous_color:
140148
cat_col = style_props['color']
141149
cat_map = style_props['color_map']
142-
elif style_props['fill_series'] is not None:
150+
elif style_props['fill_series'] is not None and not has_continuous_color:
143151
cat_col = style_props['fill']
144152
cat_map = style_props['fill_map']
145153
else:
@@ -233,7 +241,46 @@ def should_show_legend(legendgroup):
233241
trace_props, alpha, legend_name, should_show_legend, row, col
234242
)
235243

236-
# Case 5: No grouping - single trace
244+
# Case 5: Continuous color mapping (numeric values with colorscale)
245+
elif has_continuous_color:
246+
# Get the numeric color values
247+
if style_props.get('color_is_continuous'):
248+
color_values = style_props['color_series']
249+
else:
250+
color_values = style_props['fill_series']
251+
252+
# Build marker dict with colorscale
253+
marker_dict = dict(
254+
size=style_props['size'],
255+
color=color_values,
256+
colorscale='Viridis', # Default, may be overridden by scale
257+
showscale=True,
258+
opacity=alpha,
259+
)
260+
261+
# Get colorscale from scale if available
262+
colorscale = style_props.get('colorscale')
263+
if colorscale:
264+
marker_dict['colorscale'] = colorscale
265+
266+
# Use scene key for faceted plots (3D traces use scene, not row/col)
267+
scene_key = self.params.get('_scene_key', 'scene')
268+
269+
trace_name = self.params.get('name', '3D Scatter')
270+
fig.add_trace(
271+
go.Scatter3d(
272+
x=x,
273+
y=y,
274+
z=z,
275+
mode='markers',
276+
marker=marker_dict,
277+
name=trace_name,
278+
showlegend=False, # Colorbar replaces discrete legend
279+
scene=scene_key,
280+
)
281+
)
282+
283+
# Case 6: No grouping - single trace
237284
else:
238285
trace_props = self._apply_color_targets(
239286
{'color': 'marker_color', 'size': 'marker_size', 'shape': 'marker_symbol'},

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ dev = [
5252
"pytest>=7.0.0",
5353
"pytest-cov>=4.0.0",
5454
"pytest-ruff>=0.4.0",
55+
"nbmake>=1.5.0",
5556
"black>=23.0.0",
5657
"ruff>=0.1.0",
5758
"mypy>=1.0.0",
@@ -92,9 +93,9 @@ include = ["ggplotly*"]
9293
ggplotly = ["data/*.csv"]
9394

9495
[tool.pytest.ini_options]
95-
testpaths = ["pytest"]
96+
testpaths = ["pytest", "docs"]
9697
python_files = ["test_*.py"]
97-
addopts = "-v --tb=short"
98+
addopts = "-v --tb=short --nbmake"
9899

99100
[tool.black]
100101
line-length = 88

pytest/test_geom_map.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -297,17 +297,24 @@ def test_base_map_custom_projection(self):
297297
class TestGeomMapChoropleth:
298298
"""Tests for choropleth functionality."""
299299

300-
def test_choropleth_requires_map_id(self):
301-
"""Test that choropleth mode requires map_id aesthetic."""
300+
def test_fill_without_map_id_creates_base_map(self):
301+
"""Test that fill without map_id creates base map (fill is ignored).
302+
303+
This allows geom_map to be used as a background when fill is inherited
304+
from ggplot aes but intended for other geoms like geom_tile.
305+
"""
302306
data = pd.DataFrame({
303307
'value': [100, 200, 300],
304308
})
305309

306-
with pytest.raises(ValueError, match="map_id"):
307-
fig = (
308-
ggplot(data, aes(fill='value'))
309-
+ geom_map()
310-
).draw()
310+
# Should not raise - creates base map, fill is ignored
311+
fig = (
312+
ggplot(data, aes(fill='value'))
313+
+ geom_map()
314+
).draw()
315+
316+
# Base map doesn't add data traces, just sets up geo layout
317+
assert fig is not None
311318

312319
def test_choropleth_with_fill_creates_choropleth_trace(self):
313320
"""Test choropleth with numeric fill creates correct trace type."""

0 commit comments

Comments
 (0)