1
1
# Add this import
2
2
from keras import backend
3
- from keras import ops
4
3
from keras import layers
4
+ from keras import ops
5
5
from keras import random
6
6
7
+
7
8
class RandomElasticDeformation3D (layers .Layer ):
8
9
"""
9
10
A high-performance 3D elastic deformation layer optimized for TPUs.
10
11
"""
11
12
12
- def __init__ (self ,
13
- grid_size = (4 , 4 , 4 ),
14
- alpha = 35.0 ,
15
- sigma = 2.5 ,
16
- data_format = "channels_last" ,
17
- ** kwargs ):
13
+ def __init__ (
14
+ self ,
15
+ grid_size = (4 , 4 , 4 ),
16
+ alpha = 35.0 ,
17
+ sigma = 2.5 ,
18
+ data_format = "channels_last" ,
19
+ seed = None ,
20
+ ** kwargs ,
21
+ ):
18
22
super ().__init__ (** kwargs )
19
23
self .grid_size = grid_size
24
+ self .seed = seed
20
25
self .alpha = alpha
21
26
self .sigma = sigma
22
27
self .data_format = data_format
28
+ self ._rng = random .SeedGenerator (seed ) if seed is not None else None
23
29
if data_format not in ["channels_last" , "channels_first" ]:
24
30
message = (
25
31
"`data_format` must be one of 'channels_last' or "
@@ -28,21 +34,36 @@ def __init__(self,
28
34
raise ValueError (message )
29
35
30
36
def build (self , input_shape ):
31
- self ._alpha_tensor = ops .convert_to_tensor (self .alpha , dtype = self .compute_dtype )
32
- self ._sigma_tensor = ops .convert_to_tensor (self .sigma , dtype = self .compute_dtype )
33
- kernel_size = ops .cast (2 * ops .round (3 * self ._sigma_tensor ) + 1 , dtype = "int32" )
34
- ax = ops .arange (- ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 , ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 )
37
+ self ._alpha_tensor = ops .convert_to_tensor (
38
+ self .alpha , dtype = self .compute_dtype
39
+ )
40
+ self ._sigma_tensor = ops .convert_to_tensor (
41
+ self .sigma , dtype = self .compute_dtype
42
+ )
43
+ kernel_size = ops .cast (
44
+ 2 * ops .round (3 * self ._sigma_tensor ) + 1 , dtype = "int32"
45
+ )
46
+ ax = ops .arange (
47
+ - ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 ,
48
+ ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 ,
49
+ )
35
50
kernel_1d = ops .exp (- (ax ** 2 ) / (2.0 * self ._sigma_tensor ** 2 ))
36
51
self .kernel_1d = kernel_1d / ops .sum (kernel_1d )
37
52
self .built = True
38
53
39
54
def _separable_gaussian_filter_3d (self , tensor ):
40
55
depth_kernel = ops .reshape (self .kernel_1d , (- 1 , 1 , 1 , 1 , 1 ))
41
- tensor = ops .conv (tensor , ops .cast (depth_kernel , dtype = tensor .dtype ), padding = 'same' )
56
+ tensor = ops .conv (
57
+ tensor , ops .cast (depth_kernel , dtype = tensor .dtype ), padding = "same"
58
+ )
42
59
height_kernel = ops .reshape (self .kernel_1d , (1 , - 1 , 1 , 1 , 1 ))
43
- tensor = ops .conv (tensor , ops .cast (height_kernel , dtype = tensor .dtype ), padding = 'same' )
60
+ tensor = ops .conv (
61
+ tensor , ops .cast (height_kernel , dtype = tensor .dtype ), padding = "same"
62
+ )
44
63
width_kernel = ops .reshape (self .kernel_1d , (1 , 1 , - 1 , 1 , 1 ))
45
- tensor = ops .conv (tensor , ops .cast (width_kernel , dtype = tensor .dtype ), padding = 'same' )
64
+ tensor = ops .conv (
65
+ tensor , ops .cast (width_kernel , dtype = tensor .dtype ), padding = "same"
66
+ )
46
67
return tensor
47
68
48
69
def call (self , inputs ):
@@ -61,33 +82,90 @@ def call(self, inputs):
61
82
label_volume = ops .cast (label_volume , dtype = compute_dtype )
62
83
63
84
input_shape = ops .shape (image_volume )
64
- B , D , H , W , C = input_shape [0 ], input_shape [1 ], input_shape [2 ], input_shape [3 ], input_shape [4 ]
65
-
66
- coarse_flow = random .uniform (shape = (B , self .grid_size [0 ], self .grid_size [1 ], self .grid_size [2 ], 3 ), minval = - 1 , maxval = 1 , dtype = compute_dtype )
67
-
85
+ B , D , H , W , C = (
86
+ input_shape [0 ],
87
+ input_shape [1 ],
88
+ input_shape [2 ],
89
+ input_shape [3 ],
90
+ input_shape [4 ],
91
+ )
92
+
93
+ if self ._rng is not None :
94
+ coarse_flow = random .uniform (
95
+ shape = (
96
+ B ,
97
+ self .grid_size [0 ],
98
+ self .grid_size [1 ],
99
+ self .grid_size [2 ],
100
+ 3 ,
101
+ ),
102
+ minval = - 1 ,
103
+ maxval = 1 ,
104
+ dtype = compute_dtype ,
105
+ seed = self ._rng ,
106
+ )
107
+ else :
108
+ coarse_flow = random .uniform (
109
+ shape = (
110
+ B ,
111
+ self .grid_size [0 ],
112
+ self .grid_size [1 ],
113
+ self .grid_size [2 ],
114
+ 3 ,
115
+ ),
116
+ minval = - 1 ,
117
+ maxval = 1 ,
118
+ dtype = compute_dtype ,
119
+ )
120
+
68
121
flow = coarse_flow
69
122
flow_shape = ops .shape (flow )
70
- flow = ops .reshape (flow , (flow_shape [0 ] * flow_shape [1 ], flow_shape [2 ], flow_shape [3 ], 3 ))
123
+ flow = ops .reshape (
124
+ flow ,
125
+ (flow_shape [0 ] * flow_shape [1 ], flow_shape [2 ], flow_shape [3 ], 3 ),
126
+ )
71
127
flow = ops .image .resize (flow , (H , W ), interpolation = "bicubic" )
72
128
flow = ops .reshape (flow , (flow_shape [0 ], flow_shape [1 ], H , W , 3 ))
73
129
flow = ops .transpose (flow , (0 , 2 , 3 , 1 , 4 ))
74
130
flow_shape = ops .shape (flow )
75
- flow = ops .reshape (flow , (flow_shape [0 ] * flow_shape [1 ] * flow_shape [2 ], flow_shape [3 ], 1 , 3 ))
131
+ flow = ops .reshape (
132
+ flow ,
133
+ (
134
+ flow_shape [0 ] * flow_shape [1 ] * flow_shape [2 ],
135
+ flow_shape [3 ],
136
+ 1 ,
137
+ 3 ,
138
+ ),
139
+ )
76
140
flow = ops .image .resize (flow , (D , 1 ), interpolation = "bicubic" )
77
- flow = ops .reshape (flow , (flow_shape [0 ], flow_shape [1 ], flow_shape [2 ], D , 3 ))
141
+ flow = ops .reshape (
142
+ flow , (flow_shape [0 ], flow_shape [1 ], flow_shape [2 ], D , 3 )
143
+ )
78
144
flow = ops .transpose (flow , (0 , 3 , 1 , 2 , 4 ))
79
-
145
+
80
146
flow_components = ops .unstack (flow , axis = - 1 )
81
147
smoothed_components = []
82
148
for component in flow_components :
83
- smoothed_components .append (ops .squeeze (self ._separable_gaussian_filter_3d (ops .expand_dims (component , axis = - 1 )), axis = - 1 ))
149
+ smoothed_components .append (
150
+ ops .squeeze (
151
+ self ._separable_gaussian_filter_3d (
152
+ ops .expand_dims (component , axis = - 1 )
153
+ ),
154
+ axis = - 1 ,
155
+ )
156
+ )
84
157
smoothed_flow = ops .stack (smoothed_components , axis = - 1 )
85
-
158
+
86
159
flow = smoothed_flow * self ._alpha_tensor
87
- grid_d , grid_h , grid_w = ops .meshgrid (ops .arange (D , dtype = compute_dtype ), ops .arange (H , dtype = compute_dtype ), ops .arange (W , dtype = compute_dtype ), indexing = 'ij' )
160
+ grid_d , grid_h , grid_w = ops .meshgrid (
161
+ ops .arange (D , dtype = compute_dtype ),
162
+ ops .arange (H , dtype = compute_dtype ),
163
+ ops .arange (W , dtype = compute_dtype ),
164
+ indexing = "ij" ,
165
+ )
88
166
grid = ops .stack ([grid_d , grid_h , grid_w ], axis = - 1 )
89
167
warp_grid = ops .expand_dims (grid , 0 ) + flow
90
-
168
+
91
169
batched_coords = ops .transpose (warp_grid , (0 , 4 , 1 , 2 , 3 ))
92
170
93
171
def perform_map (elems ):
@@ -96,25 +174,45 @@ def perform_map(elems):
96
174
image_slice_transposed = ops .transpose (image_slice , (3 , 0 , 1 , 2 ))
97
175
# The channel dimension C is a static value when the graph is built
98
176
for c in range (C ):
99
- deformed_channels .append (ops .image .map_coordinates (image_slice_transposed [c ], coords , order = 1 ))
177
+ deformed_channels .append (
178
+ ops .image .map_coordinates (
179
+ image_slice_transposed [c ], coords , order = 1
180
+ )
181
+ )
100
182
deformed_image_slice = ops .stack (deformed_channels , axis = 0 )
101
- deformed_image_slice = ops .transpose (deformed_image_slice , (1 , 2 , 3 , 0 ))
183
+ deformed_image_slice = ops .transpose (
184
+ deformed_image_slice , (1 , 2 , 3 , 0 )
185
+ )
102
186
label_channel = ops .squeeze (label_slice , axis = - 1 )
103
- deformed_label_channel = ops .image .map_coordinates (label_channel , coords , order = 0 )
104
- deformed_label_slice = ops .expand_dims (deformed_label_channel , axis = - 1 )
187
+ deformed_label_channel = ops .image .map_coordinates (
188
+ label_channel , coords , order = 0
189
+ )
190
+ deformed_label_slice = ops .expand_dims (
191
+ deformed_label_channel , axis = - 1
192
+ )
105
193
return deformed_image_slice , deformed_label_slice
106
194
107
195
if backend .backend () == "tensorflow" :
108
196
import tensorflow as tf
109
- deformed_image , deformed_label = tf .map_fn (perform_map , elems = (image_volume , label_volume , batched_coords ), dtype = (compute_dtype , compute_dtype ))
197
+
198
+ deformed_image , deformed_label = tf .map_fn (
199
+ perform_map ,
200
+ elems = (image_volume , label_volume , batched_coords ),
201
+ dtype = (compute_dtype , compute_dtype ),
202
+ )
110
203
elif backend .backend () == "jax" :
111
204
import jax
112
- deformed_image , deformed_label = jax .lax .map (perform_map , xs = (image_volume , label_volume , batched_coords ))
205
+
206
+ deformed_image , deformed_label = jax .lax .map (
207
+ perform_map , xs = (image_volume , label_volume , batched_coords )
208
+ )
113
209
else :
114
210
deformed_images_list = []
115
211
deformed_labels_list = []
116
212
for i in range (B ):
117
- img_slice , lbl_slice = perform_map ((image_volume [i ], label_volume [i ], batched_coords [i ]))
213
+ img_slice , lbl_slice = perform_map (
214
+ (image_volume [i ], label_volume [i ], batched_coords [i ])
215
+ )
118
216
deformed_images_list .append (img_slice )
119
217
deformed_labels_list .append (lbl_slice )
120
218
deformed_image = ops .stack (deformed_images_list , axis = 0 )
@@ -135,5 +233,13 @@ def compute_output_shape(self, input_shape):
135
233
136
234
def get_config (self ):
137
235
config = super ().get_config ()
138
- config .update ({"grid_size" : self .grid_size , "alpha" : self .alpha , "sigma" : self .sigma , "data_format" : self .data_format })
139
- return config
236
+ config .update (
237
+ {
238
+ "grid_size" : self .grid_size ,
239
+ "alpha" : self .alpha ,
240
+ "sigma" : self .sigma ,
241
+ "data_format" : self .data_format ,
242
+ "seed" : self .seed ,
243
+ }
244
+ )
245
+ return config
0 commit comments