Skip to content

Commit 71d0289

Browse files
neel04baberabb
andauthored
[FIX] Initial code to disable multi-proc for stderr (#3106)
* [FIX] Initial code to disable multi-proc for stderr * add docs; align no-mp bootstrap with mp --------- Co-authored-by: Baber <[email protected]>
1 parent ff41a85 commit 71d0289

File tree

2 files changed

+109
-27
lines changed

2 files changed

+109
-27
lines changed

lm_eval/api/metrics.py

Lines changed: 80 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import logging
22
import math
3+
import os
34
import random
45
import re
56
import string
67
from collections.abc import Iterable
7-
from typing import List
8+
from typing import Callable, List, Optional, Sequence, TypeVar
89

910
import numpy as np
1011
import sacrebleu
1112

1213
from lm_eval.api.registry import register_aggregation, register_metric
1314

1415

16+
T = TypeVar("T")
17+
1518
eval_logger = logging.getLogger(__name__)
1619

1720

@@ -287,7 +290,7 @@ def pop_stddev(arr):
287290
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
288291

289292

290-
def sample_stddev(arr):
293+
def sample_stddev(arr: Sequence[T]) -> float:
291294
mu = mean(arr)
292295
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
293296

@@ -449,11 +452,16 @@ def _sacreformat(refs, preds):
449452

450453

451454
class _bootstrap_internal:
452-
def __init__(self, f, n) -> None:
455+
"""
456+
Pool worker: `(i, xs)` → `n` bootstrap replicates
457+
of `f(xs)`using a RNG seeded with `i`.
458+
"""
459+
460+
def __init__(self, f: Callable[[Sequence[T]], float], n: int) -> None:
453461
self.f = f
454462
self.n = n
455463

456-
def __call__(self, v):
464+
def __call__(self, v: tuple[int, Sequence[T]]) -> list[float]:
457465
i, xs = v
458466
rnd = random.Random()
459467
rnd.seed(i)
@@ -463,36 +471,81 @@ def __call__(self, v):
463471
return res
464472

465473

466-
def bootstrap_stderr(f, xs, iters):
467-
import multiprocessing as mp
468-
469-
pool = mp.Pool(mp.cpu_count())
470-
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
471-
# equivalent to stderr calculated without Bessel's correction in the stddev.
472-
# Unfortunately, I haven't been able to figure out what the right correction is
473-
# to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
474-
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
475-
# Thankfully, shouldn't matter because our samples are pretty big usually anyways
474+
def _bootstrap_internal_no_mp(
475+
f: Callable[[Sequence[T]], float], xs: Sequence[T], iters: int
476+
) -> list[float]:
477+
"""
478+
Single-process fallback: compute `iters` bootstrap replicates
479+
of statistic`f(xs)`, chunked (≤ 1000 draws).
480+
"""
476481
res = []
477482
chunk_size = min(1000, iters)
478483
from tqdm import tqdm
479484

480-
print("bootstrapping for stddev:", f.__name__)
481-
for bootstrap in tqdm(
482-
pool.imap(
483-
_bootstrap_internal(f, chunk_size),
484-
[(i, xs) for i in range(iters // chunk_size)],
485-
),
486-
total=iters // chunk_size,
487-
):
488-
# sample w replacement
489-
res.extend(bootstrap)
490-
491-
pool.close()
485+
print(f"bootstrapping for stddev: {f.__name__}")
486+
487+
# A single loop replaces the multiprocessing pool.
488+
for i in tqdm(range(iters // chunk_size)):
489+
rnd = random.Random(i)
490+
for _ in range(chunk_size):
491+
res.append(f(rnd.choices(xs, k=len(xs))))
492+
493+
return res
494+
495+
496+
def bootstrap_stderr(
497+
f: Callable[[Sequence[T]], float], xs: Sequence[T], iters: int
498+
) -> float:
499+
"""
500+
Bootstrap estimate of the standard error of statistic `f(xs)`
501+
using up to `iters` resamples, chunked (≤ 1000 draws)
502+
503+
Executes in parallel unless the env-var `DISABLE_MULTIPROC` is set;
504+
"""
505+
if not os.getenv("DISABLE_MULTIPROC"):
506+
import multiprocessing as mp
507+
508+
pool = mp.Pool(mp.cpu_count())
509+
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
510+
# equivalent to stderr calculated without Bessel's correction in the stddev.
511+
# Unfortunately, I haven't been able to figure out what the right correction is
512+
# to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
513+
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
514+
# Thankfully, shouldn't matter because our samples are pretty big usually anyways
515+
res = []
516+
chunk_size = min(1000, iters)
517+
from tqdm import tqdm
518+
519+
print("bootstrapping for stddev:", f.__name__)
520+
for bootstrap in tqdm(
521+
pool.imap(
522+
_bootstrap_internal(f, chunk_size),
523+
[(i, xs) for i in range(iters // chunk_size)],
524+
),
525+
total=iters // chunk_size,
526+
):
527+
# sample w replacement
528+
res.extend(bootstrap)
529+
530+
pool.close()
531+
else:
532+
res = _bootstrap_internal_no_mp(f, xs, iters)
533+
492534
return sample_stddev(res)
493535

494536

495-
def stderr_for_metric(metric, bootstrap_iters: int):
537+
def stderr_for_metric(
538+
metric: Callable[[Sequence[T]], float], bootstrap_iters: int
539+
) -> Optional[Callable[[Sequence[T]], float]]:
540+
"""
541+
Return a function that estimates the standard error of `metric(xs)`.
542+
543+
* If `bootstrap_iters > 0` and the metric is in the pre-approved
544+
bootstrappable list, use `bootstrap_stderr` with that many draws.
545+
* If the metric has a closed-form SE (e.g. `mean`, `acc_all`), use it.
546+
* Otherwise, return `None`.
547+
"""
548+
496549
if bootstrap_iters <= 0:
497550
# return no function (don't compute stderr) if bootstrap iters = 0
498551
return None

tests/test_metrics.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import unittest.mock as mock
2+
3+
from lm_eval.api.metrics import _bootstrap_internal_no_mp, mean
14
from lm_eval.api.task import ConfigurableTask, TaskConfig
25

36

@@ -149,8 +152,34 @@ def test_acc_mutual_info_without_metric():
149152
assert result_dict["acc"] == 1.0
150153

151154

155+
def test_bootstrap_internal_no_mp():
156+
"""Test basic functionality of _bootstrap_internal_no_mp"""
157+
158+
data = [1, 2, 3, 4, 5]
159+
160+
# Mock tqdm to avoid progress bar output during testing
161+
with mock.patch("tqdm.tqdm") as mock_tqdm:
162+
mock_tqdm.return_value = range(1) # Single chunk
163+
164+
# Mock print to avoid output during testing
165+
with mock.patch("builtins.print"):
166+
result = _bootstrap_internal_no_mp(mean, data, 100)
167+
168+
# Should return 100 bootstrap replicates
169+
assert len(result) == 100
170+
171+
# All results should be numbers (means)
172+
assert all(isinstance(x, (int, float)) for x in result)
173+
174+
# Bootstrap means should be close to original mean
175+
bootstrap_mean = mean(result)
176+
original_mean = mean(data)
177+
assert abs(bootstrap_mean - original_mean) < 0.5 # Should be reasonably close
178+
179+
152180
if __name__ == "__main__":
153181
test_acc_mutual_info_slicing()
154182
test_acc_mutual_info_different_predictions()
155183
test_acc_mutual_info_without_metric()
184+
test_bootstrap_internal_no_mp()
156185
print("All tests passed!")

0 commit comments

Comments
 (0)