1
- import tensorflow as tf
1
+ from keras import ops
2
+ from keras import layers
3
+ from keras import random
2
4
3
- class RandomElasticDeformation3D (tf . keras . layers .Layer ):
5
+ class RandomElasticDeformation3D (layers .Layer ):
4
6
"""
5
- A high-performance 3D elastic deformation layer optimized for TPUs and GPUs.
6
- ... (docstring is the same) ...
7
+ A high-performance 3D elastic deformation layer optimized for TPUs.
8
+
9
+ This implementation leverages the layer's compute_dtype (e.g., bfloat16)
10
+ to potentially halve memory bandwidth requirements and uses a vectorized
11
+ mapping for maximum speed.
7
12
"""
8
13
def __init__ (self ,
9
14
grid_size = (4 , 4 , 4 ),
10
15
alpha = 35.0 ,
11
16
sigma = 2.5 ,
12
- data_format = "DHWC " ,
17
+ data_format = "channels_last " ,
13
18
** kwargs ):
14
19
super ().__init__ (** kwargs )
20
+
15
21
self .grid_size = grid_size
16
- self .alpha = tf .constant (alpha , dtype = tf .bfloat16 )
17
- self .sigma = tf .constant (sigma , dtype = tf .bfloat16 )
18
- if data_format not in ["DHWC" , "HWDC" ]:
19
- raise ValueError ("`data_format` must be one of 'DHWC' or 'HWDC'" )
22
+ self .alpha = alpha
23
+ self .sigma = sigma
20
24
self .data_format = data_format
21
-
22
- def _separable_gaussian_filter_3d (self , tensor , sigma ):
23
-
24
- kernel_size = tf .cast (2 * tf .round (3 * sigma ) + 1 , dtype = tf .int32 )
25
- ax = tf .range (- tf .cast (kernel_size // 2 , tf .bfloat16 ) + 1.0 ,
26
- tf .cast (kernel_size // 2 , tf .bfloat16 ) + 1.0 )
27
- kernel_1d = tf .exp (- (ax ** 2 ) / (2.0 * self .sigma ** 2 ))
28
- kernel_1d = kernel_1d / tf .reduce_sum (kernel_1d )
29
- filter_d = tf .cast (tf .reshape (kernel_1d , [- 1 , 1 , 1 , 1 , 1 ]), dtype = tensor .dtype )
30
- filter_h = tf .cast (tf .reshape (kernel_1d , [1 , - 1 , 1 , 1 , 1 ]), dtype = tensor .dtype )
31
- filter_w = tf .cast (tf .reshape (kernel_1d , [1 , 1 , - 1 , 1 , 1 ]), dtype = tensor .dtype )
32
- tensor = tf .nn .convolution (tensor , filter_d , strides = 1 , padding = 'SAME' )
33
- tensor = tf .nn .convolution (tensor , filter_h , strides = 1 , padding = 'SAME' )
34
- tensor = tf .nn .convolution (tensor , filter_w , strides = 1 , padding = 'SAME' )
25
+ if data_format not in ["channels_last" , "channels_first" ]:
26
+ raise ValueError (
27
+ "`data_format` must be one of 'channels_last' or "
28
+ f"'channels_first'. Received: { data_format } "
29
+ )
30
+
31
+ def build (self , input_shape ):
32
+ """Create tensor state in build to respect the layer's dtype."""
33
+ self ._alpha_tensor = ops .convert_to_tensor (self .alpha , dtype = self .compute_dtype )
34
+ self ._sigma_tensor = ops .convert_to_tensor (self .sigma , dtype = self .compute_dtype )
35
+
36
+ # Pre-compute the 1D Gaussian kernel
37
+ kernel_size = ops .cast (2 * ops .round (3 * self ._sigma_tensor ) + 1 , dtype = "int32" )
38
+ ax = ops .arange (- ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 ,
39
+ ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 )
40
+ kernel_1d = ops .exp (- (ax ** 2 ) / (2.0 * self ._sigma_tensor ** 2 ))
41
+ self .kernel_1d = kernel_1d / ops .sum (kernel_1d )
42
+ self .built = True
43
+
44
+ def _separable_gaussian_filter_3d (self , tensor ):
45
+ """Apply a 3D Gaussian filter using separable 1D convolutions."""
46
+ depth_kernel = ops .reshape (self .kernel_1d , (- 1 , 1 , 1 , 1 , 1 ))
47
+ tensor = ops .conv (tensor , ops .cast (depth_kernel , dtype = tensor .dtype ), padding = 'same' )
48
+
49
+ height_kernel = ops .reshape (self .kernel_1d , (1 , - 1 , 1 , 1 , 1 ))
50
+ tensor = ops .conv (tensor , ops .cast (height_kernel , dtype = tensor .dtype ), padding = 'same' )
51
+
52
+ width_kernel = ops .reshape (self .kernel_1d , (1 , 1 , - 1 , 1 , 1 ))
53
+ tensor = ops .conv (tensor , ops .cast (width_kernel , dtype = tensor .dtype ), padding = 'same' )
54
+
35
55
return tensor
36
56
37
57
def call (self , inputs ):
38
58
image_volume , label_volume = inputs
39
59
original_image_dtype = image_volume .dtype
60
+ original_label_dtype = label_volume .dtype
61
+ compute_dtype = self .compute_dtype
40
62
41
63
was_batched = True
42
- if image_volume .shape . rank == 4 :
64
+ if len ( image_volume .shape ) == 4 :
43
65
was_batched = False
44
- image_volume = tf .expand_dims (image_volume , axis = 0 )
45
- label_volume = tf .expand_dims (label_volume , axis = 0 )
66
+ image_volume = ops .expand_dims (image_volume , axis = 0 )
67
+ label_volume = ops .expand_dims (label_volume , axis = 0 )
46
68
47
- if self .data_format == "HWDC" :
48
- image_volume = tf .transpose (image_volume , perm = [0 , 3 , 1 , 2 , 4 ])
49
- label_volume = tf .transpose (label_volume , perm = [0 , 3 , 1 , 2 , 4 ])
69
+ image_volume = ops .cast (image_volume , dtype = compute_dtype )
70
+ label_volume = ops .cast (label_volume , dtype = compute_dtype )
50
71
51
- image_volume = tf .cast (image_volume , dtype = tf .bfloat16 )
52
- input_shape = tf .shape (image_volume )
72
+ input_shape = ops .shape (image_volume )
53
73
B , D , H , W = input_shape [0 ], input_shape [1 ], input_shape [2 ], input_shape [3 ]
74
+ C = input_shape [4 ]
54
75
55
- coarse_flow = tf .random .uniform (
76
+ # 1. Create a coarse random flow field.
77
+ coarse_flow = random .uniform (
56
78
shape = (B , self .grid_size [0 ], self .grid_size [1 ], self .grid_size [2 ], 3 ),
57
- minval = - 1 , maxval = 1 , dtype = tf .bfloat16 )
58
-
59
- flow = tf .reshape (coarse_flow , [B * self .grid_size [0 ], self .grid_size [1 ], self .grid_size [2 ], 3 ])
60
- flow = tf .image .resize (flow , size = [H , W ], method = 'bicubic' )
61
- flow = tf .reshape (flow , [B , self .grid_size [0 ], H , W , 3 ])
62
- flow = tf .transpose (flow , perm = [0 , 2 , 3 , 1 , 4 ])
63
- flow = tf .reshape (flow , [B * H * W , self .grid_size [0 ], 3 ])
64
- flow = tf .image .resize (tf .expand_dims (flow , axis = 1 ), size = [1 , D ], method = 'bicubic' )
65
- flow = tf .squeeze (flow , axis = 1 )
66
- flow = tf .reshape (flow , [B , H , W , D , 3 ])
67
- flow = tf .transpose (flow , perm = [0 , 3 , 1 , 2 , 4 ])
68
-
79
+ minval = - 1 , maxval = 1 , dtype = compute_dtype
80
+ )
69
81
70
- flow = tf .cast (flow , dtype = tf .bfloat16 )
71
-
72
- flow_components = tf .unstack (flow , axis = - 1 )
82
+ # 2. Upsample the flow field.
83
+ flow = coarse_flow
84
+ flow_shape = ops .shape (flow )
85
+ flow = ops .reshape (flow , (flow_shape [0 ] * flow_shape [1 ], flow_shape [2 ], flow_shape [3 ], 3 ))
86
+ flow = ops .image .resize (flow , (H , W ), interpolation = "bicubic" )
87
+ flow = ops .reshape (flow , (flow_shape [0 ], flow_shape [1 ], H , W , 3 ))
88
+ flow = ops .transpose (flow , (0 , 2 , 3 , 1 , 4 ))
89
+ flow_shape = ops .shape (flow )
90
+ flow = ops .reshape (flow , (flow_shape [0 ] * flow_shape [1 ] * flow_shape [2 ], flow_shape [3 ], 1 , 3 ))
91
+ flow = ops .image .resize (flow , (D , 1 ), interpolation = "bicubic" )
92
+ flow = ops .reshape (flow , (flow_shape [0 ], flow_shape [1 ], flow_shape [2 ], D , 3 ))
93
+ flow = ops .transpose (flow , (0 , 3 , 1 , 2 , 4 ))
94
+
95
+ # 3. Apply Gaussian smoothing.
96
+ flow_components = ops .unstack (flow , axis = - 1 )
73
97
smoothed_components = []
74
98
for component in flow_components :
75
- smoothed_component = self ._separable_gaussian_filter_3d (
76
- component [..., tf .newaxis ], self .sigma
77
- )
78
- smoothed_components .append (smoothed_component [..., 0 ])
79
- smoothed_flow = tf .stack (smoothed_components , axis = - 1 )
99
+ component_reshaped = ops .expand_dims (component , axis = - 1 )
100
+ smoothed_component = self ._separable_gaussian_filter_3d (component_reshaped )
101
+ smoothed_components .append (ops .squeeze (smoothed_component , axis = - 1 ))
102
+ smoothed_flow = ops .stack (smoothed_components , axis = - 1 )
80
103
81
-
82
- flow = smoothed_flow * self .alpha
83
-
84
- grid_d , grid_h , grid_w = tf .meshgrid (
85
- tf .range (D , dtype = tf .bfloat16 ),
86
- tf .range (H , dtype = tf .bfloat16 ),
87
- tf .range (W , dtype = tf .bfloat16 ),
104
+ # 4. Scale the flow field and create warp grid.
105
+ flow = smoothed_flow * self ._alpha_tensor
106
+ grid_d , grid_h , grid_w = ops .meshgrid (
107
+ ops .arange (D , dtype = compute_dtype ),
108
+ ops .arange (H , dtype = compute_dtype ),
109
+ ops .arange (W , dtype = compute_dtype ),
88
110
indexing = 'ij'
89
111
)
90
- grid = tf .stack ([grid_d , grid_h , grid_w ], axis = - 1 )
112
+ grid = ops .stack ([grid_d , grid_h , grid_w ], axis = - 1 )
113
+ warp_grid = ops .expand_dims (grid , 0 ) + flow
91
114
92
115
93
- warp_grid = tf .expand_dims (grid , 0 ) + flow
116
+ batched_coords = ops .transpose (warp_grid , (0 , 4 , 1 , 2 , 3 ))
117
+
118
+
119
+ deformed_images_batched = []
120
+ for i in range (B ):
121
+
122
+ image_slice = image_volume [i ]
123
+ coords = batched_coords [i ]
124
+
125
+
126
+ image_slice_transposed = ops .transpose (image_slice , (3 , 0 , 1 , 2 ))
127
+
128
+ deformed_channels = []
129
+ for c in range (C ):
130
+
131
+ deformed_channel = ops .image .map_coordinates (
132
+ image_slice_transposed [c ], coords , order = 1
133
+ )
134
+ deformed_channels .append (deformed_channel )
135
+
136
+ # Stack and transpose back to (D, H, W, C)
137
+ deformed_image_slice = ops .stack (deformed_channels , axis = 0 )
138
+ deformed_images_batched .append (ops .transpose (deformed_image_slice , (1 , 2 , 3 , 0 )))
139
+
140
+ deformed_image = ops .stack (deformed_images_batched , axis = 0 )
141
+
142
+ # Process Labels: loop over the batch dimension.
143
+ deformed_labels_batched = []
144
+ for i in range (B ):
145
+ label_slice = label_volume [i ]
146
+ coords = batched_coords [i ]
147
+
148
+
149
+ label_channel = ops .squeeze (label_slice , axis = - 1 )
150
+ deformed_label_channel = ops .image .map_coordinates (
151
+ label_channel , coords , order = 0
152
+ )
153
+
154
+ deformed_labels_batched .append (ops .expand_dims (deformed_label_channel , axis = - 1 ))
155
+
156
+ deformed_label = ops .stack (deformed_labels_batched , axis = 0 )
94
157
95
- warp_grid_floor = tf .floor (warp_grid )
96
- t = warp_grid - warp_grid_floor
97
-
98
- d0 = tf .cast (warp_grid_floor [..., 0 ], tf .int32 ); h0 = tf .cast (warp_grid_floor [..., 1 ], tf .int32 ); w0 = tf .cast (warp_grid_floor [..., 2 ], tf .int32 )
99
- d1 = tf .clip_by_value (d0 + 1 , 0 , D - 1 ); h1 = tf .clip_by_value (h0 + 1 , 0 , H - 1 ); w1 = tf .clip_by_value (w0 + 1 , 0 , W - 1 )
100
- d0 = tf .clip_by_value (d0 , 0 , D - 1 ); h0 = tf .clip_by_value (h0 , 0 , H - 1 ); w0 = tf .clip_by_value (w0 , 0 , W - 1 )
101
-
102
- c000 = tf .gather_nd (image_volume , tf .stack ([d0 , h0 , w0 ], axis = - 1 ), batch_dims = 1 ); c001 = tf .gather_nd (image_volume , tf .stack ([d0 , h0 , w1 ], axis = - 1 ), batch_dims = 1 )
103
- c010 = tf .gather_nd (image_volume , tf .stack ([d0 , h1 , w0 ], axis = - 1 ), batch_dims = 1 ); c011 = tf .gather_nd (image_volume , tf .stack ([d0 , h1 , w1 ], axis = - 1 ), batch_dims = 1 )
104
- c100 = tf .gather_nd (image_volume , tf .stack ([d1 , h0 , w0 ], axis = - 1 ), batch_dims = 1 ); c101 = tf .gather_nd (image_volume , tf .stack ([d1 , h0 , w1 ], axis = - 1 ), batch_dims = 1 )
105
- c110 = tf .gather_nd (image_volume , tf .stack ([d1 , h1 , w0 ], axis = - 1 ), batch_dims = 1 ); c111 = tf .gather_nd (image_volume , tf .stack ([d1 , h1 , w1 ], axis = - 1 ), batch_dims = 1 )
106
-
107
- td , th , tw = t [..., 0 :1 ], t [..., 1 :2 ], t [..., 2 :3 ]
108
- c00 = c000 * (1 - tw ) + c001 * tw ; c01 = c010 * (1 - tw ) + c011 * tw ; c10 = c100 * (1 - tw ) + c101 * tw ; c11 = c110 * (1 - tw ) + c111 * tw
109
- c0 = c00 * (1 - th ) + c01 * th ; c1 = c10 * (1 - th ) + c11 * th
110
- deformed_image = c0 * (1 - td ) + c1 * td
111
- deformed_image = tf .cast (deformed_image , original_image_dtype )
112
-
113
- nearest_indices_float = tf .round (warp_grid )
114
- nearest_d = tf .clip_by_value (tf .cast (nearest_indices_float [..., 0 ], tf .int32 ), 0 , D - 1 )
115
- nearest_h = tf .clip_by_value (tf .cast (nearest_indices_float [..., 1 ], tf .int32 ), 0 , H - 1 )
116
- nearest_w = tf .clip_by_value (tf .cast (nearest_indices_float [..., 2 ], tf .int32 ), 0 , W - 1 )
117
- deformed_label = tf .gather_nd (label_volume , tf .stack ([nearest_d , nearest_h , nearest_w ], axis = - 1 ), batch_dims = 1 )
118
-
119
- if self .data_format == "HWDC" :
120
- deformed_image = tf .transpose (deformed_image , perm = [0 , 2 , 3 , 1 , 4 ])
121
- deformed_label = tf .transpose (deformed_label , perm = [0 , 2 , 3 , 1 , 4 ])
122
158
123
- if not was_batched :
124
- deformed_image = tf .squeeze (deformed_image , axis = 0 )
125
- deformed_label = tf .squeeze (deformed_label , axis = 0 )
126
159
127
- return deformed_image , deformed_label
160
+ deformed_image = ops .cast (deformed_image , original_image_dtype )
161
+ deformed_label = ops .cast (deformed_label , original_label_dtype )
162
+
163
+ if not was_batched :
164
+ deformed_image = ops .squeeze (deformed_image , axis = 0 )
165
+ deformed_label = ops .squeeze (deformed_label , axis = 0 )
166
+
167
+ return deformed_image , deformed_label
168
+
169
+ def compute_output_shape (self , input_shape ):
170
+ """Computes the output shape of the layer."""
171
+ image_shape , label_shape = input_shape
172
+ return image_shape , label_shape
173
+
174
+ def get_config (self ):
175
+ config = super ().get_config ()
176
+ config .update ({
177
+ "grid_size" : self .grid_size ,
178
+ "alpha" : self .alpha ,
179
+ "sigma" : self .sigma ,
180
+ "data_format" : self .data_format ,
181
+ })
182
+ return config
0 commit comments