@@ -99,22 +99,22 @@ def __init__(
9999 self .max_pool = nn .MaxPool3d (2 )
100100 self .in_b = InBlock (in_channels , self .channels [0 ], dropout = dropout )
101101 self .conv1 = Block (channels [0 ], self .channels [1 ], dropout = dropout )
102- # self.conv2 = Block(channels[1], self.channels[2], dropout=dropout)
102+ self .conv2 = Block (channels [1 ], self .channels [2 ], dropout = dropout )
103103 # self.conv3 = Block(channels[2], self.channels[3], dropout=dropout)
104104 # self.bot = Block(channels[3], self.channels[4], dropout=dropout)
105- # self.bot = Block(channels[2], self.channels[3], dropout=dropout)
106- self .bot = Block (channels [1 ], self .channels [2 ], dropout = dropout )
105+ self .bot = Block (channels [2 ], self .channels [3 ], dropout = dropout )
106+ # self.bot = Block(channels[1], self.channels[2], dropout=dropout)
107107 # self.bot = Block(channels[0], self.channels[1], dropout=dropout)
108108 # self.deconv1 = Block(channels[4], self.channels[3], dropout=dropout)
109- # self.deconv2 = Block(channels[3], self.channels[2], dropout=dropout)
109+ self .deconv2 = Block (channels [3 ], self .channels [2 ], dropout = dropout )
110110 self .deconv3 = Block (channels [2 ], self .channels [1 ], dropout = dropout )
111111 self .out_b = OutBlock (channels [1 ], out_channels , dropout = dropout )
112112 # self.conv_trans1 = nn.ConvTranspose3d(
113113 # self.channels[4], self.channels[3], 2, stride=2
114114 # )
115- # self.conv_trans2 = nn.ConvTranspose3d(
116- # self.channels[3], self.channels[2], 2, stride=2
117- # )
115+ self .conv_trans2 = nn .ConvTranspose3d (
116+ self .channels [3 ], self .channels [2 ], 2 , stride = 2
117+ )
118118 self .conv_trans3 = nn .ConvTranspose3d (
119119 self .channels [2 ], self .channels [1 ], 2 , stride = 2
120120 )
@@ -129,11 +129,11 @@ def forward(self, x):
129129 """Forward pass of the U-Net model."""
130130 in_b = self .in_b (x )
131131 c1 = self .conv1 (self .max_pool (in_b ))
132- # c2 = self.conv2(self.max_pool(c1))
132+ c2 = self .conv2 (self .max_pool (c1 ))
133133 # c3 = self.conv3(self.max_pool(c2))
134134 # x = self.bot(self.max_pool(c3))
135- # x = self.bot(self.max_pool(c2))
136- x = self .bot (self .max_pool (c1 ))
135+ x = self .bot (self .max_pool (c2 ))
136+ # x = self.bot(self.max_pool(c1))
137137 # x = self.bot(self.max_pool(in_b))
138138 # x = self.deconv1(
139139 # torch.cat(
@@ -144,15 +144,15 @@ def forward(self, x):
144144 # dim=1,
145145 # )
146146 # )
147- # x = self.deconv2(
148- # torch.cat(
149- # [
150- # c2,
151- # self.conv_trans2(x),
152- # ],
153- # dim=1,
154- # )
155- # )
147+ x = self .deconv2 (
148+ torch .cat (
149+ [
150+ c2 ,
151+ self .conv_trans2 (x ),
152+ ],
153+ dim = 1 ,
154+ )
155+ )
156156 x = self .deconv3 (
157157 torch .cat (
158158 [
0 commit comments