|
6 | 6 | from scipy.stats import ttest_ind |
7 | 7 | import numpy as np |
8 | 8 | from tqdm import tqdm |
| 9 | +import logging |
| 10 | + |
| 11 | +logger = logging.getLogger(__name__) |
9 | 12 |
|
10 | 13 |
|
11 | 14 | def calculate_t_statistics(df, max_epochs=500): |
@@ -40,10 +43,22 @@ def calculate_t_statistics(df, max_epochs=500): |
40 | 43 | & (t_df["epochs_completed"] == epoch) |
41 | 44 | ]["loss_value"].values |
42 | 45 |
|
43 | | - if len(true_losses) > 0 and len(other_losses) > 0: |
| 46 | + # T-test requires at least 2 samples per group for meaningful results |
| 47 | + if len(true_losses) >= 2 and len(other_losses) >= 2: |
44 | 48 | result = ttest_ind(other_losses, true_losses, equal_var=False) |
| 49 | + if np.isnan(result.statistic): |
| 50 | + logger.debug(f"NaN t-statistic for {author} at epoch {epoch}: " |
| 51 | + f"n_true={len(true_losses)}, n_other={len(other_losses)}") |
45 | 52 | t_raws[author].append(result.statistic) |
| 53 | + elif len(true_losses) > 0 or len(other_losses) > 0: |
| 54 | + # Have some data but insufficient for t-test |
| 55 | + logger.debug(f"Insufficient data for t-test for {author} at epoch {epoch}: " |
| 56 | + f"n_true={len(true_losses)}, n_other={len(other_losses)} " |
| 57 | + f"(need at least 2 samples per group)") |
| 58 | + t_raws[author].append(np.nan) |
46 | 59 | else: |
| 60 | + # No data at all |
| 61 | + logger.debug(f"No data for {author} at epoch {epoch}") |
47 | 62 | t_raws[author].append(np.nan) |
48 | 63 |
|
49 | 64 | # Convert to long-form DataFrame |
@@ -130,22 +145,36 @@ def generate_t_test_figure( |
130 | 145 | ax.set_xlabel("Epochs completed", fontsize=12) |
131 | 146 | ax.set_ylabel("$t$-value", fontsize=12) |
132 | 147 |
|
133 | | - # Calculate dynamic y-axis limits based on data |
134 | | - # This ensures all data points are visible, including negative t-statistics |
135 | | - # (e.g., when a model performs worse on its own training author) |
136 | | - y_min = t_raws_df['t_raw'].min() |
137 | | - y_max = t_raws_df['t_raw'].max() |
| 148 | + # Calculate dynamic y-axis limits based on VALID data only |
| 149 | + # Filter out NaN/Inf values to avoid matplotlib errors |
| 150 | + valid_t_values = t_raws_df['t_raw'].replace([np.inf, -np.inf], np.nan).dropna() |
| 151 | + |
| 152 | + if len(valid_t_values) == 0: |
| 153 | + # No valid data - use reasonable defaults around threshold |
| 154 | + logger.warning("No valid t-statistics found. Using default axis limits.") |
| 155 | + y_min = -1.0 |
| 156 | + y_max = 5.0 |
| 157 | + else: |
| 158 | + y_min = valid_t_values.min() |
| 159 | + y_max = valid_t_values.max() |
138 | 160 |
|
139 | | - # Add padding for better visualization |
140 | | - y_range = y_max - y_min |
141 | | - padding = 0.05 * y_range if y_range > 0 else 0.5 |
| 161 | + # Add padding for better visualization |
| 162 | + y_range = y_max - y_min |
| 163 | + padding = 0.05 * y_range if y_range > 0 else 0.5 |
142 | 164 |
|
143 | | - # Ensure threshold line (p<0.001 at t=3.291) is always visible |
144 | | - threshold = 3.291 |
145 | | - y_max = max(y_max, threshold) + padding |
146 | | - y_min = min(y_min, 0) - padding # Allow negatives if they exist |
| 165 | + # Ensure threshold line (p<0.001 at t=3.291) is always visible |
| 166 | + threshold = 3.291 |
| 167 | + y_max = max(y_max, threshold) + padding |
| 168 | + y_min = min(y_min, 0) - padding # Allow negatives if they exist |
| 169 | + |
| 170 | + # Final validation to ensure limits are finite and valid |
| 171 | + if not (np.isfinite(y_min) and np.isfinite(y_max) and y_min < y_max): |
| 172 | + logger.error(f"Invalid axis limits computed: y_min={y_min}, y_max={y_max}. Using defaults.") |
| 173 | + y_min = -1.0 |
| 174 | + y_max = 5.0 |
147 | 175 |
|
148 | 176 | # Add threshold line |
| 177 | + threshold = 3.291 |
149 | 178 | ax.axhline(y=threshold, linestyle="--", color="black", label="p<0.001 threshold" if show_legend else "") |
150 | 179 | ax.set_xlim(0, t_raws_df["Epoch"].max()) |
151 | 180 | ax.set_ylim(y_min, y_max) |
@@ -239,22 +268,36 @@ def generate_t_test_avg_figure( |
239 | 268 | ax.set_xlabel("Epochs completed", fontsize=12) |
240 | 269 | ax.set_ylabel("$t$-value", fontsize=12) |
241 | 270 |
|
242 | | - # Calculate dynamic y-axis limits based on data |
243 | | - # This ensures all data points are visible, including negative t-statistics |
244 | | - # (e.g., when a model performs worse on its own training author) |
245 | | - y_min = t_raws_df['t_raw'].min() |
246 | | - y_max = t_raws_df['t_raw'].max() |
| 271 | + # Calculate dynamic y-axis limits based on VALID data only |
| 272 | + # Filter out NaN/Inf values to avoid matplotlib errors |
| 273 | + valid_t_values = t_raws_df['t_raw'].replace([np.inf, -np.inf], np.nan).dropna() |
247 | 274 |
|
248 | | - # Add padding for better visualization |
249 | | - y_range = y_max - y_min |
250 | | - padding = 0.05 * y_range if y_range > 0 else 0.5 |
| 275 | + if len(valid_t_values) == 0: |
| 276 | + # No valid data - use reasonable defaults around threshold |
| 277 | + logger.warning("No valid t-statistics found for average figure. Using default axis limits.") |
| 278 | + y_min = -1.0 |
| 279 | + y_max = 5.0 |
| 280 | + else: |
| 281 | + y_min = valid_t_values.min() |
| 282 | + y_max = valid_t_values.max() |
251 | 283 |
|
252 | | - # Ensure threshold line (p<0.001 at t=3.291) is always visible |
253 | | - threshold = 3.291 |
254 | | - y_max = max(y_max, threshold) + padding |
255 | | - y_min = min(y_min, 0) - padding # Allow negatives if they exist |
| 284 | + # Add padding for better visualization |
| 285 | + y_range = y_max - y_min |
| 286 | + padding = 0.05 * y_range if y_range > 0 else 0.5 |
| 287 | + |
| 288 | + # Ensure threshold line (p<0.001 at t=3.291) is always visible |
| 289 | + threshold = 3.291 |
| 290 | + y_max = max(y_max, threshold) + padding |
| 291 | + y_min = min(y_min, 0) - padding # Allow negatives if they exist |
| 292 | + |
| 293 | + # Final validation to ensure limits are finite and valid |
| 294 | + if not (np.isfinite(y_min) and np.isfinite(y_max) and y_min < y_max): |
| 295 | + logger.error(f"Invalid axis limits computed for average figure: y_min={y_min}, y_max={y_max}. Using defaults.") |
| 296 | + y_min = -1.0 |
| 297 | + y_max = 5.0 |
256 | 298 |
|
257 | 299 | # Add threshold line |
| 300 | + threshold = 3.291 |
258 | 301 | ax.axhline(y=threshold, linestyle="--", color="black", label="p<0.001 threshold" if show_legend else "") |
259 | 302 | ax.set_xlim(0, t_raws_df["Epoch"].max()) |
260 | 303 | ax.set_ylim(y_min, y_max) |
|
0 commit comments