Skip to content

Commit dee6399

Browse files
authored
Fix UNet for 3D convolutions (specify ndim to convxlayer and ResBlock) (#263)
* Fix UNet for 3D convolutions (specify ndim to convxlayer and ResBlock)
1 parent 767aa2b commit dee6399

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

FastVision/src/models/unet.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ function UNetDynamic(backbone,
3939
inputsize;
4040
m_middle = UNetMiddleBlock,
4141
skip_upscale = fdownscale,
42-
kwargs...)
42+
kwargs...)
4343
outsz = Flux.outputsize(unet, inputsize)
44-
return Chain(unet, final(outsz[end - 1], k_out))
44+
return Chain(unet, final(outsz[end - 1], k_out, length(outsz) - 2))
4545
end
4646

4747
function catchannels(x1, x2)
@@ -50,7 +50,7 @@ function catchannels(x1, x2)
5050
end
5151

5252
function unetlayers(layers,
53-
sz;
53+
sz;
5454
k_out = nothing,
5555
skip_upscale = 0,
5656
m_middle = _ -> (identity,))
@@ -81,7 +81,8 @@ function unetlayers(layers,
8181
return UNetBlock(Chain(layer, childunet),
8282
k_in, # Input channels to upsampling layer
8383
k_mid,
84-
k_out)
84+
k_out,
85+
length(outsz) - 2)
8586
end
8687
end
8788

@@ -95,28 +96,28 @@ Given convolutional module `m` that halves the spatial dimensions
9596
and outputs `k_in` filters, create a module that upsamples the
9697
spatial dimensions and then aggregates features via a skip connection.
9798
"""
98-
function UNetBlock(m_child, k_in, k_mid, k_out = 2k_in)
99+
function UNetBlock(m_child, k_in, k_mid, k_out = 2k_in, ndim = 2)
99100
return Chain(upsample = SkipConnection(Chain(child = m_child, # Downsampling and processing
100-
upsample = PixelShuffleICNR(k_mid, k_mid)),
101+
upsample = PixelShuffleICNR(k_mid, k_mid, ndim)),
101102
Parallel(catchannels, identity, BatchNorm(k_in))),
102103
act = xs -> relu.(xs),
103-
combine = UNetCombineLayer(k_in + k_mid, k_out))
104+
combine = UNetCombineLayer(k_in + k_mid, k_out, ndim))
104105
end
105106

106-
function PixelShuffleICNR(k_in, k_out; r = 2)
107-
return Chain(convxlayer(k_in, k_out * (r^2), ks = 1), Flux.PixelShuffle(r))
107+
function PixelShuffleICNR(k_in, k_out, ndim; r = 2)
108+
return Chain(convxlayer(k_in, k_out * (r^ndim), ks = 1, ndim = ndim), Flux.PixelShuffle(r))
108109
end
109110

110-
function UNetCombineLayer(k_in, k_out)
111-
return Chain(convxlayer(k_in, k_out), convxlayer(k_out, k_out))
111+
function UNetCombineLayer(k_in, k_out, ndim)
112+
return Chain(convxlayer(k_in, k_out, ndim = ndim), convxlayer(k_out, k_out, ndim = ndim))
112113
end
113114

114-
function UNetMiddleBlock(k)
115-
return Chain(convxlayer(k, 2k), convxlayer(2k, k))
115+
function UNetMiddleBlock(k, ndim)
116+
return Chain(convxlayer(k, 2k, ndim = ndim), convxlayer(2k, k, ndim = ndim))
116117
end
117118

118-
function UNetFinalBlock(k_in, k_out)
119-
return Chain(ResBlock(1, k_in, k_in), convxlayer(k_in, k_out, ks = 1))
119+
function UNetFinalBlock(k_in, k_out, ndim)
120+
return Chain(ResBlock(1, k_in, k_in, ndim = ndim), convxlayer(k_in, k_out, ks = 1, ndim = ndim))
120121
end
121122

122123
"""
@@ -139,4 +140,7 @@ end
139140

140141
model = UNetDynamic(Models.xresnet18(), (128, 128, 3, 1), 4, fdownscale = 1)
141142
@test Flux.outputsize(model, (128, 128, 3, 1)) == (64, 64, 4, 1)
143+
144+
model = UNetDynamic(Models.xresnet18(ndim = 3), (128, 128, 128, 3, 1), 4)
145+
@test Flux.outputsize(model, (128, 128, 128, 3, 1)) == (128, 128, 128, 4, 1)
142146
end end

0 commit comments

Comments
 (0)