Skip to content

Commit cafb8dc

Browse files
committed
refactor DRY
1 parent c46e18c commit cafb8dc

32 files changed

+199
-220
lines changed

ggplotly/geoms/geom_abline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ class geom_abline(Geom):
1010

1111
__name__ = "geom_abline"
1212

13+
default_params = {"size": 1}
14+
1315
def __init__(self, data=None, mapping=None, **params):
1416
"""
1517
Draw lines defined by slope and intercept (y = intercept + slope * x).
@@ -48,9 +50,7 @@ def __init__(self, data=None, mapping=None, **params):
4850
self.slope = params.get('slope', 1)
4951
self.intercept = params.get('intercept', 0)
5052

51-
def draw(self, fig, data=None, row=1, col=1):
52-
if "size" not in self.params:
53-
self.params["size"] = 1
53+
def _draw_impl(self, fig, data, row, col):
5454

5555
# Get color from params, or use theme default
5656
color = self.params.get("color", None)

ggplotly/geoms/geom_area.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,21 @@ class geom_area(Geom):
2727
>>> ggplot(df, aes(x='x', y='y', fill='group')) + geom_area(alpha=0.5)
2828
"""
2929

30-
def draw(self, fig, data=None, row=1, col=1):
30+
default_params = {"size": 1}
31+
32+
def _draw_impl(self, fig, data, row, col):
3133
"""
3234
Draw area plot(s) on the figure.
3335
3436
Parameters:
3537
fig (Figure): Plotly figure object.
36-
data (DataFrame, optional): Data subset for faceting.
37-
row (int): Row position in subplot. Default is 1.
38-
col (int): Column position in subplot. Default is 1.
38+
data (DataFrame): Data (already transformed by stats).
39+
row (int): Row position in subplot.
40+
col (int): Column position in subplot.
3941
4042
Returns:
4143
None: Modifies the figure in place.
4244
"""
43-
data = data if data is not None else self.data
44-
45-
# Set default line width to 1 for area borders if not specified
46-
if "size" not in self.params:
47-
self.params["size"] = 1
4845

4946
# Remove size from mapping if present - area lines can't have variable widths
5047
# Only use size from params (literal values)

ggplotly/geoms/geom_bar.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,15 @@ class geom_bar(Geom):
3535
>>> ggplot(df, aes(x='category')) + geom_bar(width=0.5) # narrower bars
3636
"""
3737

38-
def draw(self, fig, data=None, row=1, col=1):
38+
def _apply_stats(self, data):
39+
"""Add default stat_count if no stats and stat='count'."""
40+
if self.stats == []:
41+
stat = self.params.get("stat", "count")
42+
if stat == "count":
43+
self.stats.append(stat_count(mapping=self.mapping))
44+
return super()._apply_stats(data)
45+
46+
def _draw_impl(self, fig, data, row, col):
3947
"""
4048
Draws a bar plot on the given figure.
4149
@@ -44,33 +52,20 @@ def draw(self, fig, data=None, row=1, col=1):
4452
4553
Parameters:
4654
fig (Figure): Plotly figure object.
47-
data (DataFrame): Optional data subset for faceting.
55+
data (DataFrame): Data (already transformed by stats).
4856
row (int): Row position in subplot (for faceting).
4957
col (int): Column position in subplot (for faceting).
5058
"""
5159
payload = dict()
52-
53-
# need this in case data is passed directly to the geom
54-
data = data if data is not None else self.data
5560
data = pd.DataFrame(data)
5661

57-
if self.stats == []:
58-
stat = self.params.get("stat", "count")
59-
60-
if stat == "count":
61-
self = self + stat_count()
62-
6362
if ("x" in self.mapping) & ("y" not in self.mapping):
6463
payload["orientation"] = "v"
6564
elif ("y" in self.mapping) & ("x" not in self.mapping):
6665
payload["orientation"] = "h"
6766

6867
plot = go.Bar
6968

70-
for comp in self.stats:
71-
# stack all stats on the data
72-
data, self.mapping = comp.compute(data)
73-
7469
payload["name"] = self.params.get("name", "Bar")
7570

7671
# Apply width parameter (default 0.9 to match ggplot2)

ggplotly/geoms/geom_base.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ class Geom:
2424
>>> ggplot(df, aes(x='x', y='y')) + geom_point(color='red', size=3)
2525
"""
2626

27+
# Default parameters for this geom. Subclasses should override this.
28+
default_params: dict = {}
29+
2730
def __init__(self, data=None, mapping=None, **params):
2831
"""
2932
Initialize the geom.
@@ -41,7 +44,8 @@ def __init__(self, data=None, mapping=None, **params):
4144
self.data = data
4245
self.mapping = mapping.mapping if mapping else {}
4346

44-
self.params = params
47+
# Merge default params with user-provided params (user params take precedence)
48+
self.params = {**self.default_params, **params}
4549
self.stats = []
4650
self.layers = []
4751
# Track whether this geom has explicit data or inherited from plot
@@ -89,6 +93,9 @@ def draw(self, fig, data=None, row=1, col=1):
8993
"""
9094
Draw the geometry on the figure.
9195
96+
This method applies any attached stats to transform the data,
97+
then delegates to _draw_impl for the actual rendering.
98+
9299
Parameters:
93100
fig (Figure): Plotly figure object.
94101
data (DataFrame, optional): Data subset for faceting.
@@ -97,11 +104,64 @@ def draw(self, fig, data=None, row=1, col=1):
97104
98105
Returns:
99106
None: Modifies the figure in place.
107+
"""
108+
data = data if data is not None else self.data
109+
110+
# Apply any stats to transform the data
111+
data = self._apply_stats(data)
112+
113+
# Delegate to subclass implementation
114+
self._draw_impl(fig, data, row, col)
115+
116+
def _apply_stats(self, data):
117+
"""
118+
Apply all attached stats to transform the data.
119+
120+
Parameters:
121+
data (DataFrame): Input data.
122+
123+
Returns:
124+
DataFrame: Transformed data after all stats applied.
125+
"""
126+
for stat in self.stats:
127+
data, self.mapping = stat.compute(data)
128+
return data
129+
130+
def _draw_impl(self, fig, data, row, col):
131+
"""
132+
Implementation of the actual drawing logic.
133+
134+
Subclasses should override this method instead of draw().
135+
136+
Parameters:
137+
fig (Figure): Plotly figure object.
138+
data (DataFrame): Data (already transformed by stats).
139+
row (int): Row position in subplot (for faceting).
140+
col (int): Column position in subplot (for faceting).
100141
101142
Raises:
102143
NotImplementedError: Must be implemented by subclasses.
103144
"""
104-
raise NotImplementedError("The draw method must be implemented by subclasses.")
145+
raise NotImplementedError("The _draw_impl method must be implemented by subclasses.")
146+
147+
def _get_style_props(self, data):
148+
"""
149+
Get style properties from aesthetic mapper.
150+
151+
This is a convenience method to reduce boilerplate in geom subclasses.
152+
153+
Parameters:
154+
data (DataFrame): The data to use for aesthetic mapping.
155+
156+
Returns:
157+
dict: Style properties from AestheticMapper.
158+
"""
159+
mapper = AestheticMapper(
160+
data, self.mapping, self.params, self.theme,
161+
global_color_map=self._global_color_map,
162+
global_shape_map=self._global_shape_map
163+
)
164+
return mapper.get_style_properties()
105165

106166
def _apply_color_targets(self, target_props: dict, style_props: dict, value_key=None, data_mask=None, shape_key=None) -> dict:
107167
"""

ggplotly/geoms/geom_boxplot.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def _shape_to_plotly_symbol(self, shape):
109109
}
110110
return shape_map.get(shape, 'circle')
111111

112-
def draw(self, fig, data=None, row=1, col=1):
112+
def _draw_impl(self, fig, data, row, col):
113113
"""
114114
Draw boxplot(s) on the figure.
115115
@@ -124,7 +124,6 @@ def draw(self, fig, data=None, row=1, col=1):
124124
col : int, default=1
125125
Column position in subplot.
126126
"""
127-
data = data if data is not None else self.data
128127

129128
plot = go.Box
130129

ggplotly/geoms/geom_candlestick.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(self, data=None, mapping=None, **params):
5555
if 'decreasing_color' not in self.params:
5656
self.params['decreasing_color'] = '#EF5350' # Red
5757

58-
def draw(self, fig, data=None, row=1, col=1):
58+
def _draw_impl(self, fig, data, row, col):
5959
"""
6060
Draw candlestick chart on the figure.
6161
@@ -68,7 +68,6 @@ def draw(self, fig, data=None, row=1, col=1):
6868
Returns:
6969
None: Modifies the figure in place.
7070
"""
71-
data = data if data is not None else self.data
7271

7372
# Validate required aesthetics
7473
required = ['x', 'open', 'high', 'low', 'close']
@@ -170,7 +169,7 @@ def __init__(self, data=None, mapping=None, **params):
170169
if 'decreasing_color' not in self.params:
171170
self.params['decreasing_color'] = '#EF5350'
172171

173-
def draw(self, fig, data=None, row=1, col=1):
172+
def _draw_impl(self, fig, data, row, col):
174173
"""
175174
Draw OHLC chart on the figure.
176175
@@ -183,7 +182,6 @@ def draw(self, fig, data=None, row=1, col=1):
183182
Returns:
184183
None: Modifies the figure in place.
185184
"""
186-
data = data if data is not None else self.data
187185

188186
# Validate required aesthetics
189187
required = ['x', 'open', 'high', 'low', 'close']

ggplotly/geoms/geom_col.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class geom_col(Geom):
2222
>>> ggplot(df, aes(x='category', y='value', fill='group')) + geom_col()
2323
"""
2424

25-
def draw(self, fig, data=None, row=1, col=1):
25+
def _draw_impl(self, fig, data, row, col):
2626
"""
2727
Draw column(s) on the figure.
2828
@@ -35,7 +35,6 @@ def draw(self, fig, data=None, row=1, col=1):
3535
Returns:
3636
None: Modifies the figure in place.
3737
"""
38-
data = data if data is not None else self.data
3938

4039
payload = dict()
4140
payload["name"] = self.params.get("name", "Column")

ggplotly/geoms/geom_contour.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import plotly.graph_objects as go
44

5-
from ..aesthetic_mapper import AestheticMapper
65
from ..stats.stat_contour import stat_contour
76
from .geom_base import Geom
87

98

109
class geom_contour(Geom):
1110
"""Geom for drawing contour lines from 2D data."""
1211

12+
default_params = {"size": 1}
13+
1314
def __init__(self, data=None, mapping=None, **params):
1415
"""
1516
Draw contour lines from 2D data.
@@ -79,14 +80,8 @@ def _compute_contour_grid(self, data):
7980
result, _ = contour_stat.compute(data)
8081
return result
8182

82-
def draw(self, fig, data=None, row=1, col=1):
83-
data = data if data is not None else self.data
84-
85-
if "size" not in self.params:
86-
self.params["size"] = 1
87-
88-
mapper = AestheticMapper(data, self.mapping, self.params, self.theme)
89-
style_props = mapper.get_style_properties()
83+
def _draw_impl(self, fig, data, row, col):
84+
style_props = self._get_style_props(data)
9085

9186
x_col = self.mapping.get("x")
9287
y_col = self.mapping.get("y")

ggplotly/geoms/geom_contour_filled.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import plotly.graph_objects as go
44

5-
from ..aesthetic_mapper import AestheticMapper
65
from ..stats.stat_contour import stat_contour
76
from .geom_base import Geom
87

98

109
class geom_contour_filled(Geom):
1110
"""Geom for drawing filled contours from 2D data."""
1211

12+
default_params = {"alpha": 0.8}
13+
1314
def __init__(self, data=None, mapping=None, **params):
1415
"""
1516
Draw filled contours from 2D data.
@@ -79,14 +80,8 @@ def _compute_contour_grid(self, data):
7980
result, _ = contour_stat.compute(data)
8081
return result
8182

82-
def draw(self, fig, data=None, row=1, col=1):
83-
data = data if data is not None else self.data
84-
85-
if "alpha" not in self.params:
86-
self.params["alpha"] = 0.8
87-
88-
mapper = AestheticMapper(data, self.mapping, self.params, self.theme)
89-
style_props = mapper.get_style_properties()
83+
def _draw_impl(self, fig, data, row, col):
84+
style_props = self._get_style_props(data)
9085

9186
x_col = self.mapping.get("x")
9287
y_col = self.mapping.get("y")

ggplotly/geoms/geom_density.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
class geom_density(Geom):
1212
"""Geom for drawing density plots."""
1313

14+
default_params = {"size": 2}
15+
1416
def __init__(self, data=None, mapping=None, bw='nrd0', adjust=1, kernel='gaussian',
1517
n=512, trim=False, **params):
1618
"""
@@ -121,11 +123,7 @@ def _compute_density_for_group(self, x_data, x_col, na_rm=False):
121123

122124
return result_df['x'].values, result_df['density'].values
123125

124-
def draw(self, fig, data=None, row=1, col=1):
125-
if "size" not in self.params:
126-
self.params["size"] = 2
127-
data = data if data is not None else self.data
128-
126+
def _draw_impl(self, fig, data, row, col):
129127
# Remove size from mapping if present - density lines can't have variable widths
130128
# Only use size from params (literal values)
131129
if "size" in self.mapping:

0 commit comments

Comments
 (0)