Skip to content

Commit ee28392

Browse files
committed
clean up take transform and docs
1 parent 37598b0 commit ee28392

File tree

3 files changed

+21
-18
lines changed

3 files changed

+21
-18
lines changed

bayesflow/adapters/adapter.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ def standardize(
765765
Names of variables to include in the transform.
766766
exclude : str or Sequence of str, optional
767767
Names of variables to exclude from the transform.
768-
**kwargs : dict
768+
**kwargs :
769769
Additional keyword arguments passed to the transform.
770770
"""
771771
transform = FilterTransform(
@@ -780,33 +780,36 @@ def standardize(
780780

781781
def take(
782782
self,
783-
indices,
784-
axis,
783+
include: str | Sequence[str] = None,
785784
*,
785+
indices: Sequence[int],
786+
axis: int = -1,
786787
predicate: Predicate = None,
787-
include: str | Sequence[str] = None,
788788
exclude: str | Sequence[str] = None,
789-
**kwargs,
790789
):
791790
"""
792791
Append a :py:class:`~transforms.Take` transform to the adapter.
793792
794793
Parameters
795794
----------
796-
predicate : Predicate, optional
797-
Function that indicates which variables should be transformed.
798795
include : str or Sequence of str, optional
799796
Names of variables to include in the transform.
797+
indices : Sequence of int
798+
Which indices to take from the data.
799+
axis : int, optional
800+
Which axis to take from. The last axis is used by default.
801+
predicate : Predicate, optional
802+
Function that indicates which variables should be transformed.
800803
exclude : str or Sequence of str, optional
801804
Names of variables to exclude from the transform.
802-
**kwargs : dict
803-
Additional keyword arguments passed to the transform."""
805+
"""
804806
transform = FilterTransform(
805-
transform_constructor=Take(indices=indices, axis=axis),
807+
transform_constructor=Take,
806808
predicate=predicate,
807809
include=include,
808810
exclude=exclude,
809-
**kwargs,
811+
indices=indices,
812+
axis=axis,
810813
)
811814
self.transforms.append(transform)
812815
return self
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from collections.abc import Sequence
12
import numpy as np
2-
from keras.saving import register_keras_serializable as serializable
3+
4+
from bayesflow.utils.serialization import serializable
35

46
from .elementwise_transform import ElementwiseTransform
57

@@ -8,19 +10,17 @@
810
class Take(ElementwiseTransform):
911
"""
1012
A transform to reduce the dimensionality of arrays output by the summary network
11-
Axis is a mandatory argument and will default to the last axis.
1213
Example: adapter.take("x", np.arange(0,3), axis=-1)
13-
1414
"""
1515

16-
def __init__(self, indices, axis=-1):
16+
def __init__(self, indices: Sequence[int], axis: int = -1):
1717
super().__init__()
1818
self.indices = indices
1919
self.axis = axis
2020

21-
def forward(self, data):
21+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
2222
return np.take(data, self.indices, self.axis)
2323

24-
def inverse(self, data):
24+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
2525
# not a true invertible function
2626
return data

tests/test_adapters/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def serializable_fn(x):
3232
.standardize(exclude=["t1", "t2", "o1"])
3333
.drop("d1")
3434
.one_hot("o1", 10)
35-
.keep(["x", "y", "z1", "p1", "p2", "s1", "s2", "t1", "t2", "o1", "split_1", "split_2"])
35+
.keep(["x", "y", "z1", "p1", "p2", "s1", "s2", "s3", "t1", "t2", "o1", "split_1", "split_2"])
3636
.rename("o1", "o2")
3737
.random_subsample("s3", sample_size=33, axis=0)
3838
.take("s3", indices=np.arange(0, 32), axis=0)

0 commit comments

Comments
 (0)