@@ -674,6 +674,7 @@ def fit_offline(
674674 batch_size : int = 32 ,
675675 keep_optimizer : bool = False ,
676676 validation_data : Mapping [str , np .ndarray ] | int = None ,
677+ augmentations : Mapping [str , Callable ] | Callable = None ,
677678 ** kwargs ,
678679 ) -> keras .callbacks .History :
679680 """
@@ -698,6 +699,16 @@ def fit_offline(
698699 A dictionary containing validation data. If an integer is provided,
699700 that number of validation samples will be generated (if supported).
700701 By default, no validation data is used.
702+ augmentations : dict of str to Callable or Callable, optional
703+ Dictionary of augmentation functions to apply to each corresponding key in the batch
704+ or a function to apply to the entire batch (possibly adding new keys).
705+
706+ If you provide a dictionary of functions, each function should accept one element
707+ of your output batch and return the corresponding transformed element. Otherwise,
708+ your function should accept the entire dictionary output and return a dictionary.
709+
710+ Note - augmentations are applied before the adapter is called and are generally
711+ transforms that you only want to apply during training.
701712 **kwargs : dict, optional
702713 Additional keyword arguments passed to the underlying `_fit` method.
703714
@@ -709,7 +720,7 @@ def fit_offline(
709720 metric evolution over epochs.
710721 """
711722
712- dataset = OfflineDataset (data = data , batch_size = batch_size , adapter = self .adapter )
723+ dataset = OfflineDataset (data = data , batch_size = batch_size , adapter = self .adapter , augmentations = augmentations )
713724
714725 return self ._fit (
715726 dataset , epochs , strategy = "online" , keep_optimizer = keep_optimizer , validation_data = validation_data , ** kwargs
@@ -722,6 +733,7 @@ def fit_online(
722733 batch_size : int = 32 ,
723734 keep_optimizer : bool = False ,
724735 validation_data : Mapping [str , np .ndarray ] | int = None ,
736+ augmentations : Mapping [str , Callable ] | Callable = None ,
725737 ** kwargs ,
726738 ) -> keras .callbacks .History :
727739 """
@@ -743,6 +755,16 @@ def fit_online(
743755 A dictionary containing validation data. If an integer is provided,
744756 that number of validation samples will be generated (if supported).
745757 By default, no validation data is used.
758+ augmentations : dict of str to Callable or Callable, optional
759+ Dictionary of augmentation functions to apply to each corresponding key in the batch
760+ or a function to apply to the entire batch (possibly adding new keys).
761+
762+ If you provide a dictionary of functions, each function should accept one element
763+ of your output batch and return the corresponding transformed element. Otherwise,
764+ your function should accept the entire dictionary output and return a dictionary.
765+
766+ Note - augmentations are applied before the adapter is called and are generally
767+ transforms that you only want to apply during training.
746768 **kwargs : dict, optional
747769 Additional keyword arguments passed to the underlying `_fit` method.
748770
@@ -755,7 +777,11 @@ def fit_online(
755777 """
756778
757779 dataset = OnlineDataset (
758- simulator = self .simulator , batch_size = batch_size , num_batches = num_batches_per_epoch , adapter = self .adapter
780+ simulator = self .simulator ,
781+ batch_size = batch_size ,
782+ num_batches = num_batches_per_epoch ,
783+ adapter = self .adapter ,
784+ augmentations = augmentations ,
759785 )
760786
761787 return self ._fit (
@@ -771,6 +797,7 @@ def fit_disk(
771797 epochs : int = 100 ,
772798 keep_optimizer : bool = False ,
773799 validation_data : Mapping [str , np .ndarray ] | int = None ,
800+ augmentations : Mapping [str , Callable ] | Callable = None ,
774801 ** kwargs ,
775802 ) -> keras .callbacks .History :
776803 """
@@ -798,6 +825,16 @@ def fit_disk(
798825 A dictionary containing validation data. If an integer is provided,
799826 that number of validation samples will be generated (if supported).
800827 By default, no validation data is used.
828+ augmentations : dict of str to Callable or Callable, optional
829+ Dictionary of augmentation functions to apply to each corresponding key in the batch
830+ or a function to apply to the entire batch (possibly adding new keys).
831+
832+ If you provide a dictionary of functions, each function should accept one element
833+ of your output batch and return the corresponding transformed element. Otherwise,
834+ your function should accept the entire dictionary output and return a dictionary.
835+
836+ Note - augmentations are applied before the adapter is called and are generally
837+ transforms that you only want to apply during training.
801838 **kwargs : dict, optional
802839 Additional keyword arguments passed to the underlying `_fit` method.
803840
@@ -809,7 +846,14 @@ def fit_disk(
809846 metric evolution over epochs.
810847 """
811848
812- dataset = DiskDataset (root = root , pattern = pattern , batch_size = batch_size , load_fn = load_fn , adapter = self .adapter )
849+ dataset = DiskDataset (
850+ root = root ,
851+ pattern = pattern ,
852+ batch_size = batch_size ,
853+ load_fn = load_fn ,
854+ adapter = self .adapter ,
855+ augmentations = augmentations ,
856+ )
813857
814858 return self ._fit (
815859 dataset , epochs , strategy = "online" , keep_optimizer = keep_optimizer , validation_data = validation_data , ** kwargs
0 commit comments