Skip to content

Commit 52a1932

Browse files
committed
added RNN Model
1 parent 71b8813 commit 52a1932

File tree

9 files changed

+136
-0
lines changed

9 files changed

+136
-0
lines changed

FastTimeSeries/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,24 @@ authors = ["FluxML Community"]
44
version = "0.1.0"
55

66
[deps]
7+
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
78
FastAI = "5d0beca9-ade8-49ae-ad0b-a3cf890e669f"
89
FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f"
10+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
911
InlineTest = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6"
1012
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1113
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1214
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1315
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
16+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1417

1518
[compat]
19+
DataDeps = "0.7"
1620
FastAI = "0.5"
1721
FilePathsBase = "0.9"
22+
Flux = "0.12, 0.13"
1823
InlineTest = "0.2"
1924
MLUtils = "0.2"
2025
UnicodePlots = "2, 3"
2126
julia = "1.6"
27+
Zygote = "0.6"

FastTimeSeries/src/FastTimeSeries.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ include("blocks/timeseriesrow.jl")
3434
# Encodings
3535
include("encodings/tspreprocessing.jl")
3636

37+
# Models
38+
include("models/Models.jl")
39+
include("models.jl")
40+
3741
include("container.jl")
3842
include("recipes.jl")
3943

FastTimeSeries/src/models.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
"""
3+
blockmodel(inblock::TimeSeriesRow, outblock::OneHotTensor{0}, backbone)
4+
5+
Construct a model for time-series classification.
6+
"""
7+
function blockmodel(inblock::TimeSeriesRow,
8+
outblock::OneHotTensor{0},
9+
backbone)
10+
data = rand(Float32, inblock.nfeatures, 32, inblock.obslength)
11+
# data = [rand(Float32, inblock.nfeatures, 32) for _ ∈ 1:inblock.obslength]
12+
output = backbone(data)
13+
outs = size(output)[1]
14+
return Models.RNNModel(backbone, outsize = length(outblock.classes), recout = outs)
15+
end
16+
17+
"""
18+
blockbackbone(inblock::TimeSeriesRow)
19+
20+
Construct a recurrent backbone
21+
"""
22+
function blockbackbone(inblock::TimeSeriesRow)
23+
Models.StackedLSTM(inblock.nfeatures, 16, 10, 2);
24+
end
25+
26+
# ## Tests
27+
28+
@testset "blockbackbone" begin @test_nowarn FastAI.blockbackbone(TimeSeriesRow(1,140)) end

FastTimeSeries/src/models/Models.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module Models
2+
3+
using ..FastAI
4+
5+
using Flux
6+
using Zygote
7+
using DataDeps
8+
using InlineTest
9+
10+
include("StackedLSTM.jl")
11+
include("RNN.jl")
12+
13+
export StackedLSTM, RNNModel
14+
15+
end

FastTimeSeries/src/models/RNN.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
function tabular2rnn(X::AbstractArray{Float32, 3})
2+
X = permutedims(X, (1, 3, 2))
3+
return X
4+
end
5+
6+
struct RNNModel{A, B}
7+
recbackbone::A
8+
finalclassifier::B
9+
end
10+
11+
"""
12+
RNNModel(recbackbonem, outsize, recout[; kwargs...])
13+
14+
Creates a RNN model from the recurrent 'recbackbone' architecture. The output from this backbone
15+
is passed through a dropout layer before a 'finalclassifier' block.
16+
17+
## Keyword arguments.
18+
19+
- `outsize`: The output size of the final classifier block. For single classification tasks,
20+
this would be the number of classes.
21+
- `recout`: The output size of the `recbackbone` architecture.
22+
- `dropout_rate`: Dropout probability for the dropout layer.
23+
"""
24+
25+
function RNNModel(recbackbone;
26+
outsize,
27+
recout,
28+
kwargs...)
29+
return RNNModel{}(recbackbone, Dense(recout, outsize))
30+
end
31+
32+
function (m::RNNModel)(X)
33+
X = tabular2rnn(X)
34+
Flux.reset!(m.recbackbone)
35+
X = m.recbackbone(X)[:, :, end]
36+
return m.finalclassifier(X)
37+
end
38+
39+
Flux.@functor RNNModel
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""
2+
StackedLSTM(in, out, hiddensize, layers)
3+
4+
Stacked LSTM network. Feeds the data through a chain of LSTM layers, where the hidden state
5+
of the previous layer gets fed to the next one. The first layer corresponds to
6+
`LSTM(in, hiddensize)`, the hidden layers to `LSTM(hiddensize, hiddensize)`, and the final
7+
layer to `LSTM(hiddensize, out)`. Takes the keyword argument `init` for the initialization
8+
of the layers.
9+
10+
"""
11+
function StackedLSTM(in::Int, out::Integer, hiddensize::Integer, layers::Integer;
12+
init=Flux.glorot_uniform)
13+
if layers == 1
14+
chain = Chain(LSTM(in, out; init=init))
15+
elseif layers == 2
16+
chain = Chain(LSTM(in, hiddensize; init=init),
17+
LSTM(hiddensize, out; init=init))
18+
else
19+
chain_vec = [LSTM(in, hiddensize; init=init)]
20+
for i = 1:layers - 2
21+
push!(chain_vec, LSTM(hiddensize, hiddensize; init=init))
22+
end
23+
chain = Chain(chain_vec..., LSTM(hiddensize, out; init=init))
24+
end
25+
return chain
26+
end

FastTimeSeries/src/models/layers.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""
2+
GAP1d(output_size)
3+
4+
Create a Global Adaptive Pooling + Flatten layer.
5+
"""
6+
function GAP1d(output_size::Int)
7+
gap = AdaptiveMeanPool((output_size,))
8+
Chain(gap, Flux.flatten)
9+
end

FastTimeSeries/test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[deps]
2+
FastAI = "5d0beca9-ade8-49ae-ad0b-a3cf890e669f"
3+
InlineTest = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6"
4+
ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89"

FastTimeSeries/test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
using FastAI, FastTimeSeries, ReTest
2+
3+
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"
4+
5+
FastTimeSeries.runtests([ReTest.fail, ReTest.not(ReTest.pass)])

0 commit comments

Comments
 (0)