Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,3 +755,10 @@ def to_array(
)
self.transforms.append(transform)
return self

def to_dict(self):
from .transforms import ToDict

transform = ToDict()
self.transforms.append(transform)
return self
1 change: 1 addition & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .sqrt import Sqrt
from .standardize import Standardize
from .to_array import ToArray
from .to_dict import ToDict
from .transform import Transform

from ...utils._docs import _add_imports_to_all
Expand Down
41 changes: 41 additions & 0 deletions bayesflow/adapters/transforms/to_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
import pandas as pd

from keras.saving import (
register_keras_serializable as serializable,
)

from .transform import Transform


@serializable(package="bayesflow.adapters")
class ToDict(Transform):
"""Convert non-dict batches (e.g., pandas.DataFrame) to dict batches"""

@classmethod
def from_config(cls, config: dict, custom_objects=None):
return cls()

def get_config(self) -> dict:
return {}

def forward(self, data, **kwargs) -> dict[str, np.ndarray]:
data = dict(data)

for key, value in data.items():
if isinstance(data[key], pd.Series):
if value.dtype == "object":
value = value.astype("category")

if value.dtype == "category":
value = pd.get_dummies(value)

value = np.asarray(value).astype("float32", copy=False)

data[key] = value

return data

def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]:
# non-invertible transform
return data
40 changes: 40 additions & 0 deletions tests/test_adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import pytest

import bayesflow as bf


def test_cycle_consistency(adapter, random_data):
processed = adapter(random_data)
Expand Down Expand Up @@ -192,3 +194,41 @@ def test_split_transform(adapter, random_data):

assert "split_2" in processed
assert processed["split_2"].shape == target_shape


def test_to_dict_transform():
import pandas as pd

data = {
"int32": [1, 2, 3, 4, 5],
"int64": [1, 2, 3, 4, 5],
"float32": [1.0, 2.0, 3.0, 4.0, 5.0],
"float64": [1.0, 2.0, 3.0, 4.0, 5.0],
"object": ["a", "b", "c", "d", "e"],
"category": ["one", "two", "three", "four", "five"],
}

df = pd.DataFrame(data)
df["int32"] = df["int32"].astype("int32")
df["int64"] = df["int64"].astype("int64")
df["float32"] = df["float32"].astype("float32")
df["float64"] = df["float64"].astype("float64")
df["object"] = df["object"].astype("object")
df["category"] = df["category"].astype("category")

ad = bf.Adapter().to_dict()

# drop one element to simulate non-complete data
batch = df.iloc[:-1]

processed = ad(batch)

assert isinstance(processed, dict)
assert list(processed.keys()) == ["int32", "int64", "float32", "float64", "object", "category"]

for key, value in processed.items():
assert isinstance(value, np.ndarray)
assert value.dtype == "float32"

# category should have 5 one-hot categories, even though it was only passed 4
assert processed["category"].shape[-1] == 5
Loading