Skip to content

Commit 8d38aa0

Browse files
committed
Enable augmentation to parts of the data or the whole data
1 parent 3d6c53c commit 8d38aa0

File tree

3 files changed

+33
-15
lines changed

3 files changed

+33
-15
lines changed

bayesflow/datasets/disk_dataset.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
load_fn: Callable = None,
3737
adapter: Adapter | None,
3838
stage: str = "training",
39-
augmentations: Mapping[str, Callable] = None,
39+
augmentations: Mapping[str, Callable] | Callable = None,
4040
**kwargs,
4141
):
4242
"""
@@ -57,9 +57,11 @@ def __init__(
5757
Optional adapter to transform the loaded batch.
5858
stage : str, default="training"
5959
Current stage (e.g., "training", "validation", etc.) used by the adapter.
60-
augmentations : Mapping[str, Callable], optional
61-
Dictionary of augmentation functions to apply to each corresponding key in the batch.
62-
Note - augmentations are applied before the adapter.
60+
augmentations : dict of str to Callable or Callable, optional
61+
Dictionary of augmentation functions to apply to each corresponding key in the batch
62+
or a function to apply to the entire batch (possibly adding new keys).
63+
Note - augmentations are applied before the adapter is called and are generally
64+
transforms that you only want to apply during training.
6365
**kwargs
6466
Additional keyword arguments passed to the base `PyDataset`.
6567
"""
@@ -85,9 +87,13 @@ def __getitem__(self, item) -> dict[str, np.ndarray]:
8587

8688
batch = tree_stack(batch)
8789

88-
if self.augmentations is not None:
90+
if isinstance(self.augmentations, Mapping):
8991
for key in self.augmentations:
9092
batch[key] = self.augmentations[key](batch[key])
93+
elif isinstance(self.augmentations, Callable):
94+
batch = self.augmentations(batch)
95+
else:
96+
raise RuntimeError(f"Could not apply augmentations of type {type(self.augmentations)}.")
9197

9298
if self.adapter is not None:
9399
batch = self.adapter(batch, stage=self.stage)

bayesflow/datasets/offline_dataset.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
num_samples: int = None,
2424
*,
2525
stage: str = "training",
26-
augmentations: Mapping[str, Callable] = None,
26+
augmentations: Mapping[str, Callable] | Callable = None,
2727
**kwargs,
2828
):
2929
"""
@@ -41,9 +41,11 @@ def __init__(
4141
Number of samples in the dataset. If None, it will be inferred from the data.
4242
stage : str, default="training"
4343
Current stage (e.g., "training", "validation", etc.) used by the adapter.
44-
augmentations : Mapping[str, Callable], optional
45-
Dictionary of augmentation functions to apply to each corresponding key in the batch.
46-
Note - augmentations are applied before the adapter.
44+
augmentations : dict of str to Callable or Callable, optional
45+
Dictionary of augmentation functions to apply to each corresponding key in the batch
46+
or a function to apply to the entire batch (possibly adding new keys).
47+
Note - augmentations are applied before the adapter is called and are generally
48+
transforms that you only want to apply during training.
4749
**kwargs
4850
Additional keyword arguments passed to the base `PyDataset`.
4951
"""
@@ -95,9 +97,13 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]:
9597
for key, value in self.data.items()
9698
}
9799

98-
if self.augmentations is not None:
100+
if isinstance(self.augmentations, Mapping):
99101
for key in self.augmentations:
100102
batch[key] = self.augmentations[key](batch[key])
103+
elif isinstance(self.augmentations, Callable):
104+
batch = self.augmentations(batch)
105+
else:
106+
raise RuntimeError(f"Could not apply augmentations of type {type(self.augmentations)}.")
101107

102108
if self.adapter is not None:
103109
batch = self.adapter(batch, stage=self.stage)

bayesflow/datasets/online_dataset.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(
2020
adapter: Adapter | None,
2121
*,
2222
stage: str = "training",
23-
augmentations: Mapping[str, Callable] = None,
23+
augmentations: Mapping[str, Callable] | Callable = None,
2424
**kwargs,
2525
):
2626
"""
@@ -38,9 +38,11 @@ def __init__(
3838
Optional adapter to transform the simulated batch.
3939
stage : str, default="training"
4040
Current stage (e.g., "training", "validation", etc.) used by the adapter.
41-
augmentations : dict of str to Callable, optional
42-
Dictionary of augmentation functions to apply to each corresponding key in the batch.
43-
Note - augmentations are applied before the adapter.
41+
augmentations : dict of str to Callable or Callable, optional
42+
Dictionary of augmentation functions to apply to each corresponding key in the batch
43+
or a function to apply to the entire batch (possibly adding new keys).
44+
Note - augmentations are applied before the adapter is called and are generally
45+
transforms that you only want to apply during training.
4446
**kwargs
4547
Additional keyword arguments passed to the base `PyDataset`.
4648
"""
@@ -69,9 +71,13 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]:
6971
"""
7072
batch = self.simulator.sample((self.batch_size,))
7173

72-
if self.augmentations is not None:
74+
if isinstance(self.augmentations, Mapping):
7375
for key in self.augmentations:
7476
batch[key] = self.augmentations[key](batch[key])
77+
elif isinstance(self.augmentations, Callable):
78+
batch = self.augmentations(batch)
79+
else:
80+
raise RuntimeError(f"Could not apply augmentations of type {type(self.augmentations)}.")
7581

7682
if self.adapter is not None:
7783
batch = self.adapter(batch, stage=self.stage)

0 commit comments

Comments
 (0)