1- from collections .abc import MutableSequence , Sequence
1+ from collections .abc import MutableSequence , Sequence , Mapping
22
33import numpy as np
4+
45from keras .saving import (
56 deserialize_keras_object as deserialize ,
67 register_keras_serializable as serializable ,
@@ -121,16 +122,16 @@ def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, any]:
121122
122123 return data
123124
124- def __call__ (self , data : dict [str , any ], * , inverse : bool = False , ** kwargs ) -> dict [str , np .ndarray ]:
125+ def __call__ (self , data : Mapping [str , any ], * , inverse : bool = False , ** kwargs ) -> dict [str , np .ndarray ]:
125126 """Apply the transforms in the given direction.
126127
127128 Parameters
128129 ----------
129- data : dict
130+ data : Mapping[str, any]
130131 The data to be transformed.
131132 inverse : bool, optional
132133 If False, apply the forward transform, else apply the inverse transform (default False).
133- **kwargs : dict
134+ **kwargs
134135 Additional keyword arguments passed to each transform.
135136
136137 Returns
@@ -233,28 +234,25 @@ def __len__(self):
233234
234235 def apply (
235236 self ,
237+ include : str | Sequence [str ] = None ,
236238 * ,
237239 forward : np .ufunc | str ,
238240 inverse : np .ufunc | str = None ,
239241 predicate : Predicate = None ,
240- include : str | Sequence [str ] = None ,
241242 exclude : str | Sequence [str ] = None ,
242243 ** kwargs ,
243244 ):
244245 """Append a :py:class:`~transforms.NumpyTransform` to the adapter.
245246
246247 Parameters
247248 ----------
248- forward: callable, no lambda
249- Function to transform the data in the forward pass.
250- For the adapter to be serializable, this function has to be serializable
251- as well (see Notes). Therefore, only proper functions and no lambda
252- functions should be used here.
253- inverse: callable, no lambda
254- Function to transform the data in the inverse pass.
255- For the adapter to be serializable, this function has to be serializable
256- as well (see Notes). Therefore, only proper functions and no lambda
257- functions should be used here.
249+ forward : str or np.ufunc
250+ The name of the NumPy function to use for the forward transformation.
251+ inverse : str or np.ufunc, optional
252+ The name of the NumPy function to use for the inverse transformation.
253+ By default, the inverse is inferred from the forward argument for supported methods.
254+ You can find the supported methods in
255+ :py:const:`~bayesflow.adapters.transforms.NumpyTransform.INVERSE_METHODS`.
258256 predicate : Predicate, optional
259257 Function that indicates which variables should be transformed.
260258 include : str or Sequence of str, optional
@@ -263,12 +261,6 @@ def apply(
263261 Names of variables to exclude from the transform.
264262 **kwargs : dict
265263 Additional keyword arguments passed to the transform.
266-
267- Notes
268- -----
269- Important: This is only serializable if the forward and inverse functions are serializable.
270- This most likely means you will have to pass the scope that the forward and inverse functions are contained in
271- to the `custom_objects` argument of the `deserialize` function when deserializing this class.
272264 """
273265 transform = FilterTransform (
274266 transform_constructor = NumpyTransform ,
@@ -388,6 +380,7 @@ def convert_dtype(
388380 exclude : str | Sequence [str ] = None ,
389381 ):
390382 """Append a :py:class:`~transforms.ConvertDType` transform to the adapter.
383+ See also :py:meth:`~bayesflow.adapters.Adapter.map_dtype`.
391384
392385 Parameters
393386 ----------
@@ -525,6 +518,24 @@ def log(self, keys: str | Sequence[str], *, p1: bool = False):
525518 self .transforms .append (transform )
526519 return self
527520
521+ def map_dtype (self , keys : str | Sequence [str ], to_dtype : str ):
522+ """Append a :py:class:`~transforms.ConvertDType` transform to the adapter.
523+ See also :py:meth:`~bayesflow.adapters.Adapter.convert_dtype`.
524+
525+ Parameters
526+ ----------
527+ keys : str or Sequence of str
528+ The names of the variables to transform.
529+ to_dtype : str
530+ Target dtype
531+ """
532+ if isinstance (keys , str ):
533+ keys = [keys ]
534+
535+ transform = MapTransform ({key : ConvertDType (to_dtype ) for key in keys })
536+ self .transforms .append (transform )
537+ return self
538+
528539 def one_hot (self , keys : str | Sequence [str ], num_classes : int ):
529540 """Append a :py:class:`~transforms.OneHot` transform to the adapter.
530541
@@ -555,6 +566,24 @@ def rename(self, from_key: str, to_key: str):
555566 self .transforms .append (Rename (from_key , to_key ))
556567 return self
557568
569+ def scale (self , keys : str | Sequence [str ], by : float | np .ndarray ):
570+ from .transforms import Scale
571+
572+ if isinstance (keys , str ):
573+ keys = [keys ]
574+
575+ self .transforms .append (MapTransform ({key : Scale (scale = by ) for key in keys }))
576+ return self
577+
578+ def shift (self , keys : str | Sequence [str ], by : float | np .ndarray ):
579+ from .transforms import Shift
580+
581+ if isinstance (keys , str ):
582+ keys = [keys ]
583+
584+ self .transforms .append (MapTransform ({key : Shift (shift = by ) for key in keys }))
585+ return self
586+
558587 def sqrt (self , keys : str | Sequence [str ]):
559588 """Append an :py:class:`~transforms.Sqrt` transform to the adapter.
560589
@@ -572,9 +601,9 @@ def sqrt(self, keys: str | Sequence[str]):
572601
573602 def standardize (
574603 self ,
604+ include : str | Sequence [str ] = None ,
575605 * ,
576606 predicate : Predicate = None ,
577- include : str | Sequence [str ] = None ,
578607 exclude : str | Sequence [str ] = None ,
579608 ** kwargs ,
580609 ):
@@ -603,9 +632,9 @@ def standardize(
603632
604633 def to_array (
605634 self ,
635+ include : str | Sequence [str ] = None ,
606636 * ,
607637 predicate : Predicate = None ,
608- include : str | Sequence [str ] = None ,
609638 exclude : str | Sequence [str ] = None ,
610639 ** kwargs ,
611640 ):
0 commit comments