2424import torch
2525
2626from monai .config import DtypeLike , KeysCollection , NdarrayTensor
27+ from monai .data .utils import no_collation
2728from monai .transforms .inverse import InvertibleTransform
28- from monai .transforms .transform import MapTransform , Randomizable
29+ from monai .transforms .transform import MapTransform , Randomizable , RandomizableTransform
2930from monai .transforms .utility .array import (
3031 AddChannel ,
3132 AsChannelFirst ,
@@ -833,7 +834,7 @@ def __call__(self, data):
833834 return d
834835
835836
836- class Lambdad (MapTransform ):
837+ class Lambdad (MapTransform , InvertibleTransform ):
837838 """
838839 Dictionary-based wrapper of :py:class:`monai.transforms.Lambda`.
839840
@@ -852,51 +853,110 @@ class Lambdad(MapTransform):
852853 See also: :py:class:`monai.transforms.compose.MapTransform`
853854 func: Lambda/function to be applied. It also can be a sequence of Callable,
854855 each element corresponds to a key in ``keys``.
856+ inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`.
857+ It also can be a sequence of Callable, each element corresponds to a key in ``keys``.
855858 overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output.
856859 default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``.
857860 allow_missing_keys: don't raise exception if key is missing.
861+
862+ Note: The inverse operation doesn't allow to define `extra_info` or access other information, such as the
863+ image's original size. If need these complicated information, please write a new InvertibleTransform directly.
864+
858865 """
859866
860867 def __init__ (
861868 self ,
862869 keys : KeysCollection ,
863870 func : Union [Sequence [Callable ], Callable ],
871+ inv_func : Union [Sequence [Callable ], Callable ] = no_collation ,
864872 overwrite : Union [Sequence [bool ], bool ] = True ,
865873 allow_missing_keys : bool = False ,
866874 ) -> None :
867875 super ().__init__ (keys , allow_missing_keys )
868876 self .func = ensure_tuple_rep (func , len (self .keys ))
877+ self .inv_func = ensure_tuple_rep (inv_func , len (self .keys ))
869878 self .overwrite = ensure_tuple_rep (overwrite , len (self .keys ))
870879 self ._lambd = Lambda ()
871880
881+ def _transform (self , data : Any , func : Callable ):
882+ return self ._lambd (data , func = func )
883+
872884 def __call__ (self , data ):
873885 d = dict (data )
874886 for key , func , overwrite in self .key_iterator (d , self .func , self .overwrite ):
875- ret = self ._lambd (d [key ], func = func )
887+ ret = self ._transform (data = d [key ], func = func )
888+ if overwrite :
889+ d [key ] = ret
890+ self .push_transform (d , key )
891+ return d
892+
893+ def _inverse_transform (self , transform_info : Dict , data : Any , func : Callable ):
894+ return self ._lambd (data , func = func )
895+
896+ def inverse (self , data ):
897+ d = deepcopy (dict (data ))
898+ for key , inv_func , overwrite in self .key_iterator (d , self .inv_func , self .overwrite ):
899+ transform = self .get_most_recent_transform (d , key )
900+ ret = self ._inverse_transform (transform_info = transform , data = d [key ], func = inv_func )
876901 if overwrite :
877902 d [key ] = ret
903+ self .pop_transform (d , key )
878904 return d
879905
880906
881- class RandLambdad (Lambdad , Randomizable ):
907+ class RandLambdad (Lambdad , RandomizableTransform ):
882908 """
883- Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic.
884- It's a randomizable transform so `CacheDataset` will not execute it and cache the results.
909+ Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` may contain random logic,
910+ or randomly execute the function based on `prob`. so `CacheDataset` will not execute it and cache the results.
885911
886912 Args:
887913 keys: keys of the corresponding items to be transformed.
888914 See also: :py:class:`monai.transforms.compose.MapTransform`
889915 func: Lambda/function to be applied. It also can be a sequence of Callable,
890916 each element corresponds to a key in ``keys``.
917+ inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`.
918+ It also can be a sequence of Callable, each element corresponds to a key in ``keys``.
891919 overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output.
892920 default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``.
921+ prob: probability of executing the random function, default to 1.0, with 100% probability to execute.
922+ note that all the data specified by `keys` will share the same random probability to execute or not.
923+ allow_missing_keys: don't raise exception if key is missing.
893924
894925 For more details, please check :py:class:`monai.transforms.Lambdad`.
895926
927+ Note: The inverse operation doesn't allow to define `extra_info` or access other information, such as the
928+ image's original size. If need these complicated information, please write a new InvertibleTransform directly.
929+
896930 """
897931
898- def randomize (self , data : Any ) -> None :
899- pass
932+ def __init__ (
933+ self ,
934+ keys : KeysCollection ,
935+ func : Union [Sequence [Callable ], Callable ],
936+ inv_func : Union [Sequence [Callable ], Callable ] = no_collation ,
937+ overwrite : Union [Sequence [bool ], bool ] = True ,
938+ prob : float = 1.0 ,
939+ allow_missing_keys : bool = False ,
940+ ) -> None :
941+ Lambdad .__init__ (
942+ self = self ,
943+ keys = keys ,
944+ func = func ,
945+ inv_func = inv_func ,
946+ overwrite = overwrite ,
947+ allow_missing_keys = allow_missing_keys ,
948+ )
949+ RandomizableTransform .__init__ (self = self , prob = prob , do_transform = True )
950+
951+ def _transform (self , data : Any , func : Callable ):
952+ return self ._lambd (data , func = func ) if self ._do_transform else data
953+
954+ def __call__ (self , data ):
955+ self .randomize (data )
956+ return super ().__call__ (data )
957+
958+ def _inverse_transform (self , transform_info : Dict , data : Any , func : Callable ):
959+ return self ._lambd (data , func = func ) if transform_info [InverseKeys .DO_TRANSFORM ] else data
900960
901961
902962class LabelToMaskd (MapTransform ):
0 commit comments