Skip to content

Commit 9853265

Browse files
authored
Merge pull request #397 from bayesflow-org/allow-positional-include-transforms
Allow positional include transforms
2 parents effbb00 + 759abbc commit 9853265

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

bayesflow/adapters/adapter.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,11 @@ def __len__(self):
233233

234234
def apply(
235235
self,
236+
include: str | Sequence[str] = None,
236237
*,
237238
forward: np.ufunc | str,
238239
inverse: np.ufunc | str = None,
239240
predicate: Predicate = None,
240-
include: str | Sequence[str] = None,
241241
exclude: str | Sequence[str] = None,
242242
**kwargs,
243243
):
@@ -388,6 +388,7 @@ def convert_dtype(
388388
exclude: str | Sequence[str] = None,
389389
):
390390
"""Append a :py:class:`~transforms.ConvertDType` transform to the adapter.
391+
See also :py:meth:`~bayesflow.adapters.Adapter.map_dtype`.
391392
392393
Parameters
393394
----------
@@ -525,6 +526,24 @@ def log(self, keys: str | Sequence[str], *, p1: bool = False):
525526
self.transforms.append(transform)
526527
return self
527528

529+
def map_dtype(self, keys: str | Sequence[str], to_dtype: str):
530+
"""Append a :py:class:`~transforms.ConvertDType` transform to the adapter.
531+
See also :py:meth:`~bayesflow.adapters.Adapter.convert_dtype`.
532+
533+
Parameters
534+
----------
535+
keys : str or Sequence of str
536+
The names of the variables to transform.
537+
to_dtype : str
538+
Target dtype
539+
"""
540+
if isinstance(keys, str):
541+
keys = [keys]
542+
543+
transform = MapTransform({key: ConvertDType(to_dtype) for key in keys})
544+
self.transforms.append(transform)
545+
return self
546+
528547
def one_hot(self, keys: str | Sequence[str], num_classes: int):
529548
"""Append a :py:class:`~transforms.OneHot` transform to the adapter.
530549
@@ -590,9 +609,9 @@ def sqrt(self, keys: str | Sequence[str]):
590609

591610
def standardize(
592611
self,
612+
include: str | Sequence[str] = None,
593613
*,
594614
predicate: Predicate = None,
595-
include: str | Sequence[str] = None,
596615
exclude: str | Sequence[str] = None,
597616
**kwargs,
598617
):
@@ -621,9 +640,9 @@ def standardize(
621640

622641
def to_array(
623642
self,
643+
include: str | Sequence[str] = None,
624644
*,
625645
predicate: Predicate = None,
626-
include: str | Sequence[str] = None,
627646
exclude: str | Sequence[str] = None,
628647
**kwargs,
629648
):

bayesflow/adapters/transforms/filter_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ class FilterTransform(Transform):
2929

3030
def __init__(
3131
self,
32+
include: str | Sequence[str] = None,
3233
*,
3334
transform_constructor: Callable[..., ElementwiseTransform],
3435
predicate: Predicate = None,
35-
include: str | Sequence[str] = None,
3636
exclude: str | Sequence[str] = None,
3737
**kwargs,
3838
):

bayesflow/utils/empty.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
class empty:
2+
"""
3+
Placeholder value for arguments left empty
4+
5+
Usage:
6+
7+
def f(x=empty):
8+
if x is empty:
9+
# we know the user did not pass x
10+
if x is None:
11+
# the user could have passed None explicitly
12+
13+
"""

0 commit comments

Comments
 (0)