Skip to content

Commit ef979fc

Browse files
committed
working model with residual layers
1 parent a326e36 commit ef979fc

File tree

1 file changed

+34
-26
lines changed

1 file changed

+34
-26
lines changed

FastTimeSeries/src/models/InceptionTime.jl

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
"""
2+
InceptionModule(ni::Int, nf::Int, ks::Int = 40, bottleneck::Bool = true)
3+
4+
TBW
5+
"""
16
function InceptionModule(ni::Int, nf::Int, ks::Int = 40, bottleneck::Bool = true)
27
ks = [ks ÷ (2^i) for i in range(0, stop = 2)]
38
ks = [ks[i] % 2 == 0 ? ks[i] - 1 : ks[i] for i in range(1, stop = 3)] # ensure odd ks
@@ -15,13 +20,6 @@ function InceptionModule(ni::Int, nf::Int, ks::Int = 40, bottleneck::Bool = true
1520
return Chain(Parallel(hcat, convs, maxconvpool), BatchNorm(nf * 4, relu))
1621
end
1722

18-
struct InceptionBlock
19-
residual::Bool
20-
depth::Int
21-
inception::Any
22-
shortcut::Any
23-
end
24-
2523
"""
2624
InceptionBlock(ni::Int, nf::Int = 32, residual::Bool = true, depth::Int = 6)
2725
@@ -30,31 +28,41 @@ TBW
3028
function InceptionBlock(ni::Int, nf::Int = 32, residual::Bool = true, depth::Int = 6)
3129
inception = []
3230
shortcut = []
33-
for d in range(0, stop = depth - 1)
34-
push!(inception, InceptionModule(d == 0 ? ni : nf * 4, nf))
35-
if (residual && d % 3 == 2)
36-
n_in = d == 2 ? ni : nf * 4
31+
32+
for d in range(1, stop = depth)
33+
push!(inception, InceptionModule(d == 1 ? ni : nf * 4, nf))
34+
if residual && d % 3 == 0
35+
n_in = d == 3 ? ni : nf * 4
3736
n_out = nf * 4
38-
block = n_in == n_out ? BatchNorm(n_out) : Chain(Conv1d(n_in, n_out, 1), BatchNorm(n_out))
39-
push!(shortcut, block)
37+
skip =
38+
n_in == n_out ? BatchNorm(n_out) :
39+
Chain(Conv1d(n_in, n_out, 1), BatchNorm(n_out))
40+
push!(shortcut, skip)
4041
end
4142
end
42-
return InceptionBlock(residual, depth, inception, shortcut)
43-
end
44-
Flux.@functor InceptionBlock
45-
Flux.trainable(m::InceptionBlock) = (m.inception, m.shortcut)
4643

47-
# Model Output
48-
function (m::InceptionBlock)(x)
49-
res = x
50-
for d in range(1, stop = m.depth)
51-
x = m.inception[d](x)
52-
if m.residual && d % 3 == 0
53-
x = Flux.relu(x + m.shortcut[d÷3](res))
54-
res = x
44+
blocks = []
45+
d = 1
46+
47+
while d <= depth
48+
blk = []
49+
while d <= depth
50+
push!(blk, inception[d])
51+
if d % 3 == 0
52+
d += 1
53+
break
54+
end
55+
d += 1
56+
end
57+
if residual && d ÷ 3 <= length(shortcut)
58+
skp = shortcut[d÷3]
59+
push!(blocks, Parallel(+, Chain(blk...), skp))
60+
else
61+
push!(blocks, Chain(blk...))
5562
end
5663
end
57-
return x
64+
return Chain(blocks...)
65+
5866
end
5967

6068
function changedims(X)

0 commit comments

Comments
 (0)