Skip to content

Commit eb384fb

Browse files
Fix model name
1 parent 612781d commit eb384fb

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

model/decoder.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import time as tm
5+
6+
class DecoderSTCNN(nn.Module):
7+
8+
def __init__(self, layer_size, kernel_size, initial_filter_size, channels, dropout_rate, upsample=False):
9+
super(DecoderSTCNN, self).__init__()
10+
self.padding = kernel_size - 1
11+
self.upsample = upsample
12+
self.dropout_rate = dropout_rate
13+
self.conv_layers = nn.ModuleList()
14+
self.relu_layers = nn.ModuleList()
15+
self.batch_layers = nn.ModuleList()
16+
self.dropout_layers = nn.ModuleList()
17+
18+
temporal_kernel_size = [kernel_size, 1, 1]
19+
temporal_padding = [self.padding, 0, 0]
20+
out_channels = initial_filter_size
21+
in_channels = channels
22+
for i in range(layer_size):
23+
self.conv_layers.append(
24+
nn.Conv3d(in_channels=in_channels, out_channels=out_channels,
25+
kernel_size=temporal_kernel_size, padding=temporal_padding, bias=False)
26+
)
27+
self.relu_layers.append(nn.ReLU())
28+
self.batch_layers.append(nn.BatchNorm3d(out_channels))
29+
self.dropout_layers.append(nn.Dropout(dropout_rate))
30+
in_channels = out_channels
31+
32+
self.upsample_conv = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=temporal_kernel_size,
33+
stride=[3,1,1], padding=[1,0,0])
34+
padding_final = [kernel_size // 2, 0, 0]
35+
self.conv_final = nn.Conv3d(in_channels=out_channels, out_channels=1, kernel_size=temporal_kernel_size,
36+
padding=padding_final, bias=True)
37+
38+
39+
def learning_with_dropout(self, x):
40+
for conv, relu, batch, drop in zip(self.conv_layers, self.relu_layers,
41+
self.batch_layers, self.dropout_layers):
42+
x = conv(x)[:,:,:-self.padding,:,:]
43+
x = drop(relu(batch(x)))
44+
45+
return x
46+
47+
def learning_without_dropout(self, x):
48+
for conv, relu, batch in zip(self.conv_layers, self.relu_layers, self.batch_layers):
49+
x = conv(x)[:,:,:-self.padding,:,:]
50+
x = relu(batch(x))
51+
52+
return x
53+
54+
def forward(self, input_):
55+
if self.dropout_rate > 0.:
56+
output = self.learning_with_dropout(input_)
57+
else:
58+
output = self.learning_without_dropout(input_)
59+
if (self.upsample):
60+
output_size = torch.randn(input_.shape[0],1, input_.shape[2] + 10,
61+
input_.shape[3], input_.shape[4]).size()
62+
output = self.upsample_conv(output, output_size=output_size)
63+
64+
output = self.conv_final(output)
65+
return output

0 commit comments

Comments
 (0)