Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,15 @@ def from_config(cls, config: dict, custom_objects=None) -> "Adapter":
def get_config(self) -> dict:
return {"transforms": serialize(self.transforms)}

def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -> dict[str, np.ndarray]:
"""Apply the transforms in the forward direction.

Parameters
----------
data : dict
The data to be transformed.
stage : str, one of ["training", "validation", "inference"]
The stage the function is called in.
**kwargs : dict
Additional keyword arguments passed to each transform.

Expand All @@ -97,17 +99,19 @@ def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
data = data.copy()

for transform in self.transforms:
data = transform(data, **kwargs)
data = transform(data, stage=stage, **kwargs)

return data

def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, any]:
def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kwargs) -> dict[str, any]:
"""Apply the transforms in the inverse direction.

Parameters
----------
data : dict
The data to be transformed.
stage : str, one of ["training", "validation", "inference"]
The stage the function is called in.
**kwargs : dict
Additional keyword arguments passed to each transform.

Expand All @@ -119,11 +123,13 @@ def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, any]:
data = data.copy()

for transform in reversed(self.transforms):
data = transform(data, inverse=True, **kwargs)
data = transform(data, stage=stage, inverse=True, **kwargs)

return data

def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs) -> dict[str, np.ndarray]:
def __call__(
self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs
) -> dict[str, np.ndarray]:
"""Apply the transforms in the given direction.

Parameters
Expand All @@ -132,6 +138,8 @@ def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs)
The data to be transformed.
inverse : bool, optional
If False, apply the forward transform, else apply the inverse transform (default False).
stage : str, one of ["training", "validation", "inference"]
The stage the function is called in.
**kwargs
Additional keyword arguments passed to each transform.

Expand All @@ -141,9 +149,9 @@ def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs)
The transformed data.
"""
if inverse:
return self.inverse(data, **kwargs)
return self.inverse(data, stage=stage, **kwargs)

return self.forward(data, **kwargs)
return self.forward(data, stage=stage, **kwargs)

def __repr__(self):
result = ""
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/standardize.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_config(self) -> dict:
"momentum": serialize(self.momentum),
}

def forward(self, data: np.ndarray, stage: str = "training", **kwargs) -> np.ndarray:
def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray:
if self.axis is None:
self.axis = tuple(range(data.ndim - 1))

Expand Down
4 changes: 3 additions & 1 deletion bayesflow/datasets/disk_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
batch_size: int,
load_fn: callable = None,
adapter: Adapter | None,
stage: str = "training",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -39,6 +40,7 @@
self.load_fn = load_fn or pickle_load
self.adapter = adapter
self.files = list(map(str, self.root.glob(pattern)))
self.stage = stage

Check warning on line 43 in bayesflow/datasets/disk_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/disk_dataset.py#L43

Added line #L43 was not covered by tests

self.shuffle()

Expand All @@ -55,7 +57,7 @@
batch = tree_stack(batch)

if self.adapter is not None:
batch = self.adapter(batch)
batch = self.adapter(batch, stage=self.stage)

Check warning on line 60 in bayesflow/datasets/disk_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/disk_dataset.py#L60

Added line #L60 was not covered by tests

return batch

Expand Down
5 changes: 4 additions & 1 deletion bayesflow/datasets/offline_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ def __init__(
batch_size: int,
adapter: Adapter | None,
num_samples: int = None,
*,
stage: str = "training",
**kwargs,
):
super().__init__(**kwargs)
self.batch_size = batch_size
self.data = data
self.adapter = adapter
self.stage = stage

if num_samples is None:
self.num_samples = self._get_num_samples_from_data(data)
Expand All @@ -52,7 +55,7 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]:
}

if self.adapter is not None:
batch = self.adapter(batch)
batch = self.adapter(batch, stage=self.stage)

return batch

Expand Down
5 changes: 4 additions & 1 deletion bayesflow/datasets/online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def __init__(
batch_size: int,
num_batches: int,
adapter: Adapter | None,
*,
stage: str = "training",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -24,12 +26,13 @@ def __init__(
self._num_batches = num_batches
self.adapter = adapter
self.simulator = simulator
self.stage = stage

def __getitem__(self, item: int) -> dict[str, np.ndarray]:
batch = self.simulator.sample((self.batch_size,))

if self.adapter is not None:
batch = self.adapter(batch)
batch = self.adapter(batch, stage=self.stage)

return batch

Expand Down
5 changes: 4 additions & 1 deletion bayesflow/datasets/rounds_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
num_batches: int,
epochs_per_round: int,
adapter: Adapter | None,
*,
stage: str = "training",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -27,6 +29,7 @@
self.batch_size = batch_size
self.adapter = adapter
self.epoch = 0
self.stage = stage

Check warning on line 32 in bayesflow/datasets/rounds_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/rounds_dataset.py#L32

Added line #L32 was not covered by tests

if epochs_per_round == 1:
logging.warning(
Expand All @@ -45,7 +48,7 @@
batch = self.batches[item]

if self.adapter is not None:
batch = self.adapter(batch)
batch = self.adapter(batch, stage=self.stage)

Check warning on line 51 in bayesflow/datasets/rounds_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/rounds_dataset.py#L51

Added line #L51 was not covered by tests

return batch

Expand Down
Loading