Skip to content

Commit c7a5639

Browse files
committed
working model for regression
1 parent 07b16dc commit c7a5639

File tree

7 files changed

+24
-6
lines changed

7 files changed

+24
-6
lines changed

FastTimeSeries/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["FluxML Community"]
44
version = "0.1.0"
55

66
[deps]
7+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
78
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
89
FastAI = "5d0beca9-ade8-49ae-ad0b-a3cf890e669f"
910
FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f"

FastTimeSeries/src/FastTimeSeries.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ include("blocks/timeseriesrow.jl")
3333

3434
# Encodings
3535
include("encodings/tspreprocessing.jl")
36+
include("encodings/continuouspreprocessing.jl")
3637

3738
# Models
3839
include("models/Models.jl")
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
struct ContinuousPreprocessing <: Encoding
2+
numlabels::Int
3+
end
4+
5+
ContinuousPreprocessing() = ContinuousPreprocessing(1)
6+
7+
decodedblock(c::ContinuousPreprocessing, block::AbstractArray) = Continuous(c.numlabels)
8+
9+
function encode(::ContinuousPreprocessing, _, block::Continuous, obs)
10+
return [obs]
11+
end
12+
13+
function decode(::ContinuousPreprocessing, _, block::AbstractArray, obs)
14+
return obs[1]
15+
end

FastTimeSeries/src/models.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ Construct a model for time-series classification.
77
function blockmodel(inblock::TimeSeriesRow,
88
outblock::OneHotTensor{0},
99
backbone)
10-
#TODO: Use Flux.outputsize here.
11-
data = rand(Float32, inblock.nfeatures, 32, inblock.obslength)
10+
data = zeros(Float32, inblock.nfeatures, 1, 1)
1211
output = backbone(data)
12+
Flux.reset!(backbone)
1313
return Models.RNNModel(backbone, outsize = length(outblock.classes), recout = size(output, 1))
1414
end
1515

FastTimeSeries/src/models/Models.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Flux
66
using Zygote
77
using DataDeps
88
using InlineTest
9+
using ChainRulesCore
910

1011
# include("StackedLSTM.jl")
1112
include("layers.jl")

FastTimeSeries/src/models/RNN.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@ end
2828

2929
function (m::RNNModel)(X)
3030
X = tabular2rnn(X)
31-
Flux.reset!(m.recbackbone)
32-
# ChainRulesCore.ignore_derivatives() do
33-
# Flux.reset!(m.recbackbone)
34-
# end
31+
ChainRulesCore.ignore_derivatives() do
32+
Flux.reset!(m.recbackbone)
33+
end
3534
X = m.recbackbone(X)[:, :, end]
3635
return m.finalclassifier(X)
3736
end

FastTimeSeries/src/tasks/regression.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ function TSRegression(blocks::Tuple{<:TimeSeriesRow, <:Continuous}, data)
88
return SupervisedTask(
99
blocks,
1010
(
11+
ContinuousPreprocessing(),
1112
setup(TSPreprocessing, blocks[1], data[1].table),
1213
),
1314
ŷblock = blocks[2]

0 commit comments

Comments
 (0)