Skip to content

Commit 6eab147

Browse files
committed
handle nesting: ConvertDType, ToArray, relax Concatenate
Concatenate can be equal to rename if only one key is supplied. By not calling concatenate in that case, we can accept arbitrary inputs in the transform, as long as only one is supplied. This simplifies things e.g. in the `BasicWorkflow`, where the user passes the `summary_variables` to concatenate, which may be a single dict, which does not need to be concatenated.
1 parent 01aadf1 commit 6eab147

File tree

8 files changed

+116
-22
lines changed

8 files changed

+116
-22
lines changed

bayesflow/adapters/transforms/concatenate.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def get_config(self) -> dict:
4949
return serialize(config)
5050

5151
def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dict[str, any]:
52-
if not strict and self.indices is None:
52+
if not strict and self.indices is None and len(self.keys) != 1:
5353
raise ValueError("Cannot call `forward` with `strict=False` before calling `forward` with `strict=True`.")
5454

5555
# copy to avoid side effects
@@ -69,6 +69,10 @@ def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dic
6969
data.pop(key)
7070

7171
return data
72+
elif len(required_keys) == 1:
73+
# only a rename
74+
data[self.into] = data.pop(self.keys[0])
75+
return data
7276

7377
if self.indices is None:
7478
# remember the indices of the parts in the concatenated array
@@ -86,7 +90,7 @@ def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dic
8690
return data
8791

8892
def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> dict[str, any]:
89-
if self.indices is None:
93+
if self.indices is None and len(self.keys) != 1:
9094
raise RuntimeError("Cannot call `inverse` before calling `forward` at least once.")
9195

9296
# copy to avoid side effects
@@ -98,6 +102,9 @@ def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> di
98102
elif self.into not in data:
99103
# nothing to do
100104
return data
105+
elif len(self.keys) == 1:
106+
data[self.keys[0]] = data.pop(self.into)
107+
return data
101108

102109
# split the concatenated array and remove the concatenated key
103110
keys = self.keys
@@ -141,7 +148,7 @@ def log_det_jac(
141148
available_keys = set(log_det_jac.keys())
142149
common_keys = available_keys & required_keys
143150

144-
if len(common_keys) == 0:
151+
if len(common_keys) == 0 or len(self.keys) == 1:
145152
return log_det_jac
146153

147154
parts = [log_det_jac.pop(key) for key in common_keys]

bayesflow/adapters/transforms/convert_dtype.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from keras.tree import map_structure
23

34
from bayesflow.utils.serialization import serializable, serialize
45

@@ -32,7 +33,7 @@ def get_config(self) -> dict:
3233
return serialize(config)
3334

3435
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
35-
return data.astype(self.to_dtype, copy=False)
36+
return map_structure(lambda d: d.astype(self.to_dtype, copy=False), data)
3637

3738
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
38-
return data.astype(self.from_dtype, copy=False)
39+
return map_structure(lambda d: d.astype(self.from_dtype, copy=False), data)

bayesflow/adapters/transforms/to_array.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44

5+
from bayesflow.utils.tree import map_dict, get_value_at_path, map_dict_with_path
56
from bayesflow.utils.serialization import serializable, serialize
67

78
from .elementwise_transform import ElementwiseTransform
@@ -35,13 +36,36 @@ def get_config(self) -> dict:
3536

3637
def forward(self, data: any, **kwargs) -> np.ndarray:
3738
if self.original_type is None:
38-
self.original_type = type(data)
39+
if isinstance(data, dict):
40+
self.original_type = map_dict(type, data)
41+
else:
42+
self.original_type = type(data)
3943

44+
if isinstance(self.original_type, dict):
45+
# use self.original_type in check to preserve serializablitiy
46+
return map_dict(np.asarray, data)
4047
return np.asarray(data)
4148

42-
def inverse(self, data: np.ndarray, **kwargs) -> any:
49+
def inverse(self, data: np.ndarray | dict, **kwargs) -> any:
4350
if self.original_type is None:
4451
raise RuntimeError("Cannot call `inverse` before calling `forward` at least once.")
52+
if isinstance(self.original_type, dict):
53+
# use self.original_type in check to preserve serializablitiy
54+
55+
def restore_original_type(path, value):
56+
try:
57+
original_type = get_value_at_path(self.original_type, path)
58+
return original_type(value)
59+
except KeyError:
60+
pass
61+
except TypeError:
62+
pass
63+
except ValueError:
64+
# separate statements, as optree does not allow (KeyError | TypeError | ValueError)
65+
pass
66+
return value
67+
68+
return map_dict_with_path(restore_original_type, data)
4569

4670
if issubclass(self.original_type, Number):
4771
try:

bayesflow/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
logging,
88
numpy_utils,
99
serialization,
10+
tree,
1011
)
1112

1213
from .callbacks import detailed_loss_callback
@@ -104,4 +105,4 @@
104105

105106
from ._docs import _add_imports_to_all
106107

107-
_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils", "serialization"])
108+
_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils", "serialization", "tree"])

bayesflow/utils/tree.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import optree
2+
3+
4+
def flatten_shape(structure):
5+
def is_shape_tuple(x):
6+
return isinstance(x, (list, tuple)) and all(isinstance(e, (int, type(None))) for e in x)
7+
8+
leaves, _ = optree.tree_flatten(
9+
structure,
10+
is_leaf=is_shape_tuple,
11+
none_is_leaf=True,
12+
namespace="keras",
13+
)
14+
return leaves
15+
16+
17+
def map_dict(func, *structures):
18+
def is_not_dict(x):
19+
return not isinstance(x, dict)
20+
21+
if not structures:
22+
raise ValueError("Must provide at least one structure")
23+
24+
# Add check for same structures, otherwise optree just maps to shallowest.
25+
def func_with_check(*args):
26+
if not all(optree.tree_is_leaf(s, is_leaf=is_not_dict, none_is_leaf=True, namespace="keras") for s in args):
27+
raise ValueError("Structures don't have the same nested structure.")
28+
return func(*args)
29+
30+
map_func = func_with_check if len(structures) > 1 else func
31+
32+
return optree.tree_map(
33+
map_func,
34+
*structures,
35+
is_leaf=is_not_dict,
36+
none_is_leaf=True,
37+
namespace="keras",
38+
)
39+
40+
41+
def map_dict_with_path(func, *structures):
42+
def is_not_dict(x):
43+
return not isinstance(x, dict)
44+
45+
if not structures:
46+
raise ValueError("Must provide at least one structure")
47+
48+
# Add check for same structures, otherwise optree just maps to shallowest.
49+
def func_with_check(*args):
50+
if not all(optree.tree_is_leaf(s, is_leaf=is_not_dict, none_is_leaf=True, namespace="keras") for s in args):
51+
raise ValueError("Structures don't have the same nested structure.")
52+
return func(*args)
53+
54+
map_func = func_with_check if len(structures) > 1 else func
55+
56+
return optree.tree_map_with_path(
57+
map_func,
58+
*structures,
59+
is_leaf=is_not_dict,
60+
none_is_leaf=True,
61+
namespace="keras",
62+
)
63+
64+
65+
def get_value_at_path(structure, path):
66+
output = structure
67+
for accessor in path:
68+
output = output.__getitem__(accessor)
69+
return output

tests/test_adapters/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ def serializable_fn(x):
1313

1414
return (
1515
Adapter()
16+
.group(["p1", "p2"], into="ps", prefix="p")
1617
.to_array()
18+
.ungroup("ps", prefix="p")
1719
.as_set(["s1", "s2"])
1820
.broadcast("t1", to="t2")
1921
.as_time_series(["t1", "t2"])
@@ -37,8 +39,6 @@ def serializable_fn(x):
3739
.rename("o1", "o2")
3840
.random_subsample("s3", sample_size=33, axis=0)
3941
.take("s3", indices=np.arange(0, 32), axis=0)
40-
.group(["p1", "p2"], into="ps", prefix="p")
41-
.ungroup("ps", prefix="p")
4242
)
4343

4444

tests/test_workflows/conftest.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,6 @@ def sample(self, batch_shape: Shape, num_observations: int = 4) -> dict[str, Ten
8181

8282
x = mean[:, None] + noise
8383

84-
return dict(mean=mean, a=x, b=x)
84+
return dict(mean=mean, observables=dict(a=x, b=x))
8585

8686
return FusionSimulator()
87-
88-
89-
@pytest.fixture
90-
def fusion_adapter():
91-
from bayesflow import Adapter
92-
93-
return Adapter.create_default(["mean"]).group(["a", "b"], "summary_variables")

tests/test_workflows/test_basic_workflow.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,13 @@ def test_basic_workflow(tmp_path, inference_network, summary_network):
3434
assert samples["parameters"].shape == (5, 3, 2)
3535

3636

37-
def test_basic_workflow_fusion(
38-
tmp_path, fusion_inference_network, fusion_summary_network, fusion_simulator, fusion_adapter
39-
):
37+
def test_basic_workflow_fusion(tmp_path, fusion_inference_network, fusion_summary_network, fusion_simulator):
4038
workflow = bf.BasicWorkflow(
41-
adapter=fusion_adapter,
4239
inference_network=fusion_inference_network,
4340
summary_network=fusion_summary_network,
4441
simulator=fusion_simulator,
42+
inference_variables=["mean"],
43+
summary_variables=["observables"],
4544
checkpoint_filepath=str(tmp_path),
4645
)
4746

0 commit comments

Comments
 (0)