Skip to content

Commit f8fa433

Browse files
authored
Merge pull request #396 from aai-institute/392-explicit-randomization-for-subprocesses
Add solution for explicit randomization in subprocesses.
2 parents ec8aac0 + e38e1ff commit f8fa433

File tree

22 files changed

+678
-101
lines changed

22 files changed

+678
-101
lines changed

CHANGELOG.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# Changelog
22

3-
## 0.7.0 - 📚 Documentation overhaul, new methods and bug fixes 💥
3+
## 0.7.0 - 📚🆕 Documentation and IF overhaul, new methods and bug fixes 💥🐞
44

55
This is our first β release! We have worked hard to deliver improvements across
6-
the board, with a focus on documentation and usability.
6+
the board, with a focus on documentation and usability. We have also reworked
7+
the internals of the `influence` module, improved parallelism and handling of
8+
randomness.
79

810
### Added
911

@@ -13,8 +15,13 @@ the board, with a focus on documentation and usability.
1315
[PR #406](https://github.com/aai-institute/pyDVL/pull/406)
1416
- Added more abbreviations to documentation
1517
[PR #415](https://github.com/aai-institute/pyDVL/pull/415)
18+
- Added seed to functions from `pydvl.utils.numeric`, `pydvl.value.shapley` and
19+
`pydvl.value.semivalues`. Introduced new type `Seed` and conversion function
20+
`ensure_seed_sequence`.
21+
[PR #396](https://github.com/aai-institute/pyDVL/pull/396)
1622

1723
### Changed
24+
1825
- Replaced sphinx with mkdocs for documentation. Major overhaul of documentation
1926
[PR #352](https://github.com/aai-institute/pyDVL/pull/352)
2027
- Made ray an optional dependency, relying on joblib as default parallel backend

docs/css/extra.css

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ a.autorefs-external:hover::after {
7777
user-select: none;
7878
}
7979

80+
/* Nicer style of headers in generated API */
81+
h2 code {
82+
font-size: large!important;
83+
background-color: inherit!important;
84+
}
85+
8086
/* Remove cell input and output prompt */
8187
.jp-InputArea-prompt, .jp-OutputArea-prompt {
8288
display: none !important;

src/pydvl/utils/functional.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""
2+
Supporting utilities for manipulating arguments of functions.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
import inspect
8+
from functools import partial
9+
from typing import Callable, Set, Union
10+
11+
__all__ = ["maybe_add_argument"]
12+
13+
14+
def _accept_additional_argument(*args, fun: Callable, arg: str, **kwargs):
15+
"""Calls the given function with the given positional and keyword arguments,
16+
removing `arg` from the keyword arguments.
17+
18+
Args:
19+
args: Positional arguments to pass to the function.
20+
fun: The function to call.
21+
arg: The name of the argument to remove.
22+
kwargs: Keyword arguments to pass to the function.
23+
24+
Returns:
25+
The return value of the function.
26+
"""
27+
try:
28+
del kwargs[arg]
29+
except KeyError:
30+
pass
31+
32+
return fun(*args, **kwargs)
33+
34+
35+
def free_arguments(fun: Union[Callable, partial]) -> Set[str]:
36+
"""Computes the set of free arguments for a function or
37+
[functools.partial][] object.
38+
39+
All arguments of a function are considered free unless they are set by a
40+
partial. For example, if `f = partial(g, a=1)`, then `a` is not a free
41+
argument of `f`.
42+
43+
Args:
44+
fun: A callable or a [partial object][].
45+
46+
Returns:
47+
The set of free arguments of `fun`.
48+
49+
!!! tip "New in version 0.7.0"
50+
"""
51+
args_set_by_partial: Set[str] = set()
52+
53+
def _rec_unroll_partial_function_args(g: Union[Callable, partial]) -> Callable:
54+
"""Stores arguments and recursively call itself if `g` is a
55+
[functools.partial][] object. In the end, returns the initially wrapped
56+
function.
57+
58+
This handles the construct `partial(_accept_additional_argument, *args,
59+
**kwargs)` that is used by `maybe_add_argument`.
60+
61+
Args:
62+
g: A partial or a function to unroll.
63+
64+
Returns:
65+
Initial wrapped function.
66+
"""
67+
nonlocal args_set_by_partial
68+
69+
if isinstance(g, partial) and g.func == _accept_additional_argument:
70+
arg = g.keywords["arg"]
71+
if arg in args_set_by_partial:
72+
args_set_by_partial.remove(arg)
73+
return _rec_unroll_partial_function_args(g.keywords["fun"])
74+
elif isinstance(g, partial):
75+
args_set_by_partial.update(g.keywords.keys())
76+
args_set_by_partial.update(g.args)
77+
return _rec_unroll_partial_function_args(g.func)
78+
else:
79+
return g
80+
81+
wrapped_fn = _rec_unroll_partial_function_args(fun)
82+
sig = inspect.signature(wrapped_fn)
83+
return args_set_by_partial | set(sig.parameters.keys())
84+
85+
86+
def maybe_add_argument(fun: Callable, new_arg: str) -> Callable:
87+
"""Wraps a function to accept the given keyword parameter if it doesn't
88+
already.
89+
90+
If `fun` already takes a keyword parameter of name `new_arg`, then it is
91+
returned as is. Otherwise, a wrapper is returned which merely ignores the
92+
argument.
93+
94+
Args:
95+
fun: The function to wrap
96+
new_arg: The name of the argument that the new function will accept
97+
(and ignore).
98+
99+
Returns:
100+
A new function accepting one more keyword argument.
101+
102+
!!! tip "Changed in version 0.7.0"
103+
Ability to work with partials.
104+
"""
105+
if new_arg in free_arguments(fun):
106+
return fun
107+
108+
return partial(_accept_additional_argument, fun=fun, arg=new_arg)

src/pydvl/utils/numeric.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import numpy as np
1111
from numpy.typing import NDArray
1212

13+
from pydvl.utils.types import Seed
14+
1315
__all__ = [
1416
"running_moments",
1517
"num_samples_permutation_hoeffding",
@@ -68,24 +70,32 @@ def num_samples_permutation_hoeffding(eps: float, delta: float, u_range: float)
6870
return int(np.ceil(np.log(2 / delta) * 2 * u_range**2 / eps**2))
6971

7072

71-
def random_subset(s: NDArray[T], q: float = 0.5) -> NDArray[T]:
72-
"""Returns one subset at random from `s`.
73+
def random_subset(
74+
s: NDArray[T],
75+
q: float = 0.5,
76+
seed: Optional[Seed] = None,
77+
) -> NDArray[T]:
78+
"""Returns one subset at random from ``s``.
7379
7480
Args:
7581
s: set to sample from
7682
q: Sampling probability for elements. The default 0.5 yields a
7783
uniform distribution over the power set of s.
84+
seed: Either an instance of a numpy random number generator or a seed for it.
7885
7986
Returns:
8087
The subset
8188
"""
82-
rng = np.random.default_rng()
89+
rng = np.random.default_rng(seed)
8390
selection = rng.uniform(size=len(s)) > q
8491
return s[selection]
8592

8693

8794
def random_powerset(
88-
s: NDArray[T], n_samples: Optional[int] = None, q: float = 0.5
95+
s: NDArray[T],
96+
n_samples: Optional[int] = None,
97+
q: float = 0.5,
98+
seed: Optional[Seed] = None,
8999
) -> Generator[NDArray[T], None, None]:
90100
"""Samples subsets from the power set of the argument, without
91101
pre-generating all subsets and in no order.
@@ -103,6 +113,7 @@ def random_powerset(
103113
Defaults to `np.iinfo(np.int32).max`
104114
q: Sampling probability for elements. The default 0.5 yields a
105115
uniform distribution over the power set of s.
116+
seed: Either an instance of a numpy random number generator or a seed for it.
106117
107118
Returns:
108119
Samples from the power set of `s`.
@@ -114,21 +125,27 @@ def random_powerset(
114125
if q < 0 or q > 1:
115126
raise ValueError("Element sampling probability must be in [0,1]")
116127

128+
rng = np.random.default_rng(seed)
117129
total = 1
118130
if n_samples is None:
119131
n_samples = np.iinfo(np.int32).max
120132
while total <= n_samples:
121-
yield random_subset(s, q)
133+
yield random_subset(s, q, seed=rng)
122134
total += 1
123135

124136

125-
def random_subset_of_size(s: NDArray[T], size: int) -> NDArray[T]:
137+
def random_subset_of_size(
138+
s: NDArray[T],
139+
size: int,
140+
seed: Optional[Seed] = None,
141+
) -> NDArray[T]:
126142
"""Samples a random subset of given size uniformly from the powerset
127143
of `s`.
128144
129145
Args:
130146
s: Set to sample from
131147
size: Size of the subset to generate
148+
seed: Either an instance of a numpy random number generator or a seed for it.
132149
133150
Returns:
134151
The subset
@@ -138,11 +155,13 @@ def random_subset_of_size(s: NDArray[T], size: int) -> NDArray[T]:
138155
"""
139156
if size > len(s):
140157
raise ValueError("Cannot sample subset larger than set")
141-
rng = np.random.default_rng()
158+
rng = np.random.default_rng(seed)
142159
return rng.choice(s, size=size, replace=False)
143160

144161

145-
def random_matrix_with_condition_number(n: int, condition_number: float) -> NDArray:
162+
def random_matrix_with_condition_number(
163+
n: int, condition_number: float, seed: Optional[Seed] = None
164+
) -> NDArray:
146165
"""Constructs a square matrix with a given condition number.
147166
148167
Taken from:
@@ -156,6 +175,7 @@ def random_matrix_with_condition_number(n: int, condition_number: float) -> NDAr
156175
Args:
157176
n: size of the matrix
158177
condition_number: duh
178+
seed: Either an instance of a numpy random number generator or a seed for it.
159179
160180
Returns:
161181
An (n,n) matrix with the requested condition number.
@@ -166,6 +186,7 @@ def random_matrix_with_condition_number(n: int, condition_number: float) -> NDAr
166186
if condition_number <= 1:
167187
raise ValueError("Condition number must be greater than 1")
168188

189+
rng = np.random.default_rng(seed)
169190
log_condition_number = np.log(condition_number)
170191
exp_vec = np.arange(
171192
-log_condition_number / 4.0,
@@ -175,8 +196,8 @@ def random_matrix_with_condition_number(n: int, condition_number: float) -> NDAr
175196
exp_vec = exp_vec[:n]
176197
s: np.ndarray = np.exp(exp_vec)
177198
S = np.diag(s)
178-
U, _ = np.linalg.qr((np.random.rand(n, n) - 5.0) * 200)
179-
V, _ = np.linalg.qr((np.random.rand(n, n) - 5.0) * 200)
199+
U, _ = np.linalg.qr((rng.uniform(size=(n, n)) - 5.0) * 200)
200+
V, _ = np.linalg.qr((rng.uniform(size=(n, n)) - 5.0) * 200)
180201
P: np.ndarray = U.dot(S).dot(V.T)
181202
P = P.dot(P.T)
182203
return P

src/pydvl/utils/parallel/map_reduce.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@
66
This interface might be deprecated or changed in a future release before 1.0
77
88
"""
9+
from functools import reduce
910
from itertools import accumulate, repeat
1011
from typing import Any, Collection, Dict, Generic, List, Optional, TypeVar, Union
1112

1213
from joblib import Parallel, delayed
14+
from numpy.random import SeedSequence
1315
from numpy.typing import NDArray
1416

1517
from ..config import ParallelConfig
16-
from ..types import MapFunction, ReduceFunction, maybe_add_argument
18+
from ..functional import maybe_add_argument
19+
from ..types import MapFunction, ReduceFunction, Seed, ensure_seed_sequence
1720
from .backend import init_parallel_backend
1821

1922
__all__ = ["MapReduceJob"]
@@ -104,25 +107,42 @@ def __init__(
104107
self.map_kwargs = map_kwargs if map_kwargs is not None else dict()
105108
self.reduce_kwargs = reduce_kwargs if reduce_kwargs is not None else dict()
106109

107-
self._map_func = maybe_add_argument(map_func, "job_id")
110+
self._map_func = reduce(maybe_add_argument, ["job_id", "seed"], map_func)
108111
self._reduce_func = reduce_func
109112

110113
def __call__(
111114
self,
115+
seed: Optional[Union[Seed, SeedSequence]] = None,
112116
) -> R:
117+
"""
118+
Runs the map-reduce job.
119+
120+
Args:
121+
seed: Either an instance of a numpy random number generator or a seed for
122+
it.
123+
124+
Returns:
125+
The result of the reduce function.
126+
"""
113127
if self.config.backend == "joblib":
114128
backend = "loky"
115129
else:
116130
backend = self.config.backend
117131
# In joblib the levels are reversed.
118132
# 0 means no logging and 50 means log everything to stdout
119133
verbose = 50 - self.config.logging_level
134+
seed_seq = ensure_seed_sequence(seed)
120135
with Parallel(backend=backend, n_jobs=self.n_jobs, verbose=verbose) as parallel:
121136
chunks = self._chunkify(self.inputs_, n_chunks=self.n_jobs)
122137
map_results: List[R] = parallel(
123-
delayed(self._map_func)(next_chunk, job_id=j, **self.map_kwargs)
124-
for j, next_chunk in enumerate(chunks)
138+
delayed(self._map_func)(
139+
next_chunk, job_id=j, seed=seed, **self.map_kwargs
140+
)
141+
for j, (next_chunk, seed) in enumerate(
142+
zip(chunks, seed_seq.spawn(len(chunks)))
143+
)
125144
)
145+
126146
reduce_results: R = self._reduce_func(map_results, **self.reduce_kwargs)
127147
return reduce_results
128148

0 commit comments

Comments
 (0)