1+ import torch
2+ import torch .nn as nn
3+ import torch .nn .functional as F
4+ import time as tm
5+
6+ class DecoderSTCNN (nn .Module ):
7+
8+ def __init__ (self , layer_size , kernel_size , initial_filter_size , channels , dropout_rate , upsample = False ):
9+ super (DecoderSTCNN , self ).__init__ ()
10+ self .padding = kernel_size - 1
11+ self .upsample = upsample
12+ self .dropout_rate = dropout_rate
13+ self .conv_layers = nn .ModuleList ()
14+ self .relu_layers = nn .ModuleList ()
15+ self .batch_layers = nn .ModuleList ()
16+ self .dropout_layers = nn .ModuleList ()
17+
18+ temporal_kernel_size = [kernel_size , 1 , 1 ]
19+ temporal_padding = [self .padding , 0 , 0 ]
20+ out_channels = initial_filter_size
21+ in_channels = channels
22+ for i in range (layer_size ):
23+ self .conv_layers .append (
24+ nn .Conv3d (in_channels = in_channels , out_channels = out_channels ,
25+ kernel_size = temporal_kernel_size , padding = temporal_padding , bias = False )
26+ )
27+ self .relu_layers .append (nn .ReLU ())
28+ self .batch_layers .append (nn .BatchNorm3d (out_channels ))
29+ self .dropout_layers .append (nn .Dropout (dropout_rate ))
30+ in_channels = out_channels
31+
32+ self .upsample_conv = nn .ConvTranspose3d (out_channels , out_channels , kernel_size = temporal_kernel_size ,
33+ stride = [3 ,1 ,1 ], padding = [1 ,0 ,0 ])
34+ padding_final = [kernel_size // 2 , 0 , 0 ]
35+ self .conv_final = nn .Conv3d (in_channels = out_channels , out_channels = 1 , kernel_size = temporal_kernel_size ,
36+ padding = padding_final , bias = True )
37+
38+
39+ def learning_with_dropout (self , x ):
40+ for conv , relu , batch , drop in zip (self .conv_layers , self .relu_layers ,
41+ self .batch_layers , self .dropout_layers ):
42+ x = conv (x )[:,:,:- self .padding ,:,:]
43+ x = drop (relu (batch (x )))
44+
45+ return x
46+
47+ def learning_without_dropout (self , x ):
48+ for conv , relu , batch in zip (self .conv_layers , self .relu_layers , self .batch_layers ):
49+ x = conv (x )[:,:,:- self .padding ,:,:]
50+ x = relu (batch (x ))
51+
52+ return x
53+
54+ def forward (self , input_ ):
55+ if self .dropout_rate > 0. :
56+ output = self .learning_with_dropout (input_ )
57+ else :
58+ output = self .learning_without_dropout (input_ )
59+ if (self .upsample ):
60+ output_size = torch .randn (input_ .shape [0 ],1 , input_ .shape [2 ] + 10 ,
61+ input_ .shape [3 ], input_ .shape [4 ]).size ()
62+ output = self .upsample_conv (output , output_size = output_size )
63+
64+ output = self .conv_final (output )
65+ return output
0 commit comments