33import pandas as pd
44import seaborn as sns
55import matplotlib .pyplot as plt
6- from scipy .stats import ttest_ind
6+ from scipy .stats import ttest_ind , t as t_dist
77import numpy as np
88from tqdm import tqdm
99import logging
1212
1313
1414def calculate_t_statistics (df , max_epochs = 500 ):
15- """Calculate t-statistics comparing same vs other author losses."""
15+ """
16+ Calculate t-statistics and df comparing same vs other author losses.
17+
18+ Returns:
19+ tuple: (t_raws_df, t_raws, df_values, thresholds)
20+ - t_raws_df: Long-form DataFrame with columns [Epoch, Author, t_raw]
21+ - t_raws: Dict mapping author to list of t-values
22+ - df_values: Dict mapping author to list of degrees of freedom
23+ - thresholds: Dict mapping author to list of t-thresholds for p=0.001
24+ """
1625
1726 # Define authors
1827 AUTHORS = ["baum" , "thompson" , "dickens" , "melville" , "wells" , "austen" , "fitzgerald" , "twain" ]
@@ -27,6 +36,8 @@ def calculate_t_statistics(df, max_epochs=500):
2736 authors = sorted (t_df ["train_author" ].unique ())
2837 epochs = sorted (t_df ["epochs_completed" ].unique ())
2938 t_raws = {author : [] for author in authors }
39+ df_values = {author : [] for author in authors }
40+ thresholds = {author : [] for author in authors }
3041
3142 # Compute Welch's t-statistic for each author/epoch
3243 for author in tqdm (authors , desc = "Processing authors" ):
@@ -50,16 +61,25 @@ def calculate_t_statistics(df, max_epochs=500):
5061 logger .debug (f"NaN t-statistic for { author } at epoch { epoch } : "
5162 f"n_true={ len (true_losses )} , n_other={ len (other_losses )} " )
5263 t_raws [author ].append (result .statistic )
64+ df_values [author ].append (result .df )
65+
66+ # Compute t-threshold for p=0.001 (one-tailed) given this df
67+ t_threshold = t_dist .ppf (1 - 0.001 , result .df )
68+ thresholds [author ].append (t_threshold )
5369 elif len (true_losses ) > 0 or len (other_losses ) > 0 :
5470 # Have some data but insufficient for t-test
5571 logger .debug (f"Insufficient data for t-test for { author } at epoch { epoch } : "
5672 f"n_true={ len (true_losses )} , n_other={ len (other_losses )} "
5773 f"(need at least 2 samples per group)" )
5874 t_raws [author ].append (np .nan )
75+ df_values [author ].append (np .nan )
76+ thresholds [author ].append (np .nan )
5977 else :
6078 # No data at all
6179 logger .debug (f"No data for { author } at epoch { epoch } " )
6280 t_raws [author ].append (np .nan )
81+ df_values [author ].append (np .nan )
82+ thresholds [author ].append (np .nan )
6383
6484 # Convert to long-form DataFrame
6585 t_raws_df = (
@@ -69,7 +89,7 @@ def calculate_t_statistics(df, max_epochs=500):
6989 .rename (columns = {"index" : "Epoch" })
7090 )
7191
72- return t_raws_df , t_raws
92+ return t_raws_df , t_raws , df_values , thresholds
7393
7494
7595def generate_t_test_figure (
@@ -113,7 +133,25 @@ def generate_t_test_figure(
113133 raise ValueError ("No variant column in data" )
114134 df = df [df ['variant' ] == variant ].copy ()
115135
116- t_raws_df , _ = calculate_t_statistics (df )
136+ t_raws_df , _ , df_values , thresholds = calculate_t_statistics (df )
137+
138+ # Compute average threshold across authors at each epoch (for plotting)
139+ epochs = sorted (t_raws_df ["Epoch" ].unique ())
140+ threshold_data = []
141+ for epoch in epochs :
142+ epoch_thresholds = []
143+ for author in thresholds .keys ():
144+ epoch_idx = list (epochs ).index (epoch )
145+ if epoch_idx < len (thresholds [author ]):
146+ thresh = thresholds [author ][epoch_idx ]
147+ if not np .isnan (thresh ):
148+ epoch_thresholds .append (thresh )
149+
150+ # For each epoch, add one row per author's threshold (for bootstrap CI calculation)
151+ for thresh in epoch_thresholds :
152+ threshold_data .append ({'Epoch' : epoch , 'threshold' : thresh })
153+
154+ threshold_df = pd .DataFrame (threshold_data )
117155
118156 # Define color palette
119157 unique_authors = sorted (t_raws_df ["Author" ].unique ())
@@ -124,6 +162,7 @@ def generate_t_test_figure(
124162 # Create figure
125163 fig , ax = plt .subplots (figsize = figsize )
126164
165+ # Plot author t-statistics
127166 sns .lineplot (
128167 data = t_raws_df ,
129168 x = "Epoch" ,
@@ -135,47 +174,46 @@ def generate_t_test_figure(
135174 legend = show_legend ,
136175 )
137176
177+ # Plot adaptive threshold with bootstrap 95% CI (solid black line)
178+ if not threshold_df .empty :
179+ sns .lineplot (
180+ data = threshold_df ,
181+ x = "Epoch" ,
182+ y = "threshold" ,
183+ ax = ax ,
184+ color = "black" ,
185+ linewidth = 2 ,
186+ linestyle = "-" , # Solid line
187+ errorbar = 'ci' , # Bootstrap 95% CI
188+ label = "p<0.001 threshold" if show_legend else ""
189+ )
190+
138191 sns .despine (ax = ax , top = True , right = True )
139- # Remove title as requested
140- # ax.set_title(
141- # "$t$-values: training author vs. other authors",
142- # fontsize=12,
143- # pad=10,
144- # )
145192 ax .set_xlabel ("Epochs completed" , fontsize = 12 )
146193 ax .set_ylabel ("$t$-value" , fontsize = 12 )
147194
148195 # Calculate dynamic y-axis limits based on VALID data only
149- # Filter out NaN/Inf values to avoid matplotlib errors
150196 valid_t_values = t_raws_df ['t_raw' ].replace ([np .inf , - np .inf ], np .nan ).dropna ()
151197
152198 if len (valid_t_values ) == 0 :
153- # No valid data - use reasonable defaults around threshold
154199 logger .warning ("No valid t-statistics found. Using default axis limits." )
155200 y_min = - 1.0
156201 y_max = 5.0
157202 else :
158203 y_min = valid_t_values .min ()
159204 y_max = valid_t_values .max ()
160205
161- # Add padding for better visualization
206+ # Add padding
162207 y_range = y_max - y_min
163208 padding = 0.05 * y_range if y_range > 0 else 0.5
209+ y_min = min (y_min , 0 ) - padding
210+ y_max = y_max + padding
164211
165- # Ensure threshold line (p<0.001 at t=3.291) is always visible
166- threshold = 3.291
167- y_max = max (y_max , threshold ) + padding
168- y_min = min (y_min , 0 ) - padding # Allow negatives if they exist
169-
170- # Final validation to ensure limits are finite and valid
212+ # Final validation
171213 if not (np .isfinite (y_min ) and np .isfinite (y_max ) and y_min < y_max ):
172214 logger .error (f"Invalid axis limits computed: y_min={ y_min } , y_max={ y_max } . Using defaults." )
173215 y_min = - 1.0
174216 y_max = 5.0
175-
176- # Add threshold line
177- threshold = 3.291
178- ax .axhline (y = threshold , linestyle = "--" , color = "black" , label = "p<0.001 threshold" if show_legend else "" )
179217 ax .set_xlim (0 , t_raws_df ["Epoch" ].max ())
180218 ax .set_ylim (y_min , y_max )
181219
@@ -244,11 +282,29 @@ def generate_t_test_avg_figure(
244282 raise ValueError ("No variant column in data" )
245283 df = df [df ['variant' ] == variant ].copy ()
246284
247- t_raws_df , _ = calculate_t_statistics (df )
285+ t_raws_df , _ , df_values , thresholds = calculate_t_statistics (df )
286+
287+ # Compute average threshold across authors at each epoch
288+ epochs = sorted (t_raws_df ["Epoch" ].unique ())
289+ threshold_data = []
290+ for epoch in epochs :
291+ epoch_thresholds = []
292+ for author in thresholds .keys ():
293+ epoch_idx = list (epochs ).index (epoch )
294+ if epoch_idx < len (thresholds [author ]):
295+ thresh = thresholds [author ][epoch_idx ]
296+ if not np .isnan (thresh ):
297+ epoch_thresholds .append (thresh )
298+
299+ for thresh in epoch_thresholds :
300+ threshold_data .append ({'Epoch' : epoch , 'threshold' : thresh })
301+
302+ threshold_df = pd .DataFrame (threshold_data )
248303
249304 # Create figure
250305 fig , ax = plt .subplots (figsize = figsize )
251306
307+ # Plot average t-statistic
252308 sns .lineplot (
253309 data = t_raws_df ,
254310 x = "Epoch" ,
@@ -258,47 +314,47 @@ def generate_t_test_avg_figure(
258314 color = "black" , # Set line color to black
259315 )
260316
317+ # Plot adaptive threshold with bootstrap 95% CI (solid gray line)
318+ if not threshold_df .empty :
319+ sns .lineplot (
320+ data = threshold_df ,
321+ x = "Epoch" ,
322+ y = "threshold" ,
323+ ax = ax ,
324+ color = "gray" ,
325+ linewidth = 2 ,
326+ linestyle = "-" , # Solid line
327+ errorbar = 'ci' , # Bootstrap 95% CI
328+ label = "p<0.001 threshold" if show_legend else ""
329+ )
330+
261331 sns .despine (ax = ax , top = True , right = True )
262- # Remove title as requested
263- # ax.set_title(
264- # "Average $t$-values: training author vs. other authors",
265- # fontsize=12,
266- # pad=10,
267- # )
268332 ax .set_xlabel ("Epochs completed" , fontsize = 12 )
269333 ax .set_ylabel ("$t$-value" , fontsize = 12 )
270334
271- # Calculate dynamic y-axis limits based on VALID data only
272- # Filter out NaN/Inf values to avoid matplotlib errors
335+ # Calculate dynamic y-axis limits
273336 valid_t_values = t_raws_df ['t_raw' ].replace ([np .inf , - np .inf ], np .nan ).dropna ()
274337
275338 if len (valid_t_values ) == 0 :
276- # No valid data - use reasonable defaults around threshold
277339 logger .warning ("No valid t-statistics found for average figure. Using default axis limits." )
278340 y_min = - 1.0
279341 y_max = 5.0
280342 else :
281343 y_min = valid_t_values .min ()
282344 y_max = valid_t_values .max ()
283345
284- # Add padding for better visualization
346+ # Add padding
285347 y_range = y_max - y_min
286348 padding = 0.05 * y_range if y_range > 0 else 0.5
349+ y_min = min (y_min , 0 ) - padding
350+ y_max = y_max + padding
287351
288- # Ensure threshold line (p<0.001 at t=3.291) is always visible
289- threshold = 3.291
290- y_max = max (y_max , threshold ) + padding
291- y_min = min (y_min , 0 ) - padding # Allow negatives if they exist
292-
293- # Final validation to ensure limits are finite and valid
352+ # Final validation
294353 if not (np .isfinite (y_min ) and np .isfinite (y_max ) and y_min < y_max ):
295354 logger .error (f"Invalid axis limits computed for average figure: y_min={ y_min } , y_max={ y_max } . Using defaults." )
296355 y_min = - 1.0
297356 y_max = 5.0
298357
299- # Add threshold line
300- threshold = 3.291
301- ax .axhline (y = threshold , linestyle = "--" , color = "black" , label = "p<0.001 threshold" if show_legend else "" )
302358 ax .set_xlim (0 , t_raws_df ["Epoch" ].max ())
303359 ax .set_ylim (y_min , y_max )
304360
0 commit comments