@@ -38,33 +38,88 @@ def load_data(data_path='data/model_results.pkl', variant=None):
3838 return df
3939
4040
41- def find_twain_threshold_epoch (df , p_threshold = 0.001 ):
41+ def find_threshold_crossing_epochs (df , p_threshold = 0.001 ):
4242 """
43- Find the epoch where Twain model's p-value first drops below threshold.
44- This corresponds to t-threshold of 3.291 for p < 0.001.
43+ Find epochs where each author's p-value crosses below threshold.
44+ Detects authors that start above threshold and later cross below it.
45+
46+ Returns:
47+ dict: {author: (epoch, t_stat, p_value)} for authors that cross threshold
48+ """
49+ crossing_authors = {}
50+
51+ for author in AUTHORS :
52+ author_df = df [df ['train_author' ] == author ].copy ()
53+ epochs = sorted (author_df ['epochs_completed' ].unique ())
54+
55+ # Track if we've seen above-threshold epochs before crossing
56+ seen_above_threshold = False
57+
58+ for epoch in epochs :
59+ epoch_df = author_df [author_df ['epochs_completed' ] == epoch ]
60+
61+ # Get self losses
62+ self_losses = epoch_df [epoch_df ['loss_dataset' ] == author ]['loss_value' ].values
63+
64+ # Get other losses
65+ other_authors = [a for a in AUTHORS if a != author ]
66+ other_losses = epoch_df [epoch_df ['loss_dataset' ].isin (other_authors )]['loss_value' ].values
67+
68+ if len (self_losses ) >= 10 and len (other_losses ) >= 70 :
69+ # Perform t-test (other vs self)
70+ t_stat , p_value = stats .ttest_ind (other_losses , self_losses , equal_var = False )
71+
72+ if p_value >= p_threshold :
73+ seen_above_threshold = True
74+ elif seen_above_threshold and p_value < p_threshold :
75+ # Crossed threshold!
76+ crossing_authors [author ] = (epoch , t_stat , p_value )
77+ break
78+
79+ return crossing_authors
80+
81+
82+ def find_average_threshold_crossing (df , p_threshold = 0.001 ):
83+ """
84+ Find epoch where average t-statistic across all authors crosses threshold.
85+
86+ Returns:
87+ tuple: (epoch, avg_t_stat, p_value) or (None, None, None)
4588 """
46- # Filter for Twain models comparing Twain vs other authors
47- twain_df = df [df ['train_author' ] == 'twain' ].copy ()
89+ epochs = sorted (df ['epochs_completed' ].unique ())
4890
49- # Get unique epochs sorted
50- epochs = sorted (twain_df ['epochs_completed' ].unique ())
91+ seen_above_threshold = False
5192
5293 for epoch in epochs :
53- epoch_df = twain_df [twain_df ['epochs_completed' ] == epoch ]
94+ # Compute average t-statistic across all authors at this epoch
95+ author_t_stats = []
5496
55- # Get self losses (Twain model on Twain text)
56- self_losses = epoch_df [ epoch_df [ 'loss_dataset ' ] == 'twain' ][ 'loss_value' ]. values
97+ for author in AUTHORS :
98+ author_df = df [( df [ 'train_author ' ] == author ) & ( df [ 'epochs_completed' ] == epoch )]
5799
58- # Get other losses (Twain model on other authors' texts)
59- other_authors = [a for a in AUTHORS if a != 'twain' ]
60- other_losses = epoch_df [epoch_df ['loss_dataset' ].isin (other_authors )]['loss_value' ].values
100+ # Get self and other losses
101+ self_losses = author_df [author_df ['loss_dataset' ] == author ]['loss_value' ].values
102+ other_authors = [a for a in AUTHORS if a != author ]
103+ other_losses = author_df [author_df ['loss_dataset' ].isin (other_authors )]['loss_value' ].values
61104
62- if len (self_losses ) >= 10 and len (other_losses ) >= 70 :
63- # Perform t-test (other vs self)
64- t_stat , p_value = stats .ttest_ind (other_losses , self_losses , equal_var = False )
105+ if len (self_losses ) > 0 and len (other_losses ) > 0 :
106+ # Simple t-statistic
107+ mean_diff = np .mean (other_losses ) - np .mean (self_losses )
108+ pooled_std = np .sqrt ((np .var (other_losses ) + np .var (self_losses )) / 2 )
109+ if pooled_std > 0 :
110+ t_stat = mean_diff / pooled_std
111+ author_t_stats .append (t_stat )
112+
113+ if len (author_t_stats ) == len (AUTHORS ):
114+ avg_t = np .mean (author_t_stats )
115+ # One-sample t-test: is average t-stat significantly > 0?
116+ t_result = stats .ttest_1samp (author_t_stats , 0 )
117+ p_value = t_result .pvalue / 2 # One-tailed
65118
66- if p_value < p_threshold :
67- return epoch , t_stat , p_value
119+ if p_value >= p_threshold :
120+ seen_above_threshold = True
121+ elif seen_above_threshold and p_value < p_threshold :
122+ return epoch , avg_t , p_value
68123
69124 return None , None , None
70125
@@ -197,6 +252,73 @@ def generate_author_comparison_table(df):
197252 return df_table , latex_table
198253
199254
255+ def compute_cross_variant_comparisons (all_variant_data , epoch = 500 ):
256+ """
257+ Compare t-value distributions across variants at epoch 500.
258+
259+ Args:
260+ all_variant_data: dict of {variant_name: DataFrame}
261+ epoch: Epoch to compare at (default: 500)
262+
263+ Returns:
264+ DataFrame with pairwise t-test results
265+ """
266+ from itertools import combinations
267+
268+ # Extract t-values for each variant at epoch 500
269+ variant_t_values = {}
270+
271+ for variant_name , df in all_variant_data .items ():
272+ t_values = []
273+
274+ for author in AUTHORS :
275+ # Get final epoch data for this author
276+ author_df = df [(df ['train_author' ] == author ) & (df ['epochs_completed' ] == epoch )]
277+
278+ # Get self and other losses
279+ self_losses = author_df [author_df ['loss_dataset' ] == author ]['loss_value' ].values
280+ other_authors = [a for a in AUTHORS if a != author ]
281+ other_losses = author_df [author_df ['loss_dataset' ].isin (other_authors )]['loss_value' ].values
282+
283+ if len (self_losses ) > 0 and len (other_losses ) > 0 :
284+ # Compute t-statistic
285+ if len (self_losses ) == 1 :
286+ mean_diff = np .mean (other_losses ) - self_losses [0 ]
287+ std_other = np .std (other_losses )
288+ if std_other > 0 :
289+ t_stat = mean_diff / (std_other / np .sqrt (len (other_losses )))
290+ t_values .append (t_stat )
291+ else :
292+ t_stat , _ = stats .ttest_ind (other_losses , self_losses , equal_var = False )
293+ if not np .isnan (t_stat ):
294+ t_values .append (t_stat )
295+
296+ variant_t_values [variant_name ] = t_values
297+
298+ # Pairwise comparisons
299+ results = []
300+ variant_names = list (all_variant_data .keys ())
301+
302+ for var1 , var2 in combinations (variant_names , 2 ):
303+ if var1 in variant_t_values and var2 in variant_t_values :
304+ t_vals_1 = variant_t_values [var1 ]
305+ t_vals_2 = variant_t_values [var2 ]
306+
307+ if len (t_vals_1 ) >= 2 and len (t_vals_2 ) >= 2 :
308+ # T-test comparing distributions
309+ t_result = stats .ttest_ind (t_vals_1 , t_vals_2 , equal_var = False )
310+
311+ results .append ({
312+ 'Comparison' : f'{ var1 } vs { var2 } ' ,
313+ 't-stat' : f'{ t_result .statistic :.2f} ' ,
314+ 'df' : f'{ t_result .df :.2f} ' ,
315+ 'p-value' : f'{ t_result .pvalue :.2e} ' ,
316+ 'mean_diff' : f'{ np .mean (t_vals_1 ) - np .mean (t_vals_2 ):.2f} '
317+ })
318+
319+ return pd .DataFrame (results )
320+
321+
200322def main ():
201323 """Main function to compute and display all statistics."""
202324 import argparse
@@ -213,9 +335,50 @@ def main():
213335 default = 'data/model_results.pkl' ,
214336 help = 'Path to model results file (default: data/model_results.pkl)'
215337 )
338+ parser .add_argument (
339+ '--cross-variant-comparison' ,
340+ action = 'store_true' ,
341+ help = 'Compute pairwise comparisons across all variants'
342+ )
216343
217344 args = parser .parse_args ()
218345
346+ # Handle cross-variant comparison mode
347+ if args .cross_variant_comparison :
348+ print ("=" * 60 )
349+ print ("Cross-Variant Comparison Analysis" )
350+ print ("=" * 60 )
351+
352+ # Load all variant data
353+ all_variant_data = {}
354+ for var_name , var_key in [('baseline' , None ), ('content' , 'content' ), ('function' , 'function' ), ('pos' , 'pos' )]:
355+ pkl_file = f"data/model_results.pkl" if var_key is None else f"data/model_results_{ var_key } .pkl"
356+ if Path (pkl_file ).exists ():
357+ all_variant_data [var_name ] = load_data (pkl_file , var_key )
358+ else :
359+ print (f"Warning: { pkl_file } not found, skipping { var_name } " )
360+
361+ if len (all_variant_data ) < 2 :
362+ print ("Error: Need at least 2 variants for comparison" )
363+ return
364+
365+ print (f"\n Loaded { len (all_variant_data )} conditions: { list (all_variant_data .keys ())} " )
366+
367+ # Compute pairwise comparisons
368+ print ("\n Pairwise T-Test Comparisons (Epoch 500)" )
369+ print ("Comparing distributions of t-statistics across all authors" )
370+ print ("-" * 60 )
371+
372+ comparison_df = compute_cross_variant_comparisons (all_variant_data , epoch = 500 )
373+
374+ if not comparison_df .empty :
375+ print ("\n " + comparison_df .to_string (index = False ))
376+ else :
377+ print ("No comparisons could be computed" )
378+
379+ print ("\n " + "=" * 60 )
380+ return
381+
219382 # Update header to show variant
220383 variant_label = f" (Variant: { args .variant } )" if args .variant else " (Baseline)"
221384 print ("=" * 60 )
@@ -226,19 +389,33 @@ def main():
226389 print ("\n Loading data..." )
227390 df = load_data (data_path = args .data , variant = args .variant )
228391
229- # 1. Find Twain threshold epoch
230- print ("\n 1. Twain Model P-Threshold Analysis" )
392+ # 1. Find threshold crossing epochs per author
393+ print ("\n 1. Individual Author Threshold Crossings (p < 0.001)" )
394+ print ("-" * 40 )
395+ crossing_authors = find_threshold_crossing_epochs (df )
396+ if crossing_authors :
397+ for author in AUTHORS :
398+ if author in crossing_authors :
399+ epoch , t_stat , p_value = crossing_authors [author ]
400+ print (f"{ author .capitalize ():<12} : Epoch { epoch :3d} (t={ t_stat :.2f} , p={ p_value :.2e} )" )
401+ else :
402+ print (f"{ author .capitalize ():<12} : No threshold crossing detected" )
403+ else :
404+ print ("No authors crossed threshold (started below or never crossed)" )
405+
406+ # 2. Average t-statistic threshold crossing
407+ print ("\n 2. Average T-Statistic Threshold Crossing (p < 0.001)" )
231408 print ("-" * 40 )
232- epoch , t_stat , p_value = find_twain_threshold_epoch (df )
409+ epoch , avg_t , p_value = find_average_threshold_crossing (df )
233410 if epoch is not None :
234- print (f"First epoch where p < 0.001 : { epoch } " )
235- print (f"t-statistic at epoch { epoch } : { t_stat :.3f} " )
236- print (f"p-value at epoch { epoch } : { p_value :.3e } " )
411+ print (f"Average t-stat crossed threshold at epoch : { epoch } " )
412+ print (f"Average t-statistic: { avg_t :.3f} " )
413+ print (f"p-value: { p_value :.2e } " )
237414 else :
238- print ("Threshold not reached within training epochs " )
415+ print ("Average t-statistic did not cross threshold " )
239416
240- # 2 . Average t-test at final epoch
241- print ("\n 2 . Average T-Test Across Authors (Epoch 500)" )
417+ # 3 . Average t-test at final epoch
418+ print ("\n 3 . Average T-Test Across Authors (Epoch 500)" )
242419 print ("-" * 40 )
243420 t_stat , p_value , df_val = compute_average_t_test (df , epoch = 500 )
244421 if t_stat is not None :
@@ -252,8 +429,8 @@ def main():
252429 else :
253430 print ("Insufficient data for t-test" )
254431
255- # 3 . Author comparison table
256- print ("\n 3 . Author Model Comparison Table (Table 1)" )
432+ # 4 . Author comparison table
433+ print ("\n 4 . Author Model Comparison Table (Table 1)" )
257434 print ("-" * 40 )
258435 table , latex_table = generate_author_comparison_table (df )
259436
0 commit comments