16
16
"""A space-time Transformer with Cuboid Attention"""
17
17
18
18
19
- class InitialEncoder (paddle . nn .Layer ):
19
+ class InitialEncoder (nn .Layer ):
20
20
def __init__ (
21
21
self ,
22
22
dim ,
@@ -38,39 +38,35 @@ def __init__(
38
38
for i in range (num_conv_layers ):
39
39
if i == 0 :
40
40
conv_block .append (
41
- paddle . nn .Conv2D (
41
+ nn .Conv2D (
42
42
kernel_size = (3 , 3 ),
43
43
padding = (1 , 1 ),
44
44
in_channels = dim ,
45
45
out_channels = out_dim ,
46
46
)
47
47
)
48
- conv_block .append (
49
- paddle .nn .GroupNorm (num_groups = 16 , num_channels = out_dim )
50
- )
48
+ conv_block .append (nn .GroupNorm (num_groups = 16 , num_channels = out_dim ))
51
49
conv_block .append (
52
50
act_mod .get_activation (activation )
53
51
if activation != "leaky_relu"
54
52
else nn .LeakyReLU (NEGATIVE_SLOPE )
55
53
)
56
54
else :
57
55
conv_block .append (
58
- paddle . nn .Conv2D (
56
+ nn .Conv2D (
59
57
kernel_size = (3 , 3 ),
60
58
padding = (1 , 1 ),
61
59
in_channels = out_dim ,
62
60
out_channels = out_dim ,
63
61
)
64
62
)
65
- conv_block .append (
66
- paddle .nn .GroupNorm (num_groups = 16 , num_channels = out_dim )
67
- )
63
+ conv_block .append (nn .GroupNorm (num_groups = 16 , num_channels = out_dim ))
68
64
conv_block .append (
69
65
act_mod .get_activation (activation )
70
66
if activation != "leaky_relu"
71
67
else nn .LeakyReLU (NEGATIVE_SLOPE )
72
68
)
73
- self .conv_block = paddle . nn .Sequential (* conv_block )
69
+ self .conv_block = nn .Sequential (* conv_block )
74
70
if isinstance (downsample_scale , int ):
75
71
patch_merge_downsample = (1 , downsample_scale , downsample_scale )
76
72
elif len (downsample_scale ) == 2 :
@@ -121,7 +117,7 @@ def forward(self, x):
121
117
return x
122
118
123
119
124
- class FinalDecoder (paddle . nn .Layer ):
120
+ class FinalDecoder (nn .Layer ):
125
121
def __init__ (
126
122
self ,
127
123
target_thw : Tuple [int , ...],
@@ -142,20 +138,20 @@ def __init__(
142
138
conv_block = []
143
139
for i in range (num_conv_layers ):
144
140
conv_block .append (
145
- paddle . nn .Conv2D (
141
+ nn .Conv2D (
146
142
kernel_size = (3 , 3 ),
147
143
padding = (1 , 1 ),
148
144
in_channels = dim ,
149
145
out_channels = dim ,
150
146
)
151
147
)
152
- conv_block .append (paddle . nn .GroupNorm (num_groups = 16 , num_channels = dim ))
148
+ conv_block .append (nn .GroupNorm (num_groups = 16 , num_channels = dim ))
153
149
conv_block .append (
154
150
act_mod .get_activation (activation )
155
151
if activation != "leaky_relu"
156
152
else nn .LeakyReLU (NEGATIVE_SLOPE )
157
153
)
158
- self .conv_block = paddle . nn .Sequential (* conv_block )
154
+ self .conv_block = nn .Sequential (* conv_block )
159
155
self .upsample = cuboid_decoder .Upsample3DLayer (
160
156
dim = dim ,
161
157
out_dim = dim ,
@@ -196,7 +192,7 @@ def forward(self, x):
196
192
return x
197
193
198
194
199
- class InitialStackPatchMergingEncoder (paddle . nn .Layer ):
195
+ class InitialStackPatchMergingEncoder (nn .Layer ):
200
196
def __init__ (
201
197
self ,
202
198
num_merge : int ,
@@ -220,8 +216,8 @@ def __init__(
220
216
self .downsample_scale_list = downsample_scale_list [:num_merge ]
221
217
self .num_conv_per_merge_list = num_conv_per_merge_list
222
218
self .num_group_list = [max (1 , out_dim // 4 ) for out_dim in self .out_dim_list ]
223
- self .conv_block_list = paddle . nn .LayerList ()
224
- self .patch_merge_list = paddle . nn .LayerList ()
219
+ self .conv_block_list = nn .LayerList ()
220
+ self .patch_merge_list = nn .LayerList ()
225
221
for i in range (num_merge ):
226
222
if i == 0 :
227
223
in_dim = in_dim
@@ -236,15 +232,15 @@ def __init__(
236
232
else :
237
233
conv_in_dim = out_dim
238
234
conv_block .append (
239
- paddle . nn .Conv2D (
235
+ nn .Conv2D (
240
236
kernel_size = (3 , 3 ),
241
237
padding = (1 , 1 ),
242
238
in_channels = conv_in_dim ,
243
239
out_channels = out_dim ,
244
240
)
245
241
)
246
242
conv_block .append (
247
- paddle . nn .GroupNorm (
243
+ nn .GroupNorm (
248
244
num_groups = self .num_group_list [i ], num_channels = out_dim
249
245
)
250
246
)
@@ -253,7 +249,7 @@ def __init__(
253
249
if activation != "leaky_relu"
254
250
else nn .LeakyReLU (NEGATIVE_SLOPE )
255
251
)
256
- conv_block = paddle . nn .Sequential (* conv_block )
252
+ conv_block = nn .Sequential (* conv_block )
257
253
self .conv_block_list .append (conv_block )
258
254
patch_merge = cuboid_encoder .PatchMerging3D (
259
255
dim = out_dim ,
@@ -303,7 +299,7 @@ def forward(self, x):
303
299
return x
304
300
305
301
306
- class FinalStackUpsamplingDecoder (paddle . nn .Layer ):
302
+ class FinalStackUpsamplingDecoder (nn .Layer ):
307
303
def __init__ (
308
304
self ,
309
305
target_shape_list : Tuple [Tuple [int , ...]],
@@ -326,8 +322,8 @@ def __init__(
326
322
self .in_dim = in_dim
327
323
self .num_conv_per_up_list = num_conv_per_up_list
328
324
self .num_group_list = [max (1 , out_dim // 4 ) for out_dim in self .out_dim_list ]
329
- self .conv_block_list = paddle . nn .LayerList ()
330
- self .upsample_list = paddle . nn .LayerList ()
325
+ self .conv_block_list = nn .LayerList ()
326
+ self .upsample_list = nn .LayerList ()
331
327
for i in range (self .num_upsample ):
332
328
if i == 0 :
333
329
in_dim = in_dim
@@ -349,15 +345,15 @@ def __init__(
349
345
else :
350
346
conv_in_dim = out_dim
351
347
conv_block .append (
352
- paddle . nn .Conv2D (
348
+ nn .Conv2D (
353
349
kernel_size = (3 , 3 ),
354
350
padding = (1 , 1 ),
355
351
in_channels = conv_in_dim ,
356
352
out_channels = out_dim ,
357
353
)
358
354
)
359
355
conv_block .append (
360
- paddle . nn .GroupNorm (
356
+ nn .GroupNorm (
361
357
num_groups = self .num_group_list [i ], num_channels = out_dim
362
358
)
363
359
)
@@ -366,7 +362,7 @@ def __init__(
366
362
if activation != "leaky_relu"
367
363
else nn .LeakyReLU (NEGATIVE_SLOPE )
368
364
)
369
- conv_block = paddle . nn .Sequential (* conv_block )
365
+ conv_block = nn .Sequential (* conv_block )
370
366
self .conv_block_list .append (conv_block )
371
367
self .reset_parameters ()
372
368
@@ -686,7 +682,7 @@ def __init__(
686
682
embed_dim = base_units , typ = pos_embed_type , maxH = H_in , maxW = W_in , maxT = T_in
687
683
)
688
684
mem_shapes = self .encoder .get_mem_shapes ()
689
- self .z_proj = paddle . nn .Linear (
685
+ self .z_proj = nn .Linear (
690
686
in_features = mem_shapes [- 1 ][- 1 ], out_features = mem_shapes [- 1 ][- 1 ]
691
687
)
692
688
self .dec_pos_embed = cuboid_decoder .PosEmbed (
@@ -799,7 +795,7 @@ def get_initial_encoder_final_decoder(
799
795
new_input_shape = self .initial_encoder .patch_merge .get_out_shape (
800
796
self .input_shape
801
797
)
802
- self .dec_final_proj = paddle . nn .Linear (
798
+ self .dec_final_proj = nn .Linear (
803
799
in_features = self .base_units , out_features = C_out
804
800
)
805
801
elif self .initial_downsample_type == "stack_conv" :
@@ -839,7 +835,7 @@ def get_initial_encoder_final_decoder(
839
835
linear_init_mode = self .down_up_linear_init_mode ,
840
836
norm_init_mode = self .norm_init_mode ,
841
837
)
842
- self .dec_final_proj = paddle . nn .Linear (
838
+ self .dec_final_proj = nn .Linear (
843
839
in_features = dec_target_shape_list [- 1 ][- 1 ], out_features = C_out
844
840
)
845
841
new_input_shape = self .initial_encoder .get_out_shape_list (self .input_shape )[
@@ -892,7 +888,7 @@ def get_initial_z(self, final_mem, T_out):
892
888
shape = [B , - 1 , - 1 , - 1 , - 1 ]
893
889
)
894
890
elif self .z_init_method == "nearest_interp" :
895
- initial_z = paddle . nn .functional .interpolate (
891
+ initial_z = nn .functional .interpolate (
896
892
x = final_mem .transpose (perm = [0 , 4 , 1 , 2 , 3 ]),
897
893
size = (T_out , final_mem .shape [2 ], final_mem .shape [3 ]),
898
894
).transpose (perm = [0 , 2 , 3 , 4 , 1 ])
0 commit comments