Skip to content

Commit dc5ee17

Browse files
committed
fix split and tests
1 parent c06c310 commit dc5ee17

File tree

3 files changed

+43
-4
lines changed

3 files changed

+43
-4
lines changed

bayesflow/adapters/transforms/split.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
from collections.abc import Sequence
22
import numpy as np
33

4+
from keras.saving import (
5+
deserialize_keras_object as deserialize,
6+
register_keras_serializable as serializable,
7+
serialize_keras_object as serialize,
8+
)
9+
410
from .transform import Transform
511

612

13+
@serializable(package="bayesflow.adapters")
714
class Split(Transform):
815
"""This is the effective inverse of the :py:class:`~Concatenate` Transform.
916
@@ -29,6 +36,23 @@ def __init__(self, key: str, into: Sequence[str], indices_or_sections: int | Seq
2936

3037
self.indices_or_sections = indices_or_sections
3138

39+
@classmethod
40+
def from_config(cls, config: dict, custom_objects=None) -> "Split":
41+
return cls(
42+
key=deserialize(config["key"], custom_objects),
43+
into=deserialize(config["into"], custom_objects),
44+
indices_or_sections=deserialize(config["indices_or_sections"], custom_objects),
45+
axis=deserialize(config["axis"], custom_objects),
46+
)
47+
48+
def get_config(self) -> dict:
49+
return {
50+
"key": serialize(self.key),
51+
"into": serialize(self.into),
52+
"indices_or_sections": serialize(self.indices_or_sections),
53+
"axis": serialize(self.axis),
54+
}
55+
3256
def forward(self, data: dict[str, np.ndarray], strict: bool = True, **kwargs) -> dict[str, np.ndarray]:
3357
# avoid side effects
3458
data = data.copy()
@@ -39,7 +63,7 @@ def forward(self, data: dict[str, np.ndarray], strict: bool = True, **kwargs) ->
3963
# we cannot produce a result, but also don't have to
4064
return data
4165

42-
splits = np.split(data.pop(self.key), self.indices_or_sections)
66+
splits = np.split(data.pop(self.key), self.indices_or_sections, axis=self.axis)
4367

4468
if len(splits) != len(self.into):
4569
raise ValueError(f"Requested {len(self.into)} splits, but produced {len(splits)}.")

tests/test_adapters/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,11 @@ def serializable_fn(x):
2828
.apply_serializable(include="x", forward=serializable_fn, inverse=serializable_fn)
2929
.scale("x", by=[-1, 2])
3030
.shift("x", by=2)
31-
.split("x", into=["x1", "x2"])
32-
.concatenate(["x1", "x2"], into="x")
31+
.split("key_to_split", into=["split_1", "split_2"])
3332
.standardize(exclude=["t1", "t2", "o1"])
3433
.drop("d1")
3534
.one_hot("o1", 10)
36-
.keep(["x", "y", "z1", "p1", "p2", "s1", "s2", "t1", "t2", "o1"])
35+
.keep(["x", "y", "z1", "p1", "p2", "s1", "s2", "t1", "t2", "o1", "split_1", "split_2"])
3736
.rename("o1", "o2")
3837
)
3938

@@ -57,4 +56,5 @@ def random_data():
5756
"d1": np.random.standard_normal(size=(32, 2)),
5857
"d2": np.random.standard_normal(size=(32, 2)),
5958
"o1": np.random.randint(0, 9, size=(32, 2)),
59+
"key_to_split": np.random.standard_normal(size=(32, 10)),
6060
}

tests/test_adapters/test_adapters.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,18 @@ def registered_but_changed(x): # noqa: F811
177177
corrupt_serialized_transform["config"]["inverse"]["config"] = "nonexistent"
178178
with pytest.raises(TypeError):
179179
keras.saving.deserialize_keras_object(corrupt_serialized_transform)
180+
181+
182+
def test_split_transform(adapter, random_data):
183+
assert "key_to_split" in random_data
184+
185+
shape = random_data["key_to_split"].shape
186+
target_shape = (*shape[:-1], shape[-1] // 2)
187+
188+
processed = adapter(random_data)
189+
190+
assert "split_1" in processed
191+
assert processed["split_1"].shape == target_shape
192+
193+
assert "split_2" in processed
194+
assert processed["split_2"].shape == target_shape

0 commit comments

Comments
 (0)