@@ -39,9 +39,9 @@ function UNetDynamic(backbone,
39
39
inputsize;
40
40
m_middle = UNetMiddleBlock,
41
41
skip_upscale = fdownscale,
42
- kwargs... )
42
+ kwargs... )
43
43
outsz = Flux. outputsize (unet, inputsize)
44
- return Chain (unet, final (outsz[end - 1 ], k_out))
44
+ return Chain (unet, final (outsz[end - 1 ], k_out, length (outsz) - 2 ))
45
45
end
46
46
47
47
function catchannels (x1, x2)
@@ -50,7 +50,7 @@ function catchannels(x1, x2)
50
50
end
51
51
52
52
function unetlayers (layers,
53
- sz;
53
+ sz;
54
54
k_out = nothing ,
55
55
skip_upscale = 0 ,
56
56
m_middle = _ -> (identity,))
@@ -81,7 +81,8 @@ function unetlayers(layers,
81
81
return UNetBlock (Chain (layer, childunet),
82
82
k_in, # Input channels to upsampling layer
83
83
k_mid,
84
- k_out)
84
+ k_out,
85
+ length (outsz) - 2 )
85
86
end
86
87
end
87
88
@@ -95,28 +96,28 @@ Given convolutional module `m` that halves the spatial dimensions
95
96
and outputs `k_in` filters, create a module that upsamples the
96
97
spatial dimensions and then aggregates features via a skip connection.
97
98
"""
98
- function UNetBlock (m_child, k_in, k_mid, k_out = 2 k_in)
99
+ function UNetBlock (m_child, k_in, k_mid, k_out = 2 k_in, ndim = 2 )
99
100
return Chain (upsample = SkipConnection (Chain (child = m_child, # Downsampling and processing
100
- upsample = PixelShuffleICNR (k_mid, k_mid)),
101
+ upsample = PixelShuffleICNR (k_mid, k_mid, ndim )),
101
102
Parallel (catchannels, identity, BatchNorm (k_in))),
102
103
act = xs -> relu .(xs),
103
- combine = UNetCombineLayer (k_in + k_mid, k_out))
104
+ combine = UNetCombineLayer (k_in + k_mid, k_out, ndim ))
104
105
end
105
106
106
- function PixelShuffleICNR (k_in, k_out; r = 2 )
107
- return Chain (convxlayer (k_in, k_out * (r^ 2 ), ks = 1 ), Flux. PixelShuffle (r))
107
+ function PixelShuffleICNR (k_in, k_out, ndim ; r = 2 )
108
+ return Chain (convxlayer (k_in, k_out * (r^ ndim ), ks = 1 , ndim = ndim ), Flux. PixelShuffle (r))
108
109
end
109
110
110
- function UNetCombineLayer (k_in, k_out)
111
- return Chain (convxlayer (k_in, k_out), convxlayer (k_out, k_out))
111
+ function UNetCombineLayer (k_in, k_out, ndim )
112
+ return Chain (convxlayer (k_in, k_out, ndim = ndim ), convxlayer (k_out, k_out, ndim = ndim ))
112
113
end
113
114
114
- function UNetMiddleBlock (k)
115
- return Chain (convxlayer (k, 2 k), convxlayer (2 k, k))
115
+ function UNetMiddleBlock (k, ndim )
116
+ return Chain (convxlayer (k, 2 k, ndim = ndim ), convxlayer (2 k, k, ndim = ndim ))
116
117
end
117
118
118
- function UNetFinalBlock (k_in, k_out)
119
- return Chain (ResBlock (1 , k_in, k_in), convxlayer (k_in, k_out, ks = 1 ))
119
+ function UNetFinalBlock (k_in, k_out, ndim )
120
+ return Chain (ResBlock (1 , k_in, k_in, ndim = ndim ), convxlayer (k_in, k_out, ks = 1 , ndim = ndim ))
120
121
end
121
122
122
123
"""
139
140
140
141
model = UNetDynamic (Models. xresnet18 (), (128 , 128 , 3 , 1 ), 4 , fdownscale = 1 )
141
142
@test Flux. outputsize (model, (128 , 128 , 3 , 1 )) == (64 , 64 , 4 , 1 )
143
+
144
+ model = UNetDynamic (Models. xresnet18 (ndim = 3 ), (128 , 128 , 128 , 3 , 1 ), 4 )
145
+ @test Flux. outputsize (model, (128 , 128 , 128 , 3 , 1 )) == (128 , 128 , 128 , 4 , 1 )
142
146
end end
0 commit comments