Skip to content

Commit 087fdbb

Browse files
committed
Move maybe_add_argument to functional.py
1 parent 2ba97e7 commit 087fdbb

File tree

3 files changed

+25
-28
lines changed

3 files changed

+25
-28
lines changed

src/pydvl/utils/functional.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from functools import partial
55
from typing import Callable, Set, Tuple, Union
66

7-
__all__ = ["get_free_args_fn", "fn_accept_additional_argument"]
7+
__all__ = ["maybe_add_argument"]
88

99

1010
def fn_accept_additional_argument(*args, fn: Callable, arg: str, **kwargs):
@@ -95,3 +95,25 @@ def _rec_unroll_partial_function_args(g: Union[Callable, partial]) -> Callable:
9595
wrapped_fn = _rec_unroll_partial_function_args(fun)
9696
sig = inspect.signature(wrapped_fn)
9797
return args_set_by_partial | set(sig.parameters.keys())
98+
99+
100+
def maybe_add_argument(fun: Callable, new_arg: str) -> Callable:
101+
"""Wraps a function to accept the given keyword parameter if it doesn't
102+
already.
103+
104+
If `fun` already takes a keyword parameter of name `new_arg`, then it is
105+
returned as is. Otherwise, a wrapper is returned which merely ignores the
106+
argument.
107+
108+
Args:
109+
fun: The function to wrap
110+
new_arg: The name of the argument that the new function will accept
111+
(and ignore).
112+
113+
Returns:
114+
A new function accepting one more keyword argument.
115+
"""
116+
if new_arg in free_arguments(fun):
117+
return fun
118+
119+
return functools.partial(_accept_additional_argument, fun=fun, arg=new_arg)

src/pydvl/utils/parallel/map_reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
ReduceFunction,
2121
Seed,
2222
ensure_seed_sequence,
23-
maybe_add_argument,
2423
)
24+
from ..functional import maybe_add_argument
2525
from .backend import init_parallel_backend
2626

2727
__all__ = ["MapReduceJob"]

src/pydvl/utils/types.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,12 @@
33
"""
44
from __future__ import annotations
55

6-
import functools
76
from abc import ABCMeta
8-
from typing import Any, Callable, Optional, Protocol, TypeVar, Union, cast
7+
from typing import Any, Optional, Protocol, TypeVar, Union, cast
98

109
from numpy.random import Generator, SeedSequence
1110
from numpy.typing import NDArray
1211

13-
from pydvl.utils.functional import fn_accept_additional_argument, get_free_args_fn
14-
1512
__all__ = ["SupervisedModel", "MapFunction", "ReduceFunction", "NoPublicConstructor"]
1613

1714
R = TypeVar("R", covariant=True)
@@ -45,28 +42,6 @@ def score(self, x: NDArray, y: NDArray) -> float:
4542
pass
4643

4744

48-
def maybe_add_argument(fun: Callable, new_arg: str) -> Callable:
49-
"""Wraps a function to accept the given keyword parameter if it doesn't
50-
already.
51-
52-
If `fun` already takes a keyword parameter of name `new_arg`, then it is
53-
returned as is. Otherwise, a wrapper is returned which merely ignores the
54-
argument.
55-
56-
Args:
57-
fun: The function to wrap
58-
new_arg: The name of the argument that the new function will accept
59-
(and ignore).
60-
61-
Returns:
62-
A new function accepting one more keyword argument.
63-
"""
64-
if new_arg in get_free_args_fn(fun):
65-
return fun
66-
67-
return functools.partial(fn_accept_additional_argument, fn=fun, arg=new_arg)
68-
69-
7045
class NoPublicConstructor(ABCMeta):
7146
"""Metaclass that ensures a private constructor
7247

0 commit comments

Comments
 (0)