Skip to content

Commit f62f5f2

Browse files
committed
Enable augmentations in workflow
1 parent c709386 commit f62f5f2

File tree

1 file changed

+47
-3
lines changed

1 file changed

+47
-3
lines changed

bayesflow/workflows/basic_workflow.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,7 @@ def fit_offline(
674674
batch_size: int = 32,
675675
keep_optimizer: bool = False,
676676
validation_data: Mapping[str, np.ndarray] | int = None,
677+
augmentations: Mapping[str, Callable] | Callable = None,
677678
**kwargs,
678679
) -> keras.callbacks.History:
679680
"""
@@ -698,6 +699,16 @@ def fit_offline(
698699
A dictionary containing validation data. If an integer is provided,
699700
that number of validation samples will be generated (if supported).
700701
By default, no validation data is used.
702+
augmentations : dict of str to Callable or Callable, optional
703+
Dictionary of augmentation functions to apply to each corresponding key in the batch
704+
or a function to apply to the entire batch (possibly adding new keys).
705+
706+
If you provide a dictionary of functions, each function should accept one element
707+
of your output batch and return the corresponding transformed element. Otherwise,
708+
your function should accept the entire dictionary output and return a dictionary.
709+
710+
Note - augmentations are applied before the adapter is called and are generally
711+
transforms that you only want to apply during training.
701712
**kwargs : dict, optional
702713
Additional keyword arguments passed to the underlying `_fit` method.
703714
@@ -709,7 +720,7 @@ def fit_offline(
709720
metric evolution over epochs.
710721
"""
711722

712-
dataset = OfflineDataset(data=data, batch_size=batch_size, adapter=self.adapter)
723+
dataset = OfflineDataset(data=data, batch_size=batch_size, adapter=self.adapter, augmentations=augmentations)
713724

714725
return self._fit(
715726
dataset, epochs, strategy="online", keep_optimizer=keep_optimizer, validation_data=validation_data, **kwargs
@@ -722,6 +733,7 @@ def fit_online(
722733
batch_size: int = 32,
723734
keep_optimizer: bool = False,
724735
validation_data: Mapping[str, np.ndarray] | int = None,
736+
augmentations: Mapping[str, Callable] | Callable = None,
725737
**kwargs,
726738
) -> keras.callbacks.History:
727739
"""
@@ -743,6 +755,16 @@ def fit_online(
743755
A dictionary containing validation data. If an integer is provided,
744756
that number of validation samples will be generated (if supported).
745757
By default, no validation data is used.
758+
augmentations : dict of str to Callable or Callable, optional
759+
Dictionary of augmentation functions to apply to each corresponding key in the batch
760+
or a function to apply to the entire batch (possibly adding new keys).
761+
762+
If you provide a dictionary of functions, each function should accept one element
763+
of your output batch and return the corresponding transformed element. Otherwise,
764+
your function should accept the entire dictionary output and return a dictionary.
765+
766+
Note - augmentations are applied before the adapter is called and are generally
767+
transforms that you only want to apply during training.
746768
**kwargs : dict, optional
747769
Additional keyword arguments passed to the underlying `_fit` method.
748770
@@ -755,7 +777,11 @@ def fit_online(
755777
"""
756778

757779
dataset = OnlineDataset(
758-
simulator=self.simulator, batch_size=batch_size, num_batches=num_batches_per_epoch, adapter=self.adapter
780+
simulator=self.simulator,
781+
batch_size=batch_size,
782+
num_batches=num_batches_per_epoch,
783+
adapter=self.adapter,
784+
augmentations=augmentations,
759785
)
760786

761787
return self._fit(
@@ -771,6 +797,7 @@ def fit_disk(
771797
epochs: int = 100,
772798
keep_optimizer: bool = False,
773799
validation_data: Mapping[str, np.ndarray] | int = None,
800+
augmentations: Mapping[str, Callable] | Callable = None,
774801
**kwargs,
775802
) -> keras.callbacks.History:
776803
"""
@@ -798,6 +825,16 @@ def fit_disk(
798825
A dictionary containing validation data. If an integer is provided,
799826
that number of validation samples will be generated (if supported).
800827
By default, no validation data is used.
828+
augmentations : dict of str to Callable or Callable, optional
829+
Dictionary of augmentation functions to apply to each corresponding key in the batch
830+
or a function to apply to the entire batch (possibly adding new keys).
831+
832+
If you provide a dictionary of functions, each function should accept one element
833+
of your output batch and return the corresponding transformed element. Otherwise,
834+
your function should accept the entire dictionary output and return a dictionary.
835+
836+
Note - augmentations are applied before the adapter is called and are generally
837+
transforms that you only want to apply during training.
801838
**kwargs : dict, optional
802839
Additional keyword arguments passed to the underlying `_fit` method.
803840
@@ -809,7 +846,14 @@ def fit_disk(
809846
metric evolution over epochs.
810847
"""
811848

812-
dataset = DiskDataset(root=root, pattern=pattern, batch_size=batch_size, load_fn=load_fn, adapter=self.adapter)
849+
dataset = DiskDataset(
850+
root=root,
851+
pattern=pattern,
852+
batch_size=batch_size,
853+
load_fn=load_fn,
854+
adapter=self.adapter,
855+
augmentations=augmentations,
856+
)
813857

814858
return self._fit(
815859
dataset, epochs, strategy="online", keep_optimizer=keep_optimizer, validation_data=validation_data, **kwargs

0 commit comments

Comments
 (0)