11import logging
22import math
3+ import os
34import random
45import re
56import string
67from collections .abc import Iterable
7- from typing import List
8+ from typing import Callable , List , Optional , Sequence , TypeVar
89
910import numpy as np
1011import sacrebleu
1112
1213from lm_eval .api .registry import register_aggregation , register_metric
1314
1415
16+ T = TypeVar ("T" )
17+
1518eval_logger = logging .getLogger (__name__ )
1619
1720
@@ -287,7 +290,7 @@ def pop_stddev(arr):
287290 return math .sqrt (sum ([(x - mu ) ** 2 for x in arr ]) / len (arr ))
288291
289292
290- def sample_stddev (arr ) :
293+ def sample_stddev (arr : Sequence [ T ]) -> float :
291294 mu = mean (arr )
292295 return math .sqrt (sum ([(x - mu ) ** 2 for x in arr ]) / (len (arr ) - 1 ))
293296
@@ -449,11 +452,16 @@ def _sacreformat(refs, preds):
449452
450453
451454class _bootstrap_internal :
452- def __init__ (self , f , n ) -> None :
455+ """
456+ Pool worker: `(i, xs)` → `n` bootstrap replicates
457+ of `f(xs)`using a RNG seeded with `i`.
458+ """
459+
460+ def __init__ (self , f : Callable [[Sequence [T ]], float ], n : int ) -> None :
453461 self .f = f
454462 self .n = n
455463
456- def __call__ (self , v ) :
464+ def __call__ (self , v : tuple [ int , Sequence [ T ]]) -> list [ float ] :
457465 i , xs = v
458466 rnd = random .Random ()
459467 rnd .seed (i )
@@ -463,36 +471,81 @@ def __call__(self, v):
463471 return res
464472
465473
466- def bootstrap_stderr (f , xs , iters ):
467- import multiprocessing as mp
468-
469- pool = mp .Pool (mp .cpu_count ())
470- # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
471- # equivalent to stderr calculated without Bessel's correction in the stddev.
472- # Unfortunately, I haven't been able to figure out what the right correction is
473- # to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
474- # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
475- # Thankfully, shouldn't matter because our samples are pretty big usually anyways
474+ def _bootstrap_internal_no_mp (
475+ f : Callable [[Sequence [T ]], float ], xs : Sequence [T ], iters : int
476+ ) -> list [float ]:
477+ """
478+ Single-process fallback: compute `iters` bootstrap replicates
479+ of statistic`f(xs)`, chunked (≤ 1000 draws).
480+ """
476481 res = []
477482 chunk_size = min (1000 , iters )
478483 from tqdm import tqdm
479484
480- print ("bootstrapping for stddev:" , f .__name__ )
481- for bootstrap in tqdm (
482- pool .imap (
483- _bootstrap_internal (f , chunk_size ),
484- [(i , xs ) for i in range (iters // chunk_size )],
485- ),
486- total = iters // chunk_size ,
487- ):
488- # sample w replacement
489- res .extend (bootstrap )
490-
491- pool .close ()
485+ print (f"bootstrapping for stddev: { f .__name__ } " )
486+
487+ # A single loop replaces the multiprocessing pool.
488+ for i in tqdm (range (iters // chunk_size )):
489+ rnd = random .Random (i )
490+ for _ in range (chunk_size ):
491+ res .append (f (rnd .choices (xs , k = len (xs ))))
492+
493+ return res
494+
495+
496+ def bootstrap_stderr (
497+ f : Callable [[Sequence [T ]], float ], xs : Sequence [T ], iters : int
498+ ) -> float :
499+ """
500+ Bootstrap estimate of the standard error of statistic `f(xs)`
501+ using up to `iters` resamples, chunked (≤ 1000 draws)
502+
503+ Executes in parallel unless the env-var `DISABLE_MULTIPROC` is set;
504+ """
505+ if not os .getenv ("DISABLE_MULTIPROC" ):
506+ import multiprocessing as mp
507+
508+ pool = mp .Pool (mp .cpu_count ())
509+ # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
510+ # equivalent to stderr calculated without Bessel's correction in the stddev.
511+ # Unfortunately, I haven't been able to figure out what the right correction is
512+ # to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
513+ # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
514+ # Thankfully, shouldn't matter because our samples are pretty big usually anyways
515+ res = []
516+ chunk_size = min (1000 , iters )
517+ from tqdm import tqdm
518+
519+ print ("bootstrapping for stddev:" , f .__name__ )
520+ for bootstrap in tqdm (
521+ pool .imap (
522+ _bootstrap_internal (f , chunk_size ),
523+ [(i , xs ) for i in range (iters // chunk_size )],
524+ ),
525+ total = iters // chunk_size ,
526+ ):
527+ # sample w replacement
528+ res .extend (bootstrap )
529+
530+ pool .close ()
531+ else :
532+ res = _bootstrap_internal_no_mp (f , xs , iters )
533+
492534 return sample_stddev (res )
493535
494536
495- def stderr_for_metric (metric , bootstrap_iters : int ):
537+ def stderr_for_metric (
538+ metric : Callable [[Sequence [T ]], float ], bootstrap_iters : int
539+ ) -> Optional [Callable [[Sequence [T ]], float ]]:
540+ """
541+ Return a function that estimates the standard error of `metric(xs)`.
542+
543+ * If `bootstrap_iters > 0` and the metric is in the pre-approved
544+ bootstrappable list, use `bootstrap_stderr` with that many draws.
545+ * If the metric has a closed-form SE (e.g. `mean`, `acc_all`), use it.
546+ * Otherwise, return `None`.
547+ """
548+
496549 if bootstrap_iters <= 0 :
497550 # return no function (don't compute stderr) if bootstrap iters = 0
498551 return None
0 commit comments