Skip to content

Commit 3d70b50

Browse files
authored
Merge pull request #253 from codeboy5/ts-models
Added Model for Time Series Classification
2 parents 6b9e1f1 + 54e8ebf commit 3d70b50

File tree

14 files changed

+952
-78
lines changed

14 files changed

+952
-78
lines changed

FastTimeSeries/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,24 @@ authors = ["FluxML Community"]
44
version = "0.1.0"
55

66
[deps]
7+
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
78
FastAI = "5d0beca9-ade8-49ae-ad0b-a3cf890e669f"
89
FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f"
10+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
911
InlineTest = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6"
1012
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1113
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1214
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1315
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
16+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1417

1518
[compat]
19+
DataDeps = "0.7"
1620
FastAI = "0.5"
1721
FilePathsBase = "0.9"
22+
Flux = "0.12, 0.13"
1823
InlineTest = "0.2"
1924
MLUtils = "0.2"
2025
UnicodePlots = "2, 3"
2126
julia = "1.6"
27+
Zygote = "0.6"

FastTimeSeries/src/FastTimeSeries.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,16 @@ include("blocks/timeseriesrow.jl")
3434
# Encodings
3535
include("encodings/tspreprocessing.jl")
3636

37+
# Models
38+
include("models/Models.jl")
39+
include("models.jl")
40+
3741
include("container.jl")
3842
include("recipes.jl")
3943

4044
const _tasks = Dict{String, Any}()
4145
include("tasks/classification.jl")
46+
include("tasks/regression.jl")
4247

4348
function __init__()
4449
_registerrecipes()
@@ -50,5 +55,5 @@ function __init__()
5055
end
5156

5257
export
53-
TimeSeriesRow, TSClassificationSingle, TSPreprocessing
58+
TimeSeriesRow, TSClassificationSingle, TSPreprocessing, TSRegression
5459
end

FastTimeSeries/src/container.jl

Lines changed: 117 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ function _ts2df(
5454

5555
timestamps = false
5656
class_labels = false
57+
target_labels = true
5758

5859
open(full_file_path_and_name, "r") do file
5960
for ln in eachline(file)
@@ -109,6 +110,21 @@ function _ts2df(
109110

110111
series_length = parse(Int, tokens[2])
111112

113+
elseif startswith(ln, "@dimension")
114+
# Check that the associated value is valid
115+
tokens = split(ln, " ")
116+
117+
num_dimensions = parse(Int, tokens[2])
118+
119+
elseif startswith(ln, "@targetlabel")
120+
tokens = split(ln, " ")
121+
122+
if tokens[2] == "true"
123+
target_labels = true
124+
else
125+
target_labels = false
126+
end
127+
112128
elseif startswith(ln, "@classlabel")
113129
# Check that the associated value is valid
114130
tokens = split(ln, " ")
@@ -150,7 +166,106 @@ function _ts2df(
150166
# Check if we dealing with data that has timestamps
151167

152168
if timestamps
153-
#! Need To Add Code.
169+
170+
has_another_value = false
171+
has_another_dimension = false
172+
173+
timestamps_for_dimension = []
174+
values_for_dimension = []
175+
176+
line_len = length(ln)
177+
char_num = 1
178+
num_this_dimension = 1
179+
arr = Array{Float32, 2}(undef, num_dimensions, series_length)
180+
181+
while char_num <= line_len
182+
183+
# Move through any spaces.
184+
while char_num <= line_len && isspace(ln[char_num])
185+
char_num += 1
186+
end
187+
188+
if char_num <= line_len
189+
190+
# Check if we have reached a class label
191+
if ln[char_num] != '(' && target_labels
192+
193+
class_val = strip(ln[char_num:end], ' ')
194+
195+
push!(class_val_list, parse(Float32, class_val))
196+
push!(instance_list, arr)
197+
198+
char_num = line_len
199+
200+
has_another_value = false
201+
has_another_dimension = false
202+
203+
timestamps_for_dimension = []
204+
values_for_dimension = []
205+
206+
char_num += 1
207+
num_this_dimension = 1
208+
arr = Array{Float32, 2}(undef, num_dimensions, series_length)
209+
210+
else
211+
212+
char_num += 1
213+
tuple_data = ""
214+
215+
while (char_num <= line_len && ln[char_num] != ')')
216+
tuple_data *= ln[char_num]
217+
char_num += 1
218+
end
219+
220+
char_num += 1
221+
222+
while char_num <= line_len && isspace(ln[char_num])
223+
char_num += 1
224+
end
225+
226+
# Check if there is another value or dimension to process after this tuple.
227+
if char_num > line_len
228+
has_another_value = false
229+
has_another_dimension = false
230+
elseif ln[char_num] == ','
231+
has_another_value = true
232+
has_another_dimension = false
233+
elseif ln[char_num] == ':'
234+
has_another_value = false
235+
has_another_dimension = true
236+
end
237+
238+
char_num += 1
239+
240+
last_comma_index = findlast(",", tuple_data)
241+
242+
if !isnothing(last_comma_index)
243+
last_comma_index = last_comma_index[1]
244+
end
245+
246+
value = tuple_data[last_comma_index+1:end]
247+
value = parse(Float32, value)
248+
249+
timestamp = tuple_data[1:last_comma_index-1]
250+
251+
push!(values_for_dimension, value)
252+
253+
if !has_another_value
254+
255+
arr[num_this_dimension, 1:end] = values_for_dimension
256+
257+
values_for_dimension = []
258+
259+
num_this_dimension += 1
260+
end
261+
262+
end
263+
264+
end
265+
266+
end
267+
268+
154269
else
155270
dimensions = split(ln, ':')
156271

@@ -196,13 +311,9 @@ function _ts2df(
196311
data_series = split(dimension, ',')
197312
data_series = [parse(Float32, i) for i in data_series]
198313
arr[dim, 1:end] = data_series
199-
# println(data_series)
200-
# data_series = [parse(Float32, i) for i in data_series]
201-
# push!(instance_list[dim], data_series)
202314
else
203315
tmp = Array{Float32, 1}(undef, 100)
204316
arr[dim, 1:end] = tmp
205-
# push!(instance_list[dim], [])
206317
end
207318
end
208319

@@ -229,7 +340,7 @@ function _ts2df(
229340
end
230341

231342
# Check if we should return any associated class labels separately
232-
if class_labels
343+
if class_labels || target_labels
233344
return data, class_val_list
234345
else
235346
return data

FastTimeSeries/src/models.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
2+
"""
3+
blockmodel(inblock::TimeSeriesRow, outblock::OneHotTensor{0}, backbone)
4+
5+
Construct a model for time-series classification.
6+
"""
7+
function blockmodel(inblock::TimeSeriesRow,
8+
outblock::OneHotTensor{0},
9+
backbone)
10+
data = rand(Float32, inblock.nfeatures, 32, inblock.obslength)
11+
# data = [rand(Float32, inblock.nfeatures, 32) for _ ∈ 1:inblock.obslength]
12+
output = backbone(data)
13+
return Models.RNNModel(backbone, outsize = length(outblock.classes), recout = size(output, 1))
14+
end
15+
16+
"""
17+
blockbackbone(inblock::TimeSeriesRow)
18+
19+
Construct a recurrent backbone
20+
"""
21+
function blockbackbone(inblock::TimeSeriesRow)
22+
Models.StackedLSTM(inblock.nfeatures, 16, 10, 2);
23+
end
24+
25+
# ## Tests
26+
27+
@testset "blockbackbone" begin @test_nowarn FastAI.blockbackbone(TimeSeriesRow(1,140)) end

FastTimeSeries/src/models/Models.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module Models
2+
3+
using ..FastAI
4+
5+
using Flux
6+
using Zygote
7+
using DataDeps
8+
using InlineTest
9+
10+
include("StackedLSTM.jl")
11+
include("RNN.jl")
12+
13+
export StackedLSTM, RNNModel
14+
15+
end

FastTimeSeries/src/models/RNN.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
function tabular2rnn(X::AbstractArray{Float32, 3})
2+
X = permutedims(X, (1, 3, 2))
3+
return X
4+
end
5+
6+
struct RNNModel{A, B}
7+
recbackbone::A
8+
finalclassifier::B
9+
end
10+
11+
"""
12+
RNNModel(recbackbonem, outsize, recout[; kwargs...])
13+
14+
Creates a RNN model from the recurrent 'recbackbone' architecture. The output from this backbone
15+
is passed through a dropout layer before a 'finalclassifier' block.
16+
17+
## Keyword arguments.
18+
19+
- `outsize`: The output size of the final classifier block. For single classification tasks,
20+
this would be the number of classes.
21+
- `recout`: The output size of the `recbackbone` architecture.
22+
- `dropout_rate`: Dropout probability for the dropout layer.
23+
"""
24+
25+
function RNNModel(recbackbone;
26+
outsize,
27+
recout,
28+
kwargs...)
29+
return RNNModel(recbackbone, Dense(recout, outsize))
30+
end
31+
32+
function (m::RNNModel)(X)
33+
X = tabular2rnn(X)
34+
Flux.reset!(m.recbackbone)
35+
X = m.recbackbone(X)[:, :, end]
36+
return m.finalclassifier(X)
37+
end
38+
39+
Flux.@functor RNNModel
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""
2+
StackedLSTM(in, out, hiddensize, layers)
3+
4+
Stacked LSTM network. Feeds the data through a chain of LSTM layers, where the hidden state
5+
of the previous layer gets fed to the next one. The first layer corresponds to
6+
`LSTM(in, hiddensize)`, the hidden layers to `LSTM(hiddensize, hiddensize)`, and the final
7+
layer to `LSTM(hiddensize, out)`. Takes the keyword argument `init` for the initialization
8+
of the layers.
9+
10+
"""
11+
function StackedLSTM(in::Int, out::Integer, hiddensize::Integer, layers::Integer;
12+
init=Flux.glorot_uniform)
13+
if layers == 1
14+
chain = Chain(LSTM(in, out; init=init))
15+
elseif layers == 2
16+
chain = Chain(LSTM(in, hiddensize; init=init),
17+
LSTM(hiddensize, out; init=init))
18+
else
19+
chain_vec = [LSTM(in, hiddensize; init=init)]
20+
for i = 1:layers - 2
21+
push!(chain_vec, LSTM(hiddensize, hiddensize; init=init))
22+
end
23+
chain = Chain(chain_vec..., LSTM(hiddensize, out; init=init))
24+
end
25+
return chain
26+
end

FastTimeSeries/src/models/layers.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""
2+
GAP1d(output_size)
3+
4+
Create a Global Adaptive Pooling + Flatten layer.
5+
"""
6+
function GAP1d(output_size::Int)
7+
gap = AdaptiveMeanPool((output_size,))
8+
Chain(gap, Flux.flatten)
9+
end

0 commit comments

Comments
 (0)