Skip to content

Commit e792f3e

Browse files
jeremymanningclaude
andcommitted
Fix issue #28: Handle NaN/Inf in t-test visualization axis limits
Implemented comprehensive fix for t-test visualization ValueError when computing axis limits with NaN/Inf t-statistics: **Core Fix in t_tests.py:** - Added data validation in calculate_t_statistics(): * Check for n>=2 samples per group before running t-test * Log insufficient data cases (0<n<2) and missing data (n=0) * Append NaN for invalid cases instead of computing t-test - Added robust axis limit calculation in both functions: * Filter out NaN/Inf values before computing limits * Add padding for visualization clarity * Ensure threshold line (p<0.001) is always visible * Final validation to ensure finite and valid limits * Fallback to sensible defaults (-1.0 to 5.0) **Comprehensive Test Suite:** - Created tests/test_t_test_edge_cases.py with 11 tests: 1. All NaN t-statistics (single sample per group) 2. Mixed NaN and valid data 3. All Infinite t-statistics (zero variance) 4. Mixed Infinite and valid data 5. Empty data groups (missing epochs) 6. Extreme outliers 7. Very small sample sizes (n=2, minimum required) 8. Average figure with all NaN 9. Average figure with mixed data 10. Logging verification 11. Normal data baseline All tests pass with real data (no mocks). Addresses issue #28 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 2e7fbc3 commit e792f3e

File tree

2 files changed

+622
-25
lines changed

2 files changed

+622
-25
lines changed

llm_stylometry/visualization/t_tests.py

Lines changed: 68 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from scipy.stats import ttest_ind
77
import numpy as np
88
from tqdm import tqdm
9+
import logging
10+
11+
logger = logging.getLogger(__name__)
912

1013

1114
def calculate_t_statistics(df, max_epochs=500):
@@ -40,10 +43,22 @@ def calculate_t_statistics(df, max_epochs=500):
4043
& (t_df["epochs_completed"] == epoch)
4144
]["loss_value"].values
4245

43-
if len(true_losses) > 0 and len(other_losses) > 0:
46+
# T-test requires at least 2 samples per group for meaningful results
47+
if len(true_losses) >= 2 and len(other_losses) >= 2:
4448
result = ttest_ind(other_losses, true_losses, equal_var=False)
49+
if np.isnan(result.statistic):
50+
logger.debug(f"NaN t-statistic for {author} at epoch {epoch}: "
51+
f"n_true={len(true_losses)}, n_other={len(other_losses)}")
4552
t_raws[author].append(result.statistic)
53+
elif len(true_losses) > 0 or len(other_losses) > 0:
54+
# Have some data but insufficient for t-test
55+
logger.debug(f"Insufficient data for t-test for {author} at epoch {epoch}: "
56+
f"n_true={len(true_losses)}, n_other={len(other_losses)} "
57+
f"(need at least 2 samples per group)")
58+
t_raws[author].append(np.nan)
4659
else:
60+
# No data at all
61+
logger.debug(f"No data for {author} at epoch {epoch}")
4762
t_raws[author].append(np.nan)
4863

4964
# Convert to long-form DataFrame
@@ -130,22 +145,36 @@ def generate_t_test_figure(
130145
ax.set_xlabel("Epochs completed", fontsize=12)
131146
ax.set_ylabel("$t$-value", fontsize=12)
132147

133-
# Calculate dynamic y-axis limits based on data
134-
# This ensures all data points are visible, including negative t-statistics
135-
# (e.g., when a model performs worse on its own training author)
136-
y_min = t_raws_df['t_raw'].min()
137-
y_max = t_raws_df['t_raw'].max()
148+
# Calculate dynamic y-axis limits based on VALID data only
149+
# Filter out NaN/Inf values to avoid matplotlib errors
150+
valid_t_values = t_raws_df['t_raw'].replace([np.inf, -np.inf], np.nan).dropna()
151+
152+
if len(valid_t_values) == 0:
153+
# No valid data - use reasonable defaults around threshold
154+
logger.warning("No valid t-statistics found. Using default axis limits.")
155+
y_min = -1.0
156+
y_max = 5.0
157+
else:
158+
y_min = valid_t_values.min()
159+
y_max = valid_t_values.max()
138160

139-
# Add padding for better visualization
140-
y_range = y_max - y_min
141-
padding = 0.05 * y_range if y_range > 0 else 0.5
161+
# Add padding for better visualization
162+
y_range = y_max - y_min
163+
padding = 0.05 * y_range if y_range > 0 else 0.5
142164

143-
# Ensure threshold line (p<0.001 at t=3.291) is always visible
144-
threshold = 3.291
145-
y_max = max(y_max, threshold) + padding
146-
y_min = min(y_min, 0) - padding # Allow negatives if they exist
165+
# Ensure threshold line (p<0.001 at t=3.291) is always visible
166+
threshold = 3.291
167+
y_max = max(y_max, threshold) + padding
168+
y_min = min(y_min, 0) - padding # Allow negatives if they exist
169+
170+
# Final validation to ensure limits are finite and valid
171+
if not (np.isfinite(y_min) and np.isfinite(y_max) and y_min < y_max):
172+
logger.error(f"Invalid axis limits computed: y_min={y_min}, y_max={y_max}. Using defaults.")
173+
y_min = -1.0
174+
y_max = 5.0
147175

148176
# Add threshold line
177+
threshold = 3.291
149178
ax.axhline(y=threshold, linestyle="--", color="black", label="p<0.001 threshold" if show_legend else "")
150179
ax.set_xlim(0, t_raws_df["Epoch"].max())
151180
ax.set_ylim(y_min, y_max)
@@ -239,22 +268,36 @@ def generate_t_test_avg_figure(
239268
ax.set_xlabel("Epochs completed", fontsize=12)
240269
ax.set_ylabel("$t$-value", fontsize=12)
241270

242-
# Calculate dynamic y-axis limits based on data
243-
# This ensures all data points are visible, including negative t-statistics
244-
# (e.g., when a model performs worse on its own training author)
245-
y_min = t_raws_df['t_raw'].min()
246-
y_max = t_raws_df['t_raw'].max()
271+
# Calculate dynamic y-axis limits based on VALID data only
272+
# Filter out NaN/Inf values to avoid matplotlib errors
273+
valid_t_values = t_raws_df['t_raw'].replace([np.inf, -np.inf], np.nan).dropna()
247274

248-
# Add padding for better visualization
249-
y_range = y_max - y_min
250-
padding = 0.05 * y_range if y_range > 0 else 0.5
275+
if len(valid_t_values) == 0:
276+
# No valid data - use reasonable defaults around threshold
277+
logger.warning("No valid t-statistics found for average figure. Using default axis limits.")
278+
y_min = -1.0
279+
y_max = 5.0
280+
else:
281+
y_min = valid_t_values.min()
282+
y_max = valid_t_values.max()
251283

252-
# Ensure threshold line (p<0.001 at t=3.291) is always visible
253-
threshold = 3.291
254-
y_max = max(y_max, threshold) + padding
255-
y_min = min(y_min, 0) - padding # Allow negatives if they exist
284+
# Add padding for better visualization
285+
y_range = y_max - y_min
286+
padding = 0.05 * y_range if y_range > 0 else 0.5
287+
288+
# Ensure threshold line (p<0.001 at t=3.291) is always visible
289+
threshold = 3.291
290+
y_max = max(y_max, threshold) + padding
291+
y_min = min(y_min, 0) - padding # Allow negatives if they exist
292+
293+
# Final validation to ensure limits are finite and valid
294+
if not (np.isfinite(y_min) and np.isfinite(y_max) and y_min < y_max):
295+
logger.error(f"Invalid axis limits computed for average figure: y_min={y_min}, y_max={y_max}. Using defaults.")
296+
y_min = -1.0
297+
y_max = 5.0
256298

257299
# Add threshold line
300+
threshold = 3.291
258301
ax.axhline(y=threshold, linestyle="--", color="black", label="p<0.001 threshold" if show_legend else "")
259302
ax.set_xlim(0, t_raws_df["Epoch"].max())
260303
ax.set_ylim(y_min, y_max)

0 commit comments

Comments
 (0)