1
1
# Add this import
2
2
from keras import backend
3
- from keras import ops
4
3
from keras import layers
4
+ from keras import ops
5
5
from keras import random
6
6
7
+
7
8
class RandomElasticDeformation3D (layers .Layer ):
8
9
"""
9
10
A high-performance 3D elastic deformation layer optimized for TPUs.
10
11
"""
11
12
12
- def __init__ (self ,
13
- grid_size = (4 , 4 , 4 ),
14
- alpha = 35.0 ,
15
- sigma = 2.5 ,
16
- data_format = "channels_last" ,
17
- ** kwargs ):
13
+ def __init__ (
14
+ self ,
15
+ grid_size = (4 , 4 , 4 ),
16
+ alpha = 35.0 ,
17
+ sigma = 2.5 ,
18
+ data_format = "channels_last" ,
19
+ ** kwargs ,
20
+ ):
18
21
super ().__init__ (** kwargs )
19
22
self .grid_size = grid_size
20
23
self .alpha = alpha
@@ -28,21 +31,36 @@ def __init__(self,
28
31
raise ValueError (message )
29
32
30
33
def build (self , input_shape ):
31
- self ._alpha_tensor = ops .convert_to_tensor (self .alpha , dtype = self .compute_dtype )
32
- self ._sigma_tensor = ops .convert_to_tensor (self .sigma , dtype = self .compute_dtype )
33
- kernel_size = ops .cast (2 * ops .round (3 * self ._sigma_tensor ) + 1 , dtype = "int32" )
34
- ax = ops .arange (- ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 , ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 )
34
+ self ._alpha_tensor = ops .convert_to_tensor (
35
+ self .alpha , dtype = self .compute_dtype
36
+ )
37
+ self ._sigma_tensor = ops .convert_to_tensor (
38
+ self .sigma , dtype = self .compute_dtype
39
+ )
40
+ kernel_size = ops .cast (
41
+ 2 * ops .round (3 * self ._sigma_tensor ) + 1 , dtype = "int32"
42
+ )
43
+ ax = ops .arange (
44
+ - ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 ,
45
+ ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 ,
46
+ )
35
47
kernel_1d = ops .exp (- (ax ** 2 ) / (2.0 * self ._sigma_tensor ** 2 ))
36
48
self .kernel_1d = kernel_1d / ops .sum (kernel_1d )
37
49
self .built = True
38
50
39
51
def _separable_gaussian_filter_3d (self , tensor ):
40
52
depth_kernel = ops .reshape (self .kernel_1d , (- 1 , 1 , 1 , 1 , 1 ))
41
- tensor = ops .conv (tensor , ops .cast (depth_kernel , dtype = tensor .dtype ), padding = 'same' )
53
+ tensor = ops .conv (
54
+ tensor , ops .cast (depth_kernel , dtype = tensor .dtype ), padding = "same"
55
+ )
42
56
height_kernel = ops .reshape (self .kernel_1d , (1 , - 1 , 1 , 1 , 1 ))
43
- tensor = ops .conv (tensor , ops .cast (height_kernel , dtype = tensor .dtype ), padding = 'same' )
57
+ tensor = ops .conv (
58
+ tensor , ops .cast (height_kernel , dtype = tensor .dtype ), padding = "same"
59
+ )
44
60
width_kernel = ops .reshape (self .kernel_1d , (1 , 1 , - 1 , 1 , 1 ))
45
- tensor = ops .conv (tensor , ops .cast (width_kernel , dtype = tensor .dtype ), padding = 'same' )
61
+ tensor = ops .conv (
62
+ tensor , ops .cast (width_kernel , dtype = tensor .dtype ), padding = "same"
63
+ )
46
64
return tensor
47
65
48
66
def call (self , inputs ):
@@ -61,33 +79,75 @@ def call(self, inputs):
61
79
label_volume = ops .cast (label_volume , dtype = compute_dtype )
62
80
63
81
input_shape = ops .shape (image_volume )
64
- B , D , H , W , C = input_shape [0 ], input_shape [1 ], input_shape [2 ], input_shape [3 ], input_shape [4 ]
65
-
66
- coarse_flow = random .uniform (shape = (B , self .grid_size [0 ], self .grid_size [1 ], self .grid_size [2 ], 3 ), minval = - 1 , maxval = 1 , dtype = compute_dtype )
67
-
82
+ B , D , H , W , C = (
83
+ input_shape [0 ],
84
+ input_shape [1 ],
85
+ input_shape [2 ],
86
+ input_shape [3 ],
87
+ input_shape [4 ],
88
+ )
89
+
90
+ coarse_flow = random .uniform (
91
+ shape = (
92
+ B ,
93
+ self .grid_size [0 ],
94
+ self .grid_size [1 ],
95
+ self .grid_size [2 ],
96
+ 3 ,
97
+ ),
98
+ minval = - 1 ,
99
+ maxval = 1 ,
100
+ dtype = compute_dtype ,
101
+ )
102
+
68
103
flow = coarse_flow
69
104
flow_shape = ops .shape (flow )
70
- flow = ops .reshape (flow , (flow_shape [0 ] * flow_shape [1 ], flow_shape [2 ], flow_shape [3 ], 3 ))
105
+ flow = ops .reshape (
106
+ flow ,
107
+ (flow_shape [0 ] * flow_shape [1 ], flow_shape [2 ], flow_shape [3 ], 3 ),
108
+ )
71
109
flow = ops .image .resize (flow , (H , W ), interpolation = "bicubic" )
72
110
flow = ops .reshape (flow , (flow_shape [0 ], flow_shape [1 ], H , W , 3 ))
73
111
flow = ops .transpose (flow , (0 , 2 , 3 , 1 , 4 ))
74
112
flow_shape = ops .shape (flow )
75
- flow = ops .reshape (flow , (flow_shape [0 ] * flow_shape [1 ] * flow_shape [2 ], flow_shape [3 ], 1 , 3 ))
113
+ flow = ops .reshape (
114
+ flow ,
115
+ (
116
+ flow_shape [0 ] * flow_shape [1 ] * flow_shape [2 ],
117
+ flow_shape [3 ],
118
+ 1 ,
119
+ 3 ,
120
+ ),
121
+ )
76
122
flow = ops .image .resize (flow , (D , 1 ), interpolation = "bicubic" )
77
- flow = ops .reshape (flow , (flow_shape [0 ], flow_shape [1 ], flow_shape [2 ], D , 3 ))
123
+ flow = ops .reshape (
124
+ flow , (flow_shape [0 ], flow_shape [1 ], flow_shape [2 ], D , 3 )
125
+ )
78
126
flow = ops .transpose (flow , (0 , 3 , 1 , 2 , 4 ))
79
-
127
+
80
128
flow_components = ops .unstack (flow , axis = - 1 )
81
129
smoothed_components = []
82
130
for component in flow_components :
83
- smoothed_components .append (ops .squeeze (self ._separable_gaussian_filter_3d (ops .expand_dims (component , axis = - 1 )), axis = - 1 ))
131
+ smoothed_components .append (
132
+ ops .squeeze (
133
+ self ._separable_gaussian_filter_3d (
134
+ ops .expand_dims (component , axis = - 1 )
135
+ ),
136
+ axis = - 1 ,
137
+ )
138
+ )
84
139
smoothed_flow = ops .stack (smoothed_components , axis = - 1 )
85
-
140
+
86
141
flow = smoothed_flow * self ._alpha_tensor
87
- grid_d , grid_h , grid_w = ops .meshgrid (ops .arange (D , dtype = compute_dtype ), ops .arange (H , dtype = compute_dtype ), ops .arange (W , dtype = compute_dtype ), indexing = 'ij' )
142
+ grid_d , grid_h , grid_w = ops .meshgrid (
143
+ ops .arange (D , dtype = compute_dtype ),
144
+ ops .arange (H , dtype = compute_dtype ),
145
+ ops .arange (W , dtype = compute_dtype ),
146
+ indexing = "ij" ,
147
+ )
88
148
grid = ops .stack ([grid_d , grid_h , grid_w ], axis = - 1 )
89
149
warp_grid = ops .expand_dims (grid , 0 ) + flow
90
-
150
+
91
151
batched_coords = ops .transpose (warp_grid , (0 , 4 , 1 , 2 , 3 ))
92
152
93
153
def perform_map (elems ):
@@ -96,25 +156,45 @@ def perform_map(elems):
96
156
image_slice_transposed = ops .transpose (image_slice , (3 , 0 , 1 , 2 ))
97
157
# The channel dimension C is a static value when the graph is built
98
158
for c in range (C ):
99
- deformed_channels .append (ops .image .map_coordinates (image_slice_transposed [c ], coords , order = 1 ))
159
+ deformed_channels .append (
160
+ ops .image .map_coordinates (
161
+ image_slice_transposed [c ], coords , order = 1
162
+ )
163
+ )
100
164
deformed_image_slice = ops .stack (deformed_channels , axis = 0 )
101
- deformed_image_slice = ops .transpose (deformed_image_slice , (1 , 2 , 3 , 0 ))
165
+ deformed_image_slice = ops .transpose (
166
+ deformed_image_slice , (1 , 2 , 3 , 0 )
167
+ )
102
168
label_channel = ops .squeeze (label_slice , axis = - 1 )
103
- deformed_label_channel = ops .image .map_coordinates (label_channel , coords , order = 0 )
104
- deformed_label_slice = ops .expand_dims (deformed_label_channel , axis = - 1 )
169
+ deformed_label_channel = ops .image .map_coordinates (
170
+ label_channel , coords , order = 0
171
+ )
172
+ deformed_label_slice = ops .expand_dims (
173
+ deformed_label_channel , axis = - 1
174
+ )
105
175
return deformed_image_slice , deformed_label_slice
106
176
107
177
if backend .backend () == "tensorflow" :
108
178
import tensorflow as tf
109
- deformed_image , deformed_label = tf .map_fn (perform_map , elems = (image_volume , label_volume , batched_coords ), dtype = (compute_dtype , compute_dtype ))
179
+
180
+ deformed_image , deformed_label = tf .map_fn (
181
+ perform_map ,
182
+ elems = (image_volume , label_volume , batched_coords ),
183
+ dtype = (compute_dtype , compute_dtype ),
184
+ )
110
185
elif backend .backend () == "jax" :
111
186
import jax
112
- deformed_image , deformed_label = jax .lax .map (perform_map , xs = (image_volume , label_volume , batched_coords ))
187
+
188
+ deformed_image , deformed_label = jax .lax .map (
189
+ perform_map , xs = (image_volume , label_volume , batched_coords )
190
+ )
113
191
else :
114
192
deformed_images_list = []
115
193
deformed_labels_list = []
116
194
for i in range (B ):
117
- img_slice , lbl_slice = perform_map ((image_volume [i ], label_volume [i ], batched_coords [i ]))
195
+ img_slice , lbl_slice = perform_map (
196
+ (image_volume [i ], label_volume [i ], batched_coords [i ])
197
+ )
118
198
deformed_images_list .append (img_slice )
119
199
deformed_labels_list .append (lbl_slice )
120
200
deformed_image = ops .stack (deformed_images_list , axis = 0 )
@@ -135,5 +215,12 @@ def compute_output_shape(self, input_shape):
135
215
136
216
def get_config (self ):
137
217
config = super ().get_config ()
138
- config .update ({"grid_size" : self .grid_size , "alpha" : self .alpha , "sigma" : self .sigma , "data_format" : self .data_format })
139
- return config
218
+ config .update (
219
+ {
220
+ "grid_size" : self .grid_size ,
221
+ "alpha" : self .alpha ,
222
+ "sigma" : self .sigma ,
223
+ "data_format" : self .data_format ,
224
+ }
225
+ )
226
+ return config
0 commit comments