1
+ """
2
+ InceptionModule(ni::Int, nf::Int, ks::Int = 40, bottleneck::Bool = true)
3
+
4
+ TBW
5
+ """
1
6
function InceptionModule (ni:: Int , nf:: Int , ks:: Int = 40 , bottleneck:: Bool = true )
2
7
ks = [ks ÷ (2 ^ i) for i in range (0 , stop = 2 )]
3
8
ks = [ks[i] % 2 == 0 ? ks[i] - 1 : ks[i] for i in range (1 , stop = 3 )] # ensure odd ks
@@ -15,13 +20,6 @@ function InceptionModule(ni::Int, nf::Int, ks::Int = 40, bottleneck::Bool = true
15
20
return Chain (Parallel (hcat, convs, maxconvpool), BatchNorm (nf * 4 , relu))
16
21
end
17
22
18
- struct InceptionBlock
19
- residual:: Bool
20
- depth:: Int
21
- inception:: Any
22
- shortcut:: Any
23
- end
24
-
25
23
"""
26
24
InceptionBlock(ni::Int, nf::Int = 32, residual::Bool = true, depth::Int = 6)
27
25
30
28
function InceptionBlock (ni:: Int , nf:: Int = 32 , residual:: Bool = true , depth:: Int = 6 )
31
29
inception = []
32
30
shortcut = []
33
- for d in range (0 , stop = depth - 1 )
34
- push! (inception, InceptionModule (d == 0 ? ni : nf * 4 , nf))
35
- if (residual && d % 3 == 2 )
36
- n_in = d == 2 ? ni : nf * 4
31
+
32
+ for d in range (1 , stop = depth)
33
+ push! (inception, InceptionModule (d == 1 ? ni : nf * 4 , nf))
34
+ if residual && d % 3 == 0
35
+ n_in = d == 3 ? ni : nf * 4
37
36
n_out = nf * 4
38
- block = n_in == n_out ? BatchNorm (n_out) : Chain (Conv1d (n_in, n_out, 1 ), BatchNorm (n_out))
39
- push! (shortcut, block)
37
+ skip =
38
+ n_in == n_out ? BatchNorm (n_out) :
39
+ Chain (Conv1d (n_in, n_out, 1 ), BatchNorm (n_out))
40
+ push! (shortcut, skip)
40
41
end
41
42
end
42
- return InceptionBlock (residual, depth, inception, shortcut)
43
- end
44
- Flux. @functor InceptionBlock
45
- Flux. trainable (m:: InceptionBlock ) = (m. inception, m. shortcut)
46
43
47
- # Model Output
48
- function (m:: InceptionBlock )(x)
49
- res = x
50
- for d in range (1 , stop = m. depth)
51
- x = m. inception[d](x)
52
- if m. residual && d % 3 == 0
53
- x = Flux. relu (x + m. shortcut[d÷ 3 ](res))
54
- res = x
44
+ blocks = []
45
+ d = 1
46
+
47
+ while d <= depth
48
+ blk = []
49
+ while d <= depth
50
+ push! (blk, inception[d])
51
+ if d % 3 == 0
52
+ d += 1
53
+ break
54
+ end
55
+ d += 1
56
+ end
57
+ if residual && d ÷ 3 <= length (shortcut)
58
+ skp = shortcut[d÷ 3 ]
59
+ push! (blocks, Parallel (+ , Chain (blk... ), skp))
60
+ else
61
+ push! (blocks, Chain (blk... ))
55
62
end
56
63
end
57
- return x
64
+ return Chain (blocks... )
65
+
58
66
end
59
67
60
68
function changedims (X)
0 commit comments