1
+ # Add this import
2
+ from keras import backend
1
3
from keras import ops
2
4
from keras import layers
3
5
from keras import random
4
6
5
7
class RandomElasticDeformation3D (layers .Layer ):
6
8
"""
7
9
A high-performance 3D elastic deformation layer optimized for TPUs.
8
-
9
- This implementation leverages the layer's compute_dtype (e.g., bfloat16)
10
- to potentially halve memory bandwidth requirements and uses a vectorized
11
- mapping for maximum speed.
12
10
"""
13
11
def __init__ (self ,
14
12
grid_size = (4 , 4 , 4 ),
@@ -17,41 +15,29 @@ def __init__(self,
17
15
data_format = "channels_last" ,
18
16
** kwargs ):
19
17
super ().__init__ (** kwargs )
20
-
21
18
self .grid_size = grid_size
22
19
self .alpha = alpha
23
20
self .sigma = sigma
24
21
self .data_format = data_format
25
22
if data_format not in ["channels_last" , "channels_first" ]:
26
- raise ValueError (
27
- "`data_format` must be one of 'channels_last' or "
28
- f"'channels_first'. Received: { data_format } "
29
- )
30
-
23
+ raise ValueError (f"`data_format` must be one of 'channels_last' or 'channels_first'. Received: { data_format } " )
24
+
31
25
def build (self , input_shape ):
32
- """Create tensor state in build to respect the layer's dtype."""
33
26
self ._alpha_tensor = ops .convert_to_tensor (self .alpha , dtype = self .compute_dtype )
34
27
self ._sigma_tensor = ops .convert_to_tensor (self .sigma , dtype = self .compute_dtype )
35
-
36
- # Pre-compute the 1D Gaussian kernel
37
28
kernel_size = ops .cast (2 * ops .round (3 * self ._sigma_tensor ) + 1 , dtype = "int32" )
38
- ax = ops .arange (- ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 ,
39
- ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 )
29
+ ax = ops .arange (- ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 , ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 )
40
30
kernel_1d = ops .exp (- (ax ** 2 ) / (2.0 * self ._sigma_tensor ** 2 ))
41
31
self .kernel_1d = kernel_1d / ops .sum (kernel_1d )
42
32
self .built = True
43
33
44
34
def _separable_gaussian_filter_3d (self , tensor ):
45
- """Apply a 3D Gaussian filter using separable 1D convolutions."""
46
35
depth_kernel = ops .reshape (self .kernel_1d , (- 1 , 1 , 1 , 1 , 1 ))
47
36
tensor = ops .conv (tensor , ops .cast (depth_kernel , dtype = tensor .dtype ), padding = 'same' )
48
-
49
37
height_kernel = ops .reshape (self .kernel_1d , (1 , - 1 , 1 , 1 , 1 ))
50
38
tensor = ops .conv (tensor , ops .cast (height_kernel , dtype = tensor .dtype ), padding = 'same' )
51
-
52
39
width_kernel = ops .reshape (self .kernel_1d , (1 , 1 , - 1 , 1 , 1 ))
53
40
tensor = ops .conv (tensor , ops .cast (width_kernel , dtype = tensor .dtype ), padding = 'same' )
54
-
55
41
return tensor
56
42
57
43
def call (self , inputs ):
@@ -70,16 +56,10 @@ def call(self, inputs):
70
56
label_volume = ops .cast (label_volume , dtype = compute_dtype )
71
57
72
58
input_shape = ops .shape (image_volume )
73
- B , D , H , W = input_shape [0 ], input_shape [1 ], input_shape [2 ], input_shape [3 ]
74
- C = input_shape [4 ]
75
-
76
- # 1. Create a coarse random flow field.
77
- coarse_flow = random .uniform (
78
- shape = (B , self .grid_size [0 ], self .grid_size [1 ], self .grid_size [2 ], 3 ),
79
- minval = - 1 , maxval = 1 , dtype = compute_dtype
80
- )
81
-
82
- # 2. Upsample the flow field.
59
+ B , D , H , W , C = input_shape [0 ], input_shape [1 ], input_shape [2 ], input_shape [3 ], input_shape [4 ]
60
+
61
+ coarse_flow = random .uniform (shape = (B , self .grid_size [0 ], self .grid_size [1 ], self .grid_size [2 ], 3 ), minval = - 1 , maxval = 1 , dtype = compute_dtype )
62
+
83
63
flow = coarse_flow
84
64
flow_shape = ops .shape (flow )
85
65
flow = ops .reshape (flow , (flow_shape [0 ] * flow_shape [1 ], flow_shape [2 ], flow_shape [3 ], 3 ))
@@ -91,71 +71,49 @@ def call(self, inputs):
91
71
flow = ops .image .resize (flow , (D , 1 ), interpolation = "bicubic" )
92
72
flow = ops .reshape (flow , (flow_shape [0 ], flow_shape [1 ], flow_shape [2 ], D , 3 ))
93
73
flow = ops .transpose (flow , (0 , 3 , 1 , 2 , 4 ))
94
-
95
- # 3. Apply Gaussian smoothing.
74
+
96
75
flow_components = ops .unstack (flow , axis = - 1 )
97
76
smoothed_components = []
98
77
for component in flow_components :
99
- component_reshaped = ops .expand_dims (component , axis = - 1 )
100
- smoothed_component = self ._separable_gaussian_filter_3d (component_reshaped )
101
- smoothed_components .append (ops .squeeze (smoothed_component , axis = - 1 ))
78
+ smoothed_components .append (ops .squeeze (self ._separable_gaussian_filter_3d (ops .expand_dims (component , axis = - 1 )), axis = - 1 ))
102
79
smoothed_flow = ops .stack (smoothed_components , axis = - 1 )
103
80
104
- # 4. Scale the flow field and create warp grid.
105
81
flow = smoothed_flow * self ._alpha_tensor
106
- grid_d , grid_h , grid_w = ops .meshgrid (
107
- ops .arange (D , dtype = compute_dtype ),
108
- ops .arange (H , dtype = compute_dtype ),
109
- ops .arange (W , dtype = compute_dtype ),
110
- indexing = 'ij'
111
- )
82
+ grid_d , grid_h , grid_w = ops .meshgrid (ops .arange (D , dtype = compute_dtype ), ops .arange (H , dtype = compute_dtype ), ops .arange (W , dtype = compute_dtype ), indexing = 'ij' )
112
83
grid = ops .stack ([grid_d , grid_h , grid_w ], axis = - 1 )
113
84
warp_grid = ops .expand_dims (grid , 0 ) + flow
114
85
115
-
116
86
batched_coords = ops .transpose (warp_grid , (0 , 4 , 1 , 2 , 3 ))
117
87
118
-
119
- deformed_images_batched = []
120
- for i in range (B ):
121
-
122
- image_slice = image_volume [i ]
123
- coords = batched_coords [i ]
124
-
125
-
126
- image_slice_transposed = ops .transpose (image_slice , (3 , 0 , 1 , 2 ))
127
-
88
+ def perform_map (elems ):
89
+ image_slice , label_slice , coords = elems
128
90
deformed_channels = []
91
+ image_slice_transposed = ops .transpose (image_slice , (3 , 0 , 1 , 2 ))
92
+ # The channel dimension C is a static value when the graph is built
129
93
for c in range (C ):
130
-
131
- deformed_channel = ops .image .map_coordinates (
132
- image_slice_transposed [c ], coords , order = 1
133
- )
134
- deformed_channels .append (deformed_channel )
135
-
136
- # Stack and transpose back to (D, H, W, C)
94
+ deformed_channels .append (ops .image .map_coordinates (image_slice_transposed [c ], coords , order = 1 ))
137
95
deformed_image_slice = ops .stack (deformed_channels , axis = 0 )
138
- deformed_images_batched .append (ops .transpose (deformed_image_slice , (1 , 2 , 3 , 0 )))
139
-
140
- deformed_image = ops .stack (deformed_images_batched , axis = 0 )
141
-
142
- # Process Labels: loop over the batch dimension.
143
- deformed_labels_batched = []
144
- for i in range (B ):
145
- label_slice = label_volume [i ]
146
- coords = batched_coords [i ]
147
-
148
-
96
+ deformed_image_slice = ops .transpose (deformed_image_slice , (1 , 2 , 3 , 0 ))
149
97
label_channel = ops .squeeze (label_slice , axis = - 1 )
150
- deformed_label_channel = ops .image .map_coordinates (
151
- label_channel , coords , order = 0
152
- )
153
-
154
- deformed_labels_batched .append (ops .expand_dims (deformed_label_channel , axis = - 1 ))
155
-
156
- deformed_label = ops .stack (deformed_labels_batched , axis = 0 )
157
-
158
-
98
+ deformed_label_channel = ops .image .map_coordinates (label_channel , coords , order = 0 )
99
+ deformed_label_slice = ops .expand_dims (deformed_label_channel , axis = - 1 )
100
+ return deformed_image_slice , deformed_label_slice
101
+
102
+ if backend .backend () == "tensorflow" :
103
+ import tensorflow as tf
104
+ deformed_image , deformed_label = tf .map_fn (perform_map , elems = (image_volume , label_volume , batched_coords ), dtype = (compute_dtype , compute_dtype ))
105
+ elif backend .backend () == "jax" :
106
+ import jax
107
+ deformed_image , deformed_label = jax .lax .map (perform_map , xs = (image_volume , label_volume , batched_coords ))
108
+ else :
109
+ deformed_images_list = []
110
+ deformed_labels_list = []
111
+ for i in range (B ):
112
+ img_slice , lbl_slice = perform_map ((image_volume [i ], label_volume [i ], batched_coords [i ]))
113
+ deformed_images_list .append (img_slice )
114
+ deformed_labels_list .append (lbl_slice )
115
+ deformed_image = ops .stack (deformed_images_list , axis = 0 )
116
+ deformed_label = ops .stack (deformed_labels_list , axis = 0 )
159
117
160
118
deformed_image = ops .cast (deformed_image , original_image_dtype )
161
119
deformed_label = ops .cast (deformed_label , original_label_dtype )
@@ -167,16 +125,10 @@ def call(self, inputs):
167
125
return deformed_image , deformed_label
168
126
169
127
def compute_output_shape (self , input_shape ):
170
- """Computes the output shape of the layer."""
171
128
image_shape , label_shape = input_shape
172
129
return image_shape , label_shape
173
130
174
131
def get_config (self ):
175
132
config = super ().get_config ()
176
- config .update ({
177
- "grid_size" : self .grid_size ,
178
- "alpha" : self .alpha ,
179
- "sigma" : self .sigma ,
180
- "data_format" : self .data_format ,
181
- })
133
+ config .update ({"grid_size" : self .grid_size , "alpha" : self .alpha , "sigma" : self .sigma , "data_format" : self .data_format })
182
134
return config
0 commit comments