2525 Standardize ,
2626 ToArray ,
2727 Transform ,
28+ RandomSubsample ,
29+ Take ,
2830)
2931from .transforms .filter_transform import Predicate
3032
3133
32- @serializable
34+ @serializable ( "bayesflow.adapters" )
3335class Adapter (MutableSequence [Transform ]):
3436 """
3537 Defines an adapter to apply various transforms to data.
@@ -79,7 +81,9 @@ def get_config(self) -> dict:
7981
8082 return serialize (config )
8183
82- def forward (self , data : dict [str , any ], * , stage : str = "inference" , ** kwargs ) -> dict [str , np .ndarray ]:
84+ def forward (
85+ self , data : dict [str , any ], * , stage : str = "inference" , log_det_jac : bool = False , ** kwargs
86+ ) -> dict [str , np .ndarray ] | tuple [dict [str , np .ndarray ], dict [str , np .ndarray ]]:
8387 """Apply the transforms in the forward direction.
8488
8589 Parameters
@@ -88,22 +92,33 @@ def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -
8892 The data to be transformed.
8993 stage : str, one of ["training", "validation", "inference"]
9094 The stage the function is called in.
95+ log_det_jac: bool, optional
96+ Whether to return the log determinant of the Jacobian of the transforms.
9197 **kwargs : dict
9298 Additional keyword arguments passed to each transform.
9399
94100 Returns
95101 -------
96- dict
97- The transformed data.
102+ dict | tuple[dict, dict]
103+ The transformed data or tuple of transformed data and log determinant of the Jacobian .
98104 """
99105 data = data .copy ()
106+ if not log_det_jac :
107+ for transform in self .transforms :
108+ data = transform (data , stage = stage , ** kwargs )
109+ return data
100110
111+ log_det_jac = {}
101112 for transform in self .transforms :
102- data = transform (data , stage = stage , ** kwargs )
113+ transformed_data = transform (data , stage = stage , ** kwargs )
114+ log_det_jac = transform .log_det_jac (data , log_det_jac , ** kwargs )
115+ data = transformed_data
103116
104- return data
117+ return data , log_det_jac
105118
106- def inverse (self , data : dict [str , np .ndarray ], * , stage : str = "inference" , ** kwargs ) -> dict [str , any ]:
119+ def inverse (
120+ self , data : dict [str , np .ndarray ], * , stage : str = "inference" , log_det_jac : bool = False , ** kwargs
121+ ) -> dict [str , np .ndarray ] | tuple [dict [str , np .ndarray ], dict [str , np .ndarray ]]:
107122 """Apply the transforms in the inverse direction.
108123
109124 Parameters
@@ -112,24 +127,32 @@ def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kw
112127 The data to be transformed.
113128 stage : str, one of ["training", "validation", "inference"]
114129 The stage the function is called in.
130+ log_det_jac: bool, optional
131+ Whether to return the log determinant of the Jacobian of the transforms.
115132 **kwargs : dict
116133 Additional keyword arguments passed to each transform.
117134
118135 Returns
119136 -------
120- dict
121- The transformed data.
137+ dict | tuple[dict, dict]
138+ The transformed data or tuple of transformed data and log determinant of the Jacobian .
122139 """
123140 data = data .copy ()
141+ if not log_det_jac :
142+ for transform in reversed (self .transforms ):
143+ data = transform (data , stage = stage , inverse = True , ** kwargs )
144+ return data
124145
146+ log_det_jac = {}
125147 for transform in reversed (self .transforms ):
126148 data = transform (data , stage = stage , inverse = True , ** kwargs )
149+ log_det_jac = transform .log_det_jac (data , log_det_jac , inverse = True , ** kwargs )
127150
128- return data
151+ return data , log_det_jac
129152
130153 def __call__ (
131154 self , data : Mapping [str , any ], * , inverse : bool = False , stage = "inference" , ** kwargs
132- ) -> dict [str , np .ndarray ]:
155+ ) -> dict [str , np .ndarray ] | tuple [ dict [ str , np . ndarray ], dict [ str , np . ndarray ]] :
133156 """Apply the transforms in the given direction.
134157
135158 Parameters
@@ -145,8 +168,8 @@ def __call__(
145168
146169 Returns
147170 -------
148- dict
149- The transformed data.
171+ dict | tuple[dict, dict]
172+ The transformed data or tuple of transformed data and log determinant of the Jacobian .
150173 """
151174 if inverse :
152175 return self .inverse (data , stage = stage , ** kwargs )
@@ -644,6 +667,28 @@ def one_hot(self, keys: str | Sequence[str], num_classes: int):
644667 self .transforms .append (transform )
645668 return self
646669
670+ def random_subsample (self , key : str , * , sample_size : int | float , axis : int = - 1 ):
671+ """
672+ Append a :py:class:`~transforms.RandomSubsample` transform to the adapter.
673+
674+ Parameters
675+ ----------
676+ key : str or Sequence of str
677+ The name of the variable to subsample.
678+ sample_size : int or float
679+ The number of samples to draw, or a fraction between 0 and 1 of the total number of samples to draw.
680+ axis: int, optional
681+ Which axis to draw samples over. The last axis is used by default.
682+ """
683+
684+ if not isinstance (key , str ):
685+ raise TypeError ("Can only subsample one batch entry at a time." )
686+
687+ transform = MapTransform ({key : RandomSubsample (sample_size = sample_size , axis = axis )})
688+
689+ self .transforms .append (transform )
690+ return self
691+
647692 def rename (self , from_key : str , to_key : str ):
648693 """Append a :py:class:`~transforms.Rename` transform to the adapter.
649694
@@ -720,7 +765,7 @@ def standardize(
720765 Names of variables to include in the transform.
721766 exclude : str or Sequence of str, optional
722767 Names of variables to exclude from the transform.
723- **kwargs : dict
768+ **kwargs :
724769 Additional keyword arguments passed to the transform.
725770 """
726771 transform = FilterTransform (
@@ -733,6 +778,42 @@ def standardize(
733778 self .transforms .append (transform )
734779 return self
735780
781+ def take (
782+ self ,
783+ include : str | Sequence [str ] = None ,
784+ * ,
785+ indices : Sequence [int ],
786+ axis : int = - 1 ,
787+ predicate : Predicate = None ,
788+ exclude : str | Sequence [str ] = None ,
789+ ):
790+ """
791+ Append a :py:class:`~transforms.Take` transform to the adapter.
792+
793+ Parameters
794+ ----------
795+ include : str or Sequence of str, optional
796+ Names of variables to include in the transform.
797+ indices : Sequence of int
798+ Which indices to take from the data.
799+ axis : int, optional
800+ Which axis to take from. The last axis is used by default.
801+ predicate : Predicate, optional
802+ Function that indicates which variables should be transformed.
803+ exclude : str or Sequence of str, optional
804+ Names of variables to exclude from the transform.
805+ """
806+ transform = FilterTransform (
807+ transform_constructor = Take ,
808+ predicate = predicate ,
809+ include = include ,
810+ exclude = exclude ,
811+ indices = indices ,
812+ axis = axis ,
813+ )
814+ self .transforms .append (transform )
815+ return self
816+
736817 def to_array (
737818 self ,
738819 include : str | Sequence [str ] = None ,
0 commit comments