1
+ import tensorflow as tf
2
+
3
+ class RandomElasticDeformation3D (tf .keras .layers .Layer ):
4
+ """
5
+ A high-performance 3D elastic deformation layer optimized for TPUs and GPUs.
6
+ ... (docstring is the same) ...
7
+ """
8
+ def __init__ (self ,
9
+ grid_size = (4 , 4 , 4 ),
10
+ alpha = 35.0 ,
11
+ sigma = 2.5 ,
12
+ data_format = "DHWC" ,
13
+ ** kwargs ):
14
+ super ().__init__ (** kwargs )
15
+ self .grid_size = grid_size
16
+ self .alpha = tf .constant (alpha , dtype = tf .bfloat16 )
17
+ self .sigma = tf .constant (sigma , dtype = tf .bfloat16 )
18
+ if data_format not in ["DHWC" , "HWDC" ]:
19
+ raise ValueError ("`data_format` must be one of 'DHWC' or 'HWDC'" )
20
+ self .data_format = data_format
21
+
22
+ def _separable_gaussian_filter_3d (self , tensor , sigma ):
23
+
24
+ kernel_size = tf .cast (2 * tf .round (3 * sigma ) + 1 , dtype = tf .int32 )
25
+ ax = tf .range (- tf .cast (kernel_size // 2 , tf .bfloat16 ) + 1.0 ,
26
+ tf .cast (kernel_size // 2 , tf .bfloat16 ) + 1.0 )
27
+ kernel_1d = tf .exp (- (ax ** 2 ) / (2.0 * self .sigma ** 2 ))
28
+ kernel_1d = kernel_1d / tf .reduce_sum (kernel_1d )
29
+ filter_d = tf .cast (tf .reshape (kernel_1d , [- 1 , 1 , 1 , 1 , 1 ]), dtype = tensor .dtype )
30
+ filter_h = tf .cast (tf .reshape (kernel_1d , [1 , - 1 , 1 , 1 , 1 ]), dtype = tensor .dtype )
31
+ filter_w = tf .cast (tf .reshape (kernel_1d , [1 , 1 , - 1 , 1 , 1 ]), dtype = tensor .dtype )
32
+ tensor = tf .nn .convolution (tensor , filter_d , strides = 1 , padding = 'SAME' )
33
+ tensor = tf .nn .convolution (tensor , filter_h , strides = 1 , padding = 'SAME' )
34
+ tensor = tf .nn .convolution (tensor , filter_w , strides = 1 , padding = 'SAME' )
35
+ return tensor
36
+
37
+ def call (self , inputs ):
38
+ image_volume , label_volume = inputs
39
+ original_image_dtype = image_volume .dtype
40
+
41
+ was_batched = True
42
+ if image_volume .shape .rank == 4 :
43
+ was_batched = False
44
+ image_volume = tf .expand_dims (image_volume , axis = 0 )
45
+ label_volume = tf .expand_dims (label_volume , axis = 0 )
46
+
47
+ if self .data_format == "HWDC" :
48
+ image_volume = tf .transpose (image_volume , perm = [0 , 3 , 1 , 2 , 4 ])
49
+ label_volume = tf .transpose (label_volume , perm = [0 , 3 , 1 , 2 , 4 ])
50
+
51
+ image_volume = tf .cast (image_volume , dtype = tf .bfloat16 )
52
+ input_shape = tf .shape (image_volume )
53
+ B , D , H , W = input_shape [0 ], input_shape [1 ], input_shape [2 ], input_shape [3 ]
54
+
55
+ coarse_flow = tf .random .uniform (
56
+ shape = (B , self .grid_size [0 ], self .grid_size [1 ], self .grid_size [2 ], 3 ),
57
+ minval = - 1 , maxval = 1 , dtype = tf .bfloat16 )
58
+
59
+ flow = tf .reshape (coarse_flow , [B * self .grid_size [0 ], self .grid_size [1 ], self .grid_size [2 ], 3 ])
60
+ flow = tf .image .resize (flow , size = [H , W ], method = 'bicubic' )
61
+ flow = tf .reshape (flow , [B , self .grid_size [0 ], H , W , 3 ])
62
+ flow = tf .transpose (flow , perm = [0 , 2 , 3 , 1 , 4 ])
63
+ flow = tf .reshape (flow , [B * H * W , self .grid_size [0 ], 3 ])
64
+ flow = tf .image .resize (tf .expand_dims (flow , axis = 1 ), size = [1 , D ], method = 'bicubic' )
65
+ flow = tf .squeeze (flow , axis = 1 )
66
+ flow = tf .reshape (flow , [B , H , W , D , 3 ])
67
+ flow = tf .transpose (flow , perm = [0 , 3 , 1 , 2 , 4 ])
68
+
69
+
70
+ flow = tf .cast (flow , dtype = tf .bfloat16 )
71
+
72
+ flow_components = tf .unstack (flow , axis = - 1 )
73
+ smoothed_components = []
74
+ for component in flow_components :
75
+ smoothed_component = self ._separable_gaussian_filter_3d (
76
+ component [..., tf .newaxis ], self .sigma
77
+ )
78
+ smoothed_components .append (smoothed_component [..., 0 ])
79
+ smoothed_flow = tf .stack (smoothed_components , axis = - 1 )
80
+
81
+
82
+ flow = smoothed_flow * self .alpha
83
+
84
+ grid_d , grid_h , grid_w = tf .meshgrid (
85
+ tf .range (D , dtype = tf .bfloat16 ),
86
+ tf .range (H , dtype = tf .bfloat16 ),
87
+ tf .range (W , dtype = tf .bfloat16 ),
88
+ indexing = 'ij'
89
+ )
90
+ grid = tf .stack ([grid_d , grid_h , grid_w ], axis = - 1 )
91
+
92
+
93
+ warp_grid = tf .expand_dims (grid , 0 ) + flow
94
+
95
+ warp_grid_floor = tf .floor (warp_grid )
96
+ t = warp_grid - warp_grid_floor
97
+
98
+ d0 = tf .cast (warp_grid_floor [..., 0 ], tf .int32 ); h0 = tf .cast (warp_grid_floor [..., 1 ], tf .int32 ); w0 = tf .cast (warp_grid_floor [..., 2 ], tf .int32 )
99
+ d1 = tf .clip_by_value (d0 + 1 , 0 , D - 1 ); h1 = tf .clip_by_value (h0 + 1 , 0 , H - 1 ); w1 = tf .clip_by_value (w0 + 1 , 0 , W - 1 )
100
+ d0 = tf .clip_by_value (d0 , 0 , D - 1 ); h0 = tf .clip_by_value (h0 , 0 , H - 1 ); w0 = tf .clip_by_value (w0 , 0 , W - 1 )
101
+
102
+ c000 = tf .gather_nd (image_volume , tf .stack ([d0 , h0 , w0 ], axis = - 1 ), batch_dims = 1 ); c001 = tf .gather_nd (image_volume , tf .stack ([d0 , h0 , w1 ], axis = - 1 ), batch_dims = 1 )
103
+ c010 = tf .gather_nd (image_volume , tf .stack ([d0 , h1 , w0 ], axis = - 1 ), batch_dims = 1 ); c011 = tf .gather_nd (image_volume , tf .stack ([d0 , h1 , w1 ], axis = - 1 ), batch_dims = 1 )
104
+ c100 = tf .gather_nd (image_volume , tf .stack ([d1 , h0 , w0 ], axis = - 1 ), batch_dims = 1 ); c101 = tf .gather_nd (image_volume , tf .stack ([d1 , h0 , w1 ], axis = - 1 ), batch_dims = 1 )
105
+ c110 = tf .gather_nd (image_volume , tf .stack ([d1 , h1 , w0 ], axis = - 1 ), batch_dims = 1 ); c111 = tf .gather_nd (image_volume , tf .stack ([d1 , h1 , w1 ], axis = - 1 ), batch_dims = 1 )
106
+
107
+ td , th , tw = t [..., 0 :1 ], t [..., 1 :2 ], t [..., 2 :3 ]
108
+ c00 = c000 * (1 - tw ) + c001 * tw ; c01 = c010 * (1 - tw ) + c011 * tw ; c10 = c100 * (1 - tw ) + c101 * tw ; c11 = c110 * (1 - tw ) + c111 * tw
109
+ c0 = c00 * (1 - th ) + c01 * th ; c1 = c10 * (1 - th ) + c11 * th
110
+ deformed_image = c0 * (1 - td ) + c1 * td
111
+ deformed_image = tf .cast (deformed_image , original_image_dtype )
112
+
113
+ nearest_indices_float = tf .round (warp_grid )
114
+ nearest_d = tf .clip_by_value (tf .cast (nearest_indices_float [..., 0 ], tf .int32 ), 0 , D - 1 )
115
+ nearest_h = tf .clip_by_value (tf .cast (nearest_indices_float [..., 1 ], tf .int32 ), 0 , H - 1 )
116
+ nearest_w = tf .clip_by_value (tf .cast (nearest_indices_float [..., 2 ], tf .int32 ), 0 , W - 1 )
117
+ deformed_label = tf .gather_nd (label_volume , tf .stack ([nearest_d , nearest_h , nearest_w ], axis = - 1 ), batch_dims = 1 )
118
+
119
+ if self .data_format == "HWDC" :
120
+ deformed_image = tf .transpose (deformed_image , perm = [0 , 2 , 3 , 1 , 4 ])
121
+ deformed_label = tf .transpose (deformed_label , perm = [0 , 2 , 3 , 1 , 4 ])
122
+
123
+ if not was_batched :
124
+ deformed_image = tf .squeeze (deformed_image , axis = 0 )
125
+ deformed_label = tf .squeeze (deformed_label , axis = 0 )
126
+
127
+ return deformed_image , deformed_label
0 commit comments