4949 RemoveRepeatedChannel ,
5050 RepeatChannel ,
5151 SimulateDelay ,
52- SplitChannel ,
52+ SplitDim ,
5353 SqueezeDim ,
5454 ToCupy ,
5555 ToDevice ,
6161)
6262from monai .transforms .utils import extreme_points_to_image , get_extreme_points
6363from monai .transforms .utils_pytorch_numpy_unification import concatenate
64- from monai .utils import convert_to_numpy , deprecated_arg , ensure_tuple , ensure_tuple_rep
64+ from monai .utils import convert_to_numpy , deprecated , deprecated_arg , ensure_tuple , ensure_tuple_rep
6565from monai .utils .enums import PostFix , TraceKeys , TransformBackends
6666from monai .utils .type_conversion import convert_to_dst_type
6767
150150 "SplitChannelD" ,
151151 "SplitChannelDict" ,
152152 "SplitChanneld" ,
153+ "SplitDimD" ,
154+ "SplitDimDict" ,
155+ "SplitDimd" ,
153156 "SqueezeDimD" ,
154157 "SqueezeDimDict" ,
155158 "SqueezeDimd" ,
@@ -372,19 +375,14 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
372375 return d
373376
374377
375- class SplitChanneld (MapTransform ):
376- """
377- Dictionary-based wrapper of :py:class:`monai.transforms.SplitChannel`.
378- All the input specified by `keys` should be split into same count of data.
379- """
380-
381- backend = SplitChannel .backend
382-
378+ class SplitDimd (MapTransform ):
383379 def __init__ (
384380 self ,
385381 keys : KeysCollection ,
386382 output_postfixes : Optional [Sequence [str ]] = None ,
387- channel_dim : int = 0 ,
383+ dim : int = 0 ,
384+ keepdim : bool = True ,
385+ update_meta : bool = True ,
388386 allow_missing_keys : bool = False ,
389387 ) -> None :
390388 """
@@ -395,13 +393,17 @@ def __init__(
395393 for example: if the key of input data is `pred` and split 2 classes, the output
396394 data keys will be: pred_(output_postfixes[0]), pred_(output_postfixes[1])
397395 if None, using the index number: `pred_0`, `pred_1`, ... `pred_N`.
398- channel_dim: which dimension of input image is the channel, default to 0.
396+ dim: which dimension of input image is the channel, default to 0.
397+ keepdim: if `True`, output will have singleton in the split dimension. If `False`, this
398+ dimension will be squeezed.
399+ update_meta: if `True`, copy `[key]_meta_dict` for each output and update affine to
400+ reflect the cropped image
399401 allow_missing_keys: don't raise exception if key is missing.
400-
401402 """
402403 super ().__init__ (keys , allow_missing_keys )
403404 self .output_postfixes = output_postfixes
404- self .splitter = SplitChannel (channel_dim = channel_dim )
405+ self .splitter = SplitDim (dim , keepdim )
406+ self .update_meta = update_meta
405407
406408 def __call__ (self , data : Mapping [Hashable , NdarrayOrTensor ]) -> Dict [Hashable , NdarrayOrTensor ]:
407409 d = dict (data )
@@ -415,9 +417,44 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
415417 if split_key in d :
416418 raise RuntimeError (f"input data already contains key { split_key } ." )
417419 d [split_key ] = r
420+
421+ if self .update_meta :
422+ orig_meta = d .get (PostFix .meta (key ), None )
423+ if orig_meta is not None :
424+ split_meta_key = PostFix .meta (split_key )
425+ d [split_meta_key ] = deepcopy (orig_meta )
426+ dim = self .splitter .dim
427+ if dim > 0 : # don't update affine if channel dim
428+ shift = np .eye (len (d [split_meta_key ]["affine" ])) # type: ignore
429+ shift [dim - 1 , - 1 ] = i # type: ignore
430+ d [split_meta_key ]["affine" ] = d [split_meta_key ]["affine" ] @ shift # type: ignore
431+
418432 return d
419433
420434
435+ @deprecated (since = "0.8" , msg_suffix = "please use `SplitDimd` instead." )
436+ class SplitChanneld (SplitDimd ):
437+ """
438+ Dictionary-based wrapper of :py:class:`monai.transforms.SplitChannel`.
439+ All the input specified by `keys` should be split into same count of data.
440+ """
441+
442+ def __init__ (
443+ self ,
444+ keys : KeysCollection ,
445+ output_postfixes : Optional [Sequence [str ]] = None ,
446+ channel_dim : int = 0 ,
447+ allow_missing_keys : bool = False ,
448+ ) -> None :
449+ super ().__init__ (
450+ keys ,
451+ output_postfixes = output_postfixes ,
452+ dim = channel_dim ,
453+ update_meta = False , # for backwards compatibility
454+ allow_missing_keys = allow_missing_keys ,
455+ )
456+
457+
421458class CastToTyped (MapTransform ):
422459 """
423460 Dictionary-based wrapper of :py:class:`monai.transforms.CastToType`.
@@ -1637,6 +1674,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
16371674RemoveRepeatedChannelD = RemoveRepeatedChannelDict = RemoveRepeatedChanneld
16381675RepeatChannelD = RepeatChannelDict = RepeatChanneld
16391676SplitChannelD = SplitChannelDict = SplitChanneld
1677+ SplitDimD = SplitDimDict = SplitDimd
16401678CastToTypeD = CastToTypeDict = CastToTyped
16411679ToTensorD = ToTensorDict = ToTensord
16421680EnsureTypeD = EnsureTypeDict = EnsureTyped
0 commit comments