|
| 1 | +# geoms/geom_norm.py |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import plotly.graph_objects as go |
| 5 | + |
| 6 | +from .geom_base import Geom |
| 7 | + |
| 8 | + |
| 9 | +class geom_norm(Geom): |
| 10 | + """ |
| 11 | + Geom for overlaying a normal distribution curve. |
| 12 | +
|
| 13 | + Automatically fits a normal distribution to the data (using mean and std) |
| 14 | + unless mean and sd are explicitly provided. Useful for comparing actual |
| 15 | + data distribution to theoretical normal distribution. |
| 16 | +
|
| 17 | + Parameters |
| 18 | + ---------- |
| 19 | + data : DataFrame, optional |
| 20 | + Data for the geom (overrides plot data). |
| 21 | + mapping : aes, optional |
| 22 | + Aesthetic mappings. Uses x aesthetic to determine data range. |
| 23 | + mean : float, optional |
| 24 | + Mean of the normal distribution. If None, computed from data. |
| 25 | + sd : float, optional |
| 26 | + Standard deviation of the normal distribution. If None, computed from data. |
| 27 | + scale : str, optional |
| 28 | + Output scale: 'density' (default) outputs PDF values, 'count' scales |
| 29 | + to match histogram counts (PDF * n * binwidth). When 'count', automatically |
| 30 | + estimates binwidth from data range and number of observations. |
| 31 | + binwidth : float, optional |
| 32 | + Bin width for count scaling. If None, estimated as data_range / 30. |
| 33 | + n : int, optional |
| 34 | + Number of points for the curve. Default is 101. |
| 35 | + color : str, optional |
| 36 | + Color of the line. Default is 'red'. |
| 37 | + size : float, optional |
| 38 | + Width of the line. Default is 2. |
| 39 | + linetype : str, optional |
| 40 | + Line style ('solid', 'dash', etc.). Default is 'solid'. |
| 41 | +
|
| 42 | + Examples |
| 43 | + -------- |
| 44 | + >>> # With density-scaled histogram (default) |
| 45 | + >>> ggplot(df, aes(x='x')) + geom_histogram(aes(y=after_stat('density'))) + geom_norm() |
| 46 | +
|
| 47 | + >>> # With count histogram (no density scaling needed on histogram) |
| 48 | + >>> ggplot(df, aes(x='x')) + geom_histogram(bins=30) + geom_norm(scale='count') |
| 49 | +
|
| 50 | + >>> # Explicit parameters |
| 51 | + >>> ggplot(df, aes(x='x')) + geom_histogram(bins=30) + geom_norm(scale='count', mean=0, sd=1) |
| 52 | +
|
| 53 | + >>> # Styled |
| 54 | + >>> ggplot(df, aes(x='x')) + geom_histogram(aes(y=after_stat('density'))) + geom_norm(color='blue', size=3) |
| 55 | + """ |
| 56 | + |
| 57 | + default_params = { |
| 58 | + "n": 101, |
| 59 | + "color": "red", |
| 60 | + "size": 2, |
| 61 | + "linetype": "solid", |
| 62 | + "scale": "density", |
| 63 | + } |
| 64 | + |
| 65 | + def __init__(self, data=None, mapping=None, mean=None, sd=None, |
| 66 | + scale="density", binwidth=None, **params): |
| 67 | + super().__init__(data, mapping, **params) |
| 68 | + self.mean = mean |
| 69 | + self.sd = sd |
| 70 | + self.scale = scale |
| 71 | + self.binwidth = binwidth |
| 72 | + |
| 73 | + def _draw_impl(self, fig, data, row, col): |
| 74 | + from scipy.stats import norm |
| 75 | + |
| 76 | + # Get parameters |
| 77 | + n = self.params.get("n", 101) |
| 78 | + color = self.params.get("color", "red") |
| 79 | + size = self.params.get("size", 2) |
| 80 | + linetype = self.params.get("linetype", "solid") |
| 81 | + |
| 82 | + # Get x column |
| 83 | + x_col = self.mapping.get('x') if self.mapping else None |
| 84 | + if x_col is None or x_col not in data.columns: |
| 85 | + raise ValueError("geom_norm requires x aesthetic") |
| 86 | + |
| 87 | + x_data = data[x_col].dropna() |
| 88 | + n_obs = len(x_data) |
| 89 | + |
| 90 | + # Compute or use provided mean/sd |
| 91 | + mean = self.mean if self.mean is not None else x_data.mean() |
| 92 | + sd = self.sd if self.sd is not None else x_data.std() |
| 93 | + |
| 94 | + # Generate x range (extend beyond data range) |
| 95 | + x_min, x_max = x_data.min(), x_data.max() |
| 96 | + x_range = x_max - x_min |
| 97 | + x_min -= x_range * 0.05 |
| 98 | + x_max += x_range * 0.05 |
| 99 | + |
| 100 | + # Compute normal PDF |
| 101 | + x_vals = np.linspace(x_min, x_max, n) |
| 102 | + y_vals = norm.pdf(x_vals, mean, sd) |
| 103 | + |
| 104 | + # Scale to counts if requested |
| 105 | + if self.scale == 'count': |
| 106 | + # Estimate binwidth if not provided (default 30 bins like geom_histogram) |
| 107 | + binwidth = self.binwidth if self.binwidth is not None else x_range / 30 |
| 108 | + y_vals = y_vals * n_obs * binwidth |
| 109 | + y_label = 'count' |
| 110 | + else: |
| 111 | + y_label = 'density' |
| 112 | + |
| 113 | + # Map linetype to Plotly dash |
| 114 | + dash_map = { |
| 115 | + 'solid': 'solid', |
| 116 | + 'dashed': 'dash', |
| 117 | + 'dash': 'dash', |
| 118 | + 'dotted': 'dot', |
| 119 | + 'dot': 'dot', |
| 120 | + 'longdash': 'longdash', |
| 121 | + 'dashdot': 'dashdot', |
| 122 | + 'twodash': 'dashdot', |
| 123 | + } |
| 124 | + dash = dash_map.get(linetype, 'solid') |
| 125 | + |
| 126 | + # Add trace |
| 127 | + fig.add_trace( |
| 128 | + go.Scatter( |
| 129 | + x=x_vals, |
| 130 | + y=y_vals, |
| 131 | + mode='lines', |
| 132 | + line=dict(color=color, width=size, dash=dash), |
| 133 | + name=f'Normal(\u03bc={mean:.2f}, \u03c3={sd:.2f})', |
| 134 | + showlegend=True, |
| 135 | + hovertemplate=f'x: %{{x:.2f}}<br>{y_label}: %{{y:.4f}}<extra></extra>', |
| 136 | + ), |
| 137 | + row=row, col=col, |
| 138 | + ) |
0 commit comments