Skip to content

Commit e5a0be4

Browse files
committed
Update model.py
1 parent 1b12c4a commit e5a0be4

File tree

1 file changed

+19
-19
lines changed
  • napari_cellseg3d/code_models/models/wnet

1 file changed

+19
-19
lines changed

napari_cellseg3d/code_models/models/wnet/model.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -99,22 +99,22 @@ def __init__(
9999
self.max_pool = nn.MaxPool3d(2)
100100
self.in_b = InBlock(in_channels, self.channels[0], dropout=dropout)
101101
self.conv1 = Block(channels[0], self.channels[1], dropout=dropout)
102-
# self.conv2 = Block(channels[1], self.channels[2], dropout=dropout)
102+
self.conv2 = Block(channels[1], self.channels[2], dropout=dropout)
103103
# self.conv3 = Block(channels[2], self.channels[3], dropout=dropout)
104104
# self.bot = Block(channels[3], self.channels[4], dropout=dropout)
105-
# self.bot = Block(channels[2], self.channels[3], dropout=dropout)
106-
self.bot = Block(channels[1], self.channels[2], dropout=dropout)
105+
self.bot = Block(channels[2], self.channels[3], dropout=dropout)
106+
# self.bot = Block(channels[1], self.channels[2], dropout=dropout)
107107
# self.bot = Block(channels[0], self.channels[1], dropout=dropout)
108108
# self.deconv1 = Block(channels[4], self.channels[3], dropout=dropout)
109-
# self.deconv2 = Block(channels[3], self.channels[2], dropout=dropout)
109+
self.deconv2 = Block(channels[3], self.channels[2], dropout=dropout)
110110
self.deconv3 = Block(channels[2], self.channels[1], dropout=dropout)
111111
self.out_b = OutBlock(channels[1], out_channels, dropout=dropout)
112112
# self.conv_trans1 = nn.ConvTranspose3d(
113113
# self.channels[4], self.channels[3], 2, stride=2
114114
# )
115-
# self.conv_trans2 = nn.ConvTranspose3d(
116-
# self.channels[3], self.channels[2], 2, stride=2
117-
# )
115+
self.conv_trans2 = nn.ConvTranspose3d(
116+
self.channels[3], self.channels[2], 2, stride=2
117+
)
118118
self.conv_trans3 = nn.ConvTranspose3d(
119119
self.channels[2], self.channels[1], 2, stride=2
120120
)
@@ -129,11 +129,11 @@ def forward(self, x):
129129
"""Forward pass of the U-Net model."""
130130
in_b = self.in_b(x)
131131
c1 = self.conv1(self.max_pool(in_b))
132-
# c2 = self.conv2(self.max_pool(c1))
132+
c2 = self.conv2(self.max_pool(c1))
133133
# c3 = self.conv3(self.max_pool(c2))
134134
# x = self.bot(self.max_pool(c3))
135-
# x = self.bot(self.max_pool(c2))
136-
x = self.bot(self.max_pool(c1))
135+
x = self.bot(self.max_pool(c2))
136+
# x = self.bot(self.max_pool(c1))
137137
# x = self.bot(self.max_pool(in_b))
138138
# x = self.deconv1(
139139
# torch.cat(
@@ -144,15 +144,15 @@ def forward(self, x):
144144
# dim=1,
145145
# )
146146
# )
147-
# x = self.deconv2(
148-
# torch.cat(
149-
# [
150-
# c2,
151-
# self.conv_trans2(x),
152-
# ],
153-
# dim=1,
154-
# )
155-
# )
147+
x = self.deconv2(
148+
torch.cat(
149+
[
150+
c2,
151+
self.conv_trans2(x),
152+
],
153+
dim=1,
154+
)
155+
)
156156
x = self.deconv3(
157157
torch.cat(
158158
[

0 commit comments

Comments
 (0)