5
5
6
6
7
7
class BatchedRand3DElasticd (MapTransform , RandomizableTransform ):
8
- """Batched 3D elastic deformation for biological structures."""
8
+ """Apply random 3D elastic deformation to input data.
9
+
10
+ Uses Gaussian-smoothed displacement fields to simulate deformation.
11
+
12
+ Parameters
13
+ ----------
14
+ keys : str or Iterable[str]
15
+ Keys of the corresponding items to be transformed.
16
+ sigma_range : tuple[float, float]
17
+ Range for random sigma values used in Gaussian smoothing.
18
+ magnitude_range : tuple[float, float]
19
+ Range for random displacement magnitude values.
20
+ spatial_size : tuple[int, int, int] or int or None, optional
21
+ Expected spatial size of input data.
22
+ prob : float, optional
23
+ Probability of applying the transform, by default 0.1.
24
+ mode : str, optional
25
+ Interpolation mode for grid sampling, by default "bilinear".
26
+ padding_mode : str, optional
27
+ Padding mode for grid sampling, by default "reflection".
28
+ allow_missing_keys : bool, optional
29
+ Whether to ignore missing keys, by default False.
30
+ """
9
31
10
32
def __init__ (
11
33
self ,
@@ -29,7 +51,6 @@ def __init__(
29
51
def _generate_elastic_field (
30
52
self , shape : torch .Size , device : torch .device
31
53
) -> Tensor :
32
- """Generate batched elastic deformation field."""
33
54
batch_size = shape [0 ]
34
55
spatial_dims = shape [2 :] # Skip batch and channel
35
56
@@ -76,6 +97,18 @@ def _generate_elastic_field(
76
97
return torch .stack (displacement_fields )
77
98
78
99
def __call__ (self , sample : dict [str , Tensor ]) -> dict [str , Tensor ]:
100
+ """Apply elastic deformation to sample data.
101
+
102
+ Parameters
103
+ ----------
104
+ sample : dict[str, Tensor]
105
+ Dictionary containing image tensors to transform.
106
+
107
+ Returns
108
+ -------
109
+ dict[str, Tensor]
110
+ Dictionary with transformed tensors.
111
+ """
79
112
self .randomize (None )
80
113
d = dict (sample )
81
114
0 commit comments