Skip to content

Commit e5c3c92

Browse files
committed
Enhance statistical analysis with comprehensive variant comparisons
New features in code/compute_stats.py: 1. Automatic threshold crossing detection per author - Detects which authors cross p < 0.001 threshold and when - No longer hard-coded to Twain (was special in baseline) - Shows all 8 authors individually 2. Average t-statistic threshold crossing - Reports when average across all authors crosses threshold - Complements individual author analysis 3. Cross-variant pairwise comparisons (--cross-variant-comparison) - T-tests comparing t-value distributions between all variant pairs - Shows which conditions differ significantly - 6 pairwise comparisons: baseline vs content, baseline vs function, etc. 4. LaTeX table output - Formatted exactly as specified for paper inclusion - Scientific notation for small p-values run_stats.sh updates: - Automatically triggers cross-variant comparison when --all is used - Fixed variant data path handling Example output: Baseline: 7 authors cross at epoch 1-2, Twain at epoch 77 Average crosses at epoch 1 Cross-variant: baseline differs from all (p < 0.01) Related to #33
1 parent 5b573bb commit e5c3c92

File tree

2 files changed

+214
-29
lines changed

2 files changed

+214
-29
lines changed

code/compute_stats.py

Lines changed: 206 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,33 +38,88 @@ def load_data(data_path='data/model_results.pkl', variant=None):
3838
return df
3939

4040

41-
def find_twain_threshold_epoch(df, p_threshold=0.001):
41+
def find_threshold_crossing_epochs(df, p_threshold=0.001):
4242
"""
43-
Find the epoch where Twain model's p-value first drops below threshold.
44-
This corresponds to t-threshold of 3.291 for p < 0.001.
43+
Find epochs where each author's p-value crosses below threshold.
44+
Detects authors that start above threshold and later cross below it.
45+
46+
Returns:
47+
dict: {author: (epoch, t_stat, p_value)} for authors that cross threshold
48+
"""
49+
crossing_authors = {}
50+
51+
for author in AUTHORS:
52+
author_df = df[df['train_author'] == author].copy()
53+
epochs = sorted(author_df['epochs_completed'].unique())
54+
55+
# Track if we've seen above-threshold epochs before crossing
56+
seen_above_threshold = False
57+
58+
for epoch in epochs:
59+
epoch_df = author_df[author_df['epochs_completed'] == epoch]
60+
61+
# Get self losses
62+
self_losses = epoch_df[epoch_df['loss_dataset'] == author]['loss_value'].values
63+
64+
# Get other losses
65+
other_authors = [a for a in AUTHORS if a != author]
66+
other_losses = epoch_df[epoch_df['loss_dataset'].isin(other_authors)]['loss_value'].values
67+
68+
if len(self_losses) >= 10 and len(other_losses) >= 70:
69+
# Perform t-test (other vs self)
70+
t_stat, p_value = stats.ttest_ind(other_losses, self_losses, equal_var=False)
71+
72+
if p_value >= p_threshold:
73+
seen_above_threshold = True
74+
elif seen_above_threshold and p_value < p_threshold:
75+
# Crossed threshold!
76+
crossing_authors[author] = (epoch, t_stat, p_value)
77+
break
78+
79+
return crossing_authors
80+
81+
82+
def find_average_threshold_crossing(df, p_threshold=0.001):
83+
"""
84+
Find epoch where average t-statistic across all authors crosses threshold.
85+
86+
Returns:
87+
tuple: (epoch, avg_t_stat, p_value) or (None, None, None)
4588
"""
46-
# Filter for Twain models comparing Twain vs other authors
47-
twain_df = df[df['train_author'] == 'twain'].copy()
89+
epochs = sorted(df['epochs_completed'].unique())
4890

49-
# Get unique epochs sorted
50-
epochs = sorted(twain_df['epochs_completed'].unique())
91+
seen_above_threshold = False
5192

5293
for epoch in epochs:
53-
epoch_df = twain_df[twain_df['epochs_completed'] == epoch]
94+
# Compute average t-statistic across all authors at this epoch
95+
author_t_stats = []
5496

55-
# Get self losses (Twain model on Twain text)
56-
self_losses = epoch_df[epoch_df['loss_dataset'] == 'twain']['loss_value'].values
97+
for author in AUTHORS:
98+
author_df = df[(df['train_author'] == author) & (df['epochs_completed'] == epoch)]
5799

58-
# Get other losses (Twain model on other authors' texts)
59-
other_authors = [a for a in AUTHORS if a != 'twain']
60-
other_losses = epoch_df[epoch_df['loss_dataset'].isin(other_authors)]['loss_value'].values
100+
# Get self and other losses
101+
self_losses = author_df[author_df['loss_dataset'] == author]['loss_value'].values
102+
other_authors = [a for a in AUTHORS if a != author]
103+
other_losses = author_df[author_df['loss_dataset'].isin(other_authors)]['loss_value'].values
61104

62-
if len(self_losses) >= 10 and len(other_losses) >= 70:
63-
# Perform t-test (other vs self)
64-
t_stat, p_value = stats.ttest_ind(other_losses, self_losses, equal_var=False)
105+
if len(self_losses) > 0 and len(other_losses) > 0:
106+
# Simple t-statistic
107+
mean_diff = np.mean(other_losses) - np.mean(self_losses)
108+
pooled_std = np.sqrt((np.var(other_losses) + np.var(self_losses)) / 2)
109+
if pooled_std > 0:
110+
t_stat = mean_diff / pooled_std
111+
author_t_stats.append(t_stat)
112+
113+
if len(author_t_stats) == len(AUTHORS):
114+
avg_t = np.mean(author_t_stats)
115+
# One-sample t-test: is average t-stat significantly > 0?
116+
t_result = stats.ttest_1samp(author_t_stats, 0)
117+
p_value = t_result.pvalue / 2 # One-tailed
65118

66-
if p_value < p_threshold:
67-
return epoch, t_stat, p_value
119+
if p_value >= p_threshold:
120+
seen_above_threshold = True
121+
elif seen_above_threshold and p_value < p_threshold:
122+
return epoch, avg_t, p_value
68123

69124
return None, None, None
70125

@@ -197,6 +252,73 @@ def generate_author_comparison_table(df):
197252
return df_table, latex_table
198253

199254

255+
def compute_cross_variant_comparisons(all_variant_data, epoch=500):
256+
"""
257+
Compare t-value distributions across variants at epoch 500.
258+
259+
Args:
260+
all_variant_data: dict of {variant_name: DataFrame}
261+
epoch: Epoch to compare at (default: 500)
262+
263+
Returns:
264+
DataFrame with pairwise t-test results
265+
"""
266+
from itertools import combinations
267+
268+
# Extract t-values for each variant at epoch 500
269+
variant_t_values = {}
270+
271+
for variant_name, df in all_variant_data.items():
272+
t_values = []
273+
274+
for author in AUTHORS:
275+
# Get final epoch data for this author
276+
author_df = df[(df['train_author'] == author) & (df['epochs_completed'] == epoch)]
277+
278+
# Get self and other losses
279+
self_losses = author_df[author_df['loss_dataset'] == author]['loss_value'].values
280+
other_authors = [a for a in AUTHORS if a != author]
281+
other_losses = author_df[author_df['loss_dataset'].isin(other_authors)]['loss_value'].values
282+
283+
if len(self_losses) > 0 and len(other_losses) > 0:
284+
# Compute t-statistic
285+
if len(self_losses) == 1:
286+
mean_diff = np.mean(other_losses) - self_losses[0]
287+
std_other = np.std(other_losses)
288+
if std_other > 0:
289+
t_stat = mean_diff / (std_other / np.sqrt(len(other_losses)))
290+
t_values.append(t_stat)
291+
else:
292+
t_stat, _ = stats.ttest_ind(other_losses, self_losses, equal_var=False)
293+
if not np.isnan(t_stat):
294+
t_values.append(t_stat)
295+
296+
variant_t_values[variant_name] = t_values
297+
298+
# Pairwise comparisons
299+
results = []
300+
variant_names = list(all_variant_data.keys())
301+
302+
for var1, var2 in combinations(variant_names, 2):
303+
if var1 in variant_t_values and var2 in variant_t_values:
304+
t_vals_1 = variant_t_values[var1]
305+
t_vals_2 = variant_t_values[var2]
306+
307+
if len(t_vals_1) >= 2 and len(t_vals_2) >= 2:
308+
# T-test comparing distributions
309+
t_result = stats.ttest_ind(t_vals_1, t_vals_2, equal_var=False)
310+
311+
results.append({
312+
'Comparison': f'{var1} vs {var2}',
313+
't-stat': f'{t_result.statistic:.2f}',
314+
'df': f'{t_result.df:.2f}',
315+
'p-value': f'{t_result.pvalue:.2e}',
316+
'mean_diff': f'{np.mean(t_vals_1) - np.mean(t_vals_2):.2f}'
317+
})
318+
319+
return pd.DataFrame(results)
320+
321+
200322
def main():
201323
"""Main function to compute and display all statistics."""
202324
import argparse
@@ -213,9 +335,50 @@ def main():
213335
default='data/model_results.pkl',
214336
help='Path to model results file (default: data/model_results.pkl)'
215337
)
338+
parser.add_argument(
339+
'--cross-variant-comparison',
340+
action='store_true',
341+
help='Compute pairwise comparisons across all variants'
342+
)
216343

217344
args = parser.parse_args()
218345

346+
# Handle cross-variant comparison mode
347+
if args.cross_variant_comparison:
348+
print("=" * 60)
349+
print("Cross-Variant Comparison Analysis")
350+
print("=" * 60)
351+
352+
# Load all variant data
353+
all_variant_data = {}
354+
for var_name, var_key in [('baseline', None), ('content', 'content'), ('function', 'function'), ('pos', 'pos')]:
355+
pkl_file = f"data/model_results.pkl" if var_key is None else f"data/model_results_{var_key}.pkl"
356+
if Path(pkl_file).exists():
357+
all_variant_data[var_name] = load_data(pkl_file, var_key)
358+
else:
359+
print(f"Warning: {pkl_file} not found, skipping {var_name}")
360+
361+
if len(all_variant_data) < 2:
362+
print("Error: Need at least 2 variants for comparison")
363+
return
364+
365+
print(f"\nLoaded {len(all_variant_data)} conditions: {list(all_variant_data.keys())}")
366+
367+
# Compute pairwise comparisons
368+
print("\nPairwise T-Test Comparisons (Epoch 500)")
369+
print("Comparing distributions of t-statistics across all authors")
370+
print("-" * 60)
371+
372+
comparison_df = compute_cross_variant_comparisons(all_variant_data, epoch=500)
373+
374+
if not comparison_df.empty:
375+
print("\n" + comparison_df.to_string(index=False))
376+
else:
377+
print("No comparisons could be computed")
378+
379+
print("\n" + "=" * 60)
380+
return
381+
219382
# Update header to show variant
220383
variant_label = f" (Variant: {args.variant})" if args.variant else " (Baseline)"
221384
print("=" * 60)
@@ -226,19 +389,33 @@ def main():
226389
print("\nLoading data...")
227390
df = load_data(data_path=args.data, variant=args.variant)
228391

229-
# 1. Find Twain threshold epoch
230-
print("\n1. Twain Model P-Threshold Analysis")
392+
# 1. Find threshold crossing epochs per author
393+
print("\n1. Individual Author Threshold Crossings (p < 0.001)")
394+
print("-" * 40)
395+
crossing_authors = find_threshold_crossing_epochs(df)
396+
if crossing_authors:
397+
for author in AUTHORS:
398+
if author in crossing_authors:
399+
epoch, t_stat, p_value = crossing_authors[author]
400+
print(f"{author.capitalize():<12}: Epoch {epoch:3d} (t={t_stat:.2f}, p={p_value:.2e})")
401+
else:
402+
print(f"{author.capitalize():<12}: No threshold crossing detected")
403+
else:
404+
print("No authors crossed threshold (started below or never crossed)")
405+
406+
# 2. Average t-statistic threshold crossing
407+
print("\n2. Average T-Statistic Threshold Crossing (p < 0.001)")
231408
print("-" * 40)
232-
epoch, t_stat, p_value = find_twain_threshold_epoch(df)
409+
epoch, avg_t, p_value = find_average_threshold_crossing(df)
233410
if epoch is not None:
234-
print(f"First epoch where p < 0.001: {epoch}")
235-
print(f"t-statistic at epoch {epoch}: {t_stat:.3f}")
236-
print(f"p-value at epoch {epoch}: {p_value:.3e}")
411+
print(f"Average t-stat crossed threshold at epoch: {epoch}")
412+
print(f"Average t-statistic: {avg_t:.3f}")
413+
print(f"p-value: {p_value:.2e}")
237414
else:
238-
print("Threshold not reached within training epochs")
415+
print("Average t-statistic did not cross threshold")
239416

240-
# 2. Average t-test at final epoch
241-
print("\n2. Average T-Test Across Authors (Epoch 500)")
417+
# 3. Average t-test at final epoch
418+
print("\n3. Average T-Test Across Authors (Epoch 500)")
242419
print("-" * 40)
243420
t_stat, p_value, df_val = compute_average_t_test(df, epoch=500)
244421
if t_stat is not None:
@@ -252,8 +429,8 @@ def main():
252429
else:
253430
print("Insufficient data for t-test")
254431

255-
# 3. Author comparison table
256-
print("\n3. Author Model Comparison Table (Table 1)")
432+
# 4. Author comparison table
433+
print("\n4. Author Model Comparison Table (Table 1)")
257434
print("-" * 40)
258435
table, latex_table = generate_author_comparison_table(df)
259436

run_stats.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,12 @@ for variant in "${VARIANTS[@]}"; do
139139
echo
140140
done
141141

142+
# If --all was specified, compute cross-variant comparisons
143+
if [ ${#VARIANTS[@]} -eq 4 ]; then
144+
echo
145+
print_info "Computing cross-variant comparisons..."
146+
python code/compute_stats.py --cross-variant-comparison
147+
echo
148+
fi
149+
142150
print_success "Statistical analysis complete!"

0 commit comments

Comments
 (0)