1414 Drop ,
1515 ExpandDims ,
1616 FilterTransform ,
17+ Group ,
1718 Keep ,
1819 Log ,
1920 MapTransform ,
2021 NumpyTransform ,
2122 OneHot ,
2223 Rename ,
2324 SerializableCustomTransform ,
25+ Squeeze ,
2426 Sqrt ,
2527 Standardize ,
2628 ToArray ,
2729 Transform ,
30+ Ungroup ,
31+ RandomSubsample ,
32+ Take ,
2833 NanToNum ,
2934)
3035from .transforms .filter_transform import Predicate
3136
3237
33- @serializable
38+ @serializable ( "bayesflow.adapters" )
3439class Adapter (MutableSequence [Transform ]):
3540 """
3641 Defines an adapter to apply various transforms to data.
@@ -599,6 +604,52 @@ def expand_dims(self, keys: str | Sequence[str], *, axis: int | tuple):
599604 self .transforms .append (transform )
600605 return self
601606
607+ def group (self , keys : Sequence [str ], into : str , * , prefix : str = "" ):
608+ """Append a :py:class:`~transforms.Group` transform to the adapter.
609+
610+ Groups the given variables as a dictionary in the key `into`. As most transforms do
611+ not support nested structures, this should usually be the last transform in the adapter.
612+
613+ Parameters
614+ ----------
615+ keys : Sequence of str
616+ The names of the variables to group together.
617+ into : str
618+ The name of the variable to store the grouped variables in.
619+ prefix : str, optional
620+ An optional common prefix of the variable names before grouping, which will be removed after grouping.
621+
622+ Raises
623+ ------
624+ ValueError
625+ If a prefix is specified, but a provided key does not start with the prefix.
626+ """
627+ if isinstance (keys , str ):
628+ keys = [keys ]
629+
630+ transform = Group (keys = keys , into = into , prefix = prefix )
631+ self .transforms .append (transform )
632+ return self
633+
634+ def ungroup (self , key : str , * , prefix : str = "" ):
635+ """Append an :py:class:`~transforms.Ungroup` transform to the adapter.
636+
637+ Ungroups the the variables in `key` from a dictionary into individual entries. Most transforms do
638+ not support nested structures, so this can be used to flatten a nested structure.
639+ The nesting can be re-established after the transforms using the :py:meth:`group` method.
640+
641+ Parameters
642+ ----------
643+ key : str
644+ The name of the variable to ungroup. The corresponding variable has to be a dictionary.
645+ prefix : str, optional
646+ An optional common prefix that will be added to the ungrouped variable names. This can be necessary
647+ to avoid duplicate names.
648+ """
649+ transform = Ungroup (key = key , prefix = prefix )
650+ self .transforms .append (transform )
651+ return self
652+
602653 def keep (self , keys : str | Sequence [str ]):
603654 """Append a :py:class:`~transforms.Keep` transform to the adapter.
604655
@@ -666,6 +717,28 @@ def one_hot(self, keys: str | Sequence[str], num_classes: int):
666717 self .transforms .append (transform )
667718 return self
668719
720+ def random_subsample (self , key : str , * , sample_size : int | float , axis : int = - 1 ):
721+ """
722+ Append a :py:class:`~transforms.RandomSubsample` transform to the adapter.
723+
724+ Parameters
725+ ----------
726+ key : str or Sequence of str
727+ The name of the variable to subsample.
728+ sample_size : int or float
729+ The number of samples to draw, or a fraction between 0 and 1 of the total number of samples to draw.
730+ axis: int, optional
731+ Which axis to draw samples over. The last axis is used by default.
732+ """
733+
734+ if not isinstance (key , str ):
735+ raise TypeError ("Can only subsample one batch entry at a time." )
736+
737+ transform = MapTransform ({key : RandomSubsample (sample_size = sample_size , axis = axis )})
738+
739+ self .transforms .append (transform )
740+ return self
741+
669742 def rename (self , from_key : str , to_key : str ):
670743 """Append a :py:class:`~transforms.Rename` transform to the adapter.
671744
@@ -709,6 +782,24 @@ def split(self, key: str, *, into: Sequence[str], indices_or_sections: int | Seq
709782
710783 return self
711784
785+ def squeeze (self , keys : str | Sequence [str ], * , axis : int | tuple ):
786+ """Append a :py:class:`~transforms.Squeeze` transform to the adapter.
787+
788+ Parameters
789+ ----------
790+ keys : str or Sequence of str
791+ The names of the variables to squeeze.
792+ axis : int or tuple
793+ The axis to squeeze. As the number of batch dimensions might change, we advise using negative
794+ numbers (i.e., indexing from the end instead of the start).
795+ """
796+ if isinstance (keys , str ):
797+ keys = [keys ]
798+
799+ transform = MapTransform ({key : Squeeze (axis = axis ) for key in keys })
800+ self .transforms .append (transform )
801+ return self
802+
712803 def sqrt (self , keys : str | Sequence [str ]):
713804 """Append an :py:class:`~transforms.Sqrt` transform to the adapter.
714805
@@ -742,7 +833,7 @@ def standardize(
742833 Names of variables to include in the transform.
743834 exclude : str or Sequence of str, optional
744835 Names of variables to exclude from the transform.
745- **kwargs : dict
836+ **kwargs :
746837 Additional keyword arguments passed to the transform.
747838 """
748839 transform = FilterTransform (
@@ -755,6 +846,42 @@ def standardize(
755846 self .transforms .append (transform )
756847 return self
757848
849+ def take (
850+ self ,
851+ include : str | Sequence [str ] = None ,
852+ * ,
853+ indices : Sequence [int ],
854+ axis : int = - 1 ,
855+ predicate : Predicate = None ,
856+ exclude : str | Sequence [str ] = None ,
857+ ):
858+ """
859+ Append a :py:class:`~transforms.Take` transform to the adapter.
860+
861+ Parameters
862+ ----------
863+ include : str or Sequence of str, optional
864+ Names of variables to include in the transform.
865+ indices : Sequence of int
866+ Which indices to take from the data.
867+ axis : int, optional
868+ Which axis to take from. The last axis is used by default.
869+ predicate : Predicate, optional
870+ Function that indicates which variables should be transformed.
871+ exclude : str or Sequence of str, optional
872+ Names of variables to exclude from the transform.
873+ """
874+ transform = FilterTransform (
875+ transform_constructor = Take ,
876+ predicate = predicate ,
877+ include = include ,
878+ exclude = exclude ,
879+ indices = indices ,
880+ axis = axis ,
881+ )
882+ self .transforms .append (transform )
883+ return self
884+
758885 def to_array (
759886 self ,
760887 include : str | Sequence [str ] = None ,
0 commit comments