Skip to content

Commit ddf45c7

Browse files
committed
add to_dict transform
1 parent dc5ee17 commit ddf45c7

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

bayesflow/adapters/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .sqrt import Sqrt
2222
from .standardize import Standardize
2323
from .to_array import ToArray
24+
from .to_dict import ToDict
2425
from .transform import Transform
2526

2627
from ...utils._docs import _add_imports_to_all
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import numpy as np
2+
import pandas as pd
3+
4+
from keras.saving import (
5+
register_keras_serializable as serializable,
6+
)
7+
8+
from .transform import Transform
9+
10+
11+
@serializable(package="bayesflow.adapters")
12+
class ToDict(Transform):
13+
"""Convert non-dict batches (e.g., pandas.DataFrame) to dict batches"""
14+
15+
@classmethod
16+
def from_config(cls, config: dict, custom_objects=None):
17+
return cls()
18+
19+
def get_config(self) -> dict:
20+
return {}
21+
22+
def forward(self, data, **kwargs) -> dict[str, np.ndarray]:
23+
data = dict(data)
24+
25+
for key, value in data.items():
26+
if isinstance(data[key], pd.Series):
27+
if value.dtype == "object":
28+
value = value.astype("category")
29+
30+
if value.dtype == "category":
31+
value = pd.get_dummies(value)
32+
33+
value = np.asarray(value).astype("float32", copy=False)
34+
35+
data[key] = value
36+
37+
return data
38+
39+
def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]:
40+
# non-invertible transform
41+
return data

0 commit comments

Comments
 (0)