Skip to content

Commit 58816e7

Browse files
author
LuizFCDuarte
committed
🎨 Format code according to space guidelines
1 parent fd9d270 commit 58816e7

File tree

9 files changed

+2074
-814
lines changed

9 files changed

+2074
-814
lines changed

src/datasets.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
@enum Datasets begin
2-
AIR_PASSENGERS=1
3-
GDPC1=2
4-
NROU=3
2+
AIR_PASSENGERS = 1
3+
GDPC1 = 2
4+
NROU = 3
55
end
66

77
const AIR_PASSENGERS = instances(Datasets)[1]
@@ -10,9 +10,9 @@ const NROU = instances(Datasets)[3]
1010
export AIR_PASSENGERS, GDPC1, NROU
1111

1212
datasetsPaths = [
13-
joinpath(dirname(@__DIR__()), "datasets", "airpassengers.csv"),
13+
joinpath(dirname(@__DIR__()), "datasets", "airpassengers.csv"),
1414
joinpath(dirname(@__DIR__()), "datasets", "GDPC1.csv"),
15-
joinpath(dirname(@__DIR__()), "datasets", "NROU.csv")
15+
joinpath(dirname(@__DIR__()), "datasets", "NROU.csv"),
1616
]
1717

1818

@@ -71,12 +71,12 @@ julia> airPassengers = loadDataset(airPassengersDf)
7171
7272
```
7373
"""
74-
function loadDataset(df::DataFrame, showLogs::Bool=false)
74+
function loadDataset(df::DataFrame, showLogs::Bool = false)
7575
auxiliarDF = deepcopy(df)
7676
if !(:date in propertynames(auxiliarDF))
7777
showLogs && @info("The DataFrame does not have a column named 'date'.")
7878
showLogs && @info("Creating a date column with the index of the DataFrame")
79-
auxiliarDF[!,:date] = [Date(i) for i=1:size(auxiliarDF,1)]
79+
auxiliarDF[!, :date] = [Date(i) for i = 1:size(auxiliarDF, 1)]
8080
end
8181
y = TimeArray(auxiliarDF, timestamp = :date)
8282
return y
@@ -90,9 +90,18 @@ end
9090
9191
Splits the time series in training and testing sets.
9292
"""
93-
function splitTrainTest(data::TimeArray; trainPercentage::Fl=0.8) where Fl<:AbstractFloat
94-
trainingSetEndIndex = floor(Int, trainPercentage*length(data))
95-
trainingSet = TimeArray(timestamp(data)[1:trainingSetEndIndex], values(data)[1:trainingSetEndIndex])
96-
testingSet = TimeArray(timestamp(data)[trainingSetEndIndex+1:end], values(data)[trainingSetEndIndex+1:end])
93+
function splitTrainTest(
94+
data::TimeArray;
95+
trainPercentage::Fl = 0.8,
96+
) where {Fl<:AbstractFloat}
97+
trainingSetEndIndex = floor(Int, trainPercentage * length(data))
98+
trainingSet = TimeArray(
99+
timestamp(data)[1:trainingSetEndIndex],
100+
values(data)[1:trainingSetEndIndex],
101+
)
102+
testingSet = TimeArray(
103+
timestamp(data)[trainingSetEndIndex+1:end],
104+
values(data)[trainingSetEndIndex+1:end],
105+
)
97106
return trainingSet, testingSet
98-
end
107+
end

src/datetime_utils.jl

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import Base: copy, deepcopy
2+
13
function Base.copy(y::TimeArray)
2-
return TimeArray(copy(timestamp(y)),copy(values(y)))
4+
return TimeArray(copy(timestamp(y)), copy(values(y)))
35
end
46

57
function Base.deepcopy(y::TimeArray)
6-
return TimeArray(deepcopy(timestamp(y)),deepcopy(values(y)))
8+
return TimeArray(deepcopy(timestamp(y)), deepcopy(values(y)))
79
end
810

911
"""
@@ -22,11 +24,11 @@ An array of DateTime objects.
2224
2325
"""
2426
function buildDatetimes(
25-
startDatetime::T,
26-
granularity::P where P <: Dates.Period,
27-
weekDaysOnly::Bool,
27+
startDatetime::T,
28+
granularity::P where {P<:Dates.Period},
29+
weekDaysOnly::Bool,
2830
datetimesLength::Int,
29-
) where T
31+
) where {T}
3032
if datetimesLength == 0
3133
return DateTime[]
3234
end
@@ -37,7 +39,7 @@ function buildDatetimes(
3739
currentDatetime = startDatetime
3840

3941
# Loop to generate timestamps based on granularity
40-
for _ in 1:datetimesLength
42+
for _ = 1:datetimesLength
4143
if weekDaysOnly && dayofweek(currentDatetime) == 5
4244
currentDatetime += Dates.Day(3)
4345
else
@@ -69,15 +71,15 @@ A tuple `(granularity, frequency, weekdays)` where:
6971
Throws an error if the timestamps do not follow a consistent pattern.
7072
7173
"""
72-
function identifyGranularity(datetimes::Vector{T}) where T
74+
function identifyGranularity(datetimes::Vector{T}) where {T}
7375
# Define base units and periods
7476
baseUnits = [Second(1), Minute(1), Hour(1), Day(1), Week(1)]
7577
basePeriods = [:Second, :Minute, :Hour, :Day, :Week, :Month, :Year]
76-
78+
7779
unitPeriod = nothing
7880
diffBetweenTimestamps = nothing
7981
weekDaysSeries = false
80-
82+
8183
for (i, unit) in enumerate(baseUnits)
8284
differences = diff(datetimes) ./ unit
8385
min_difference = minimum(differences)
@@ -86,13 +88,13 @@ function identifyGranularity(datetimes::Vector{T}) where T
8688
if lessThanOne
8789
break
8890
end
89-
91+
9092
# Check if all elements are equal
9193
regularDistribution = all(differences .== differences[1])
9294
if regularDistribution
9395
unitPeriod = basePeriods[i]
9496
diffBetweenTimestamps = differences[1]
95-
97+
9698
if unit in [Minute(1), Second(1)]
9799
if diffBetweenTimestamps < 60
98100
break
@@ -128,7 +130,7 @@ function identifyGranularity(datetimes::Vector{T}) where T
128130
end
129131
end
130132
end
131-
133+
132134
amplitude = maximum(differences) - min_difference
133135
if amplitude < 1
134136
unitPeriod = basePeriods[i]
@@ -152,10 +154,14 @@ function identifyGranularity(datetimes::Vector{T}) where T
152154
elseif diffBetweenTimestamps % 4 == 0
153155
unitPeriod = :Month
154156
diffBetweenTimestamps = diffBetweenTimestamps / 4
155-
end
157+
end
156158
end
157-
158-
return (granularity=unitPeriod, frequency=diffBetweenTimestamps, weekdays=weekDaysSeries)
159+
160+
return (
161+
granularity = unitPeriod,
162+
frequency = diffBetweenTimestamps,
163+
weekdays = weekDaysSeries,
164+
)
159165
end
160166

161167
"""
@@ -171,7 +177,7 @@ Merge multiple `TimeArray` objects into a single `TimeArray`. The function align
171177
A `TimeArray` object representing the merged time series.
172178
173179
"""
174-
function merge(timeArrayVector::Vector{TimeArray},modelFl::DataType=Float64)
180+
function merge(timeArrayVector::Vector{TimeArray}, modelFl::DataType = Float64)
175181
if length(timeArrayVector) == 0
176182
return TimeArray(DateTime[], [])
177183
end
@@ -192,19 +198,18 @@ function merge(timeArrayVector::Vector{TimeArray},modelFl::DataType=Float64)
192198

193199
newTimeArrayVector = []
194200
for ta in timeArrayVector
195-
newTimeArray = from(ta,initialTimestamp)
196-
newTimeArray = to(newTimeArray,finalTimestamp)
197-
push!(newTimeArrayVector,newTimeArray)
201+
newTimeArray = from(ta, initialTimestamp)
202+
newTimeArray = to(newTimeArray, finalTimestamp)
203+
push!(newTimeArrayVector, newTimeArray)
198204
end
199205

200-
auxiliarDf = DataFrame((:timestamp=>timestamp(newTimeArrayVector[1])))
206+
auxiliarDf = DataFrame((:timestamp => timestamp(newTimeArrayVector[1])))
201207
for ta in newTimeArrayVector
202208
# Add a column with ta colname and values
203209
colname = colnames(ta)[1]
204210
valuesTa::Vector{modelFl} = values(ta)
205-
auxiliarDf[!,colname] = valuesTa
211+
auxiliarDf[!, colname] = valuesTa
206212
end
207213

208-
return TimeArray(auxiliarDf, timestamp=:timestamp)
214+
return TimeArray(auxiliarDf, timestamp = :timestamp)
209215
end
210-

src/exceptions.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
mutable struct ModelNotFitted <: Exception end
2-
Base.showerror(io::IO, e::ModelNotFitted) = print(io, "The model has not been fitted yet. Please run fit!(model)")
2+
Base.showerror(io::IO, e::ModelNotFitted) =
3+
print(io, "The model has not been fitted yet. Please run fit!(model)")
34

4-
mutable struct MissingMethodImplementation <: Exception
5+
mutable struct MissingMethodImplementation <: Exception
56
method::String
67
end
7-
Base.showerror(io::IO, e::MissingMethodImplementation) = print(io, "The model does not implement the ", e.method, " method.")
8+
Base.showerror(io::IO, e::MissingMethodImplementation) =
9+
print(io, "The model does not implement the ", e.method, " method.")
810

911
mutable struct InconsistentDatePattern <: Exception end
10-
Base.showerror(io::IO, e::InconsistentDatePattern) = print(io, "The timestamps do not follow a consistent pattern.")
12+
Base.showerror(io::IO, e::InconsistentDatePattern) =
13+
print(io, "The timestamps do not follow a consistent pattern.")
1114

1215
mutable struct MissingExogenousData <: Exception end
13-
Base.showerror(io::IO, e::MissingExogenousData) = print(io, "There is no exogenous data to forecast the horizon requested")
16+
Base.showerror(io::IO, e::MissingExogenousData) =
17+
print(io, "There is no exogenous data to forecast the horizon requested")
1418

1519
mutable struct InvalidParametersCombination <: Exception
1620
msg::String
1721
end
18-
Base.showerror(io::IO, e::InvalidParametersCombination) = print(io, "The parameters provided are invalid for the model \n", e.msg)
22+
Base.showerror(io::IO, e::InvalidParametersCombination) =
23+
print(io, "The parameters provided are invalid for the model \n", e.msg)

src/fit.jl

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ Calculate the Akaike Information Criterion (AIC) for a given number of parameter
4545
The AIC value calculated using the formula: AIC = 2*K - 2*loglikeVal.
4646
4747
"""
48-
function aic(K::Int, loglikeVal::Fl) where Fl<:AbstractFloat
49-
return 2*K - 2*loglikeVal
48+
function aic(K::Int, loglikeVal::Fl) where {Fl<:AbstractFloat}
49+
return 2 * K - 2 * loglikeVal
5050
end
5151

5252
"""
@@ -63,8 +63,8 @@ Calculate the corrected Akaike Information Criterion (AICc) for a given number o
6363
The AICc value calculated using the formula: AICc = AIC(K, loglikeVal) + ((2*K*K + 2*K) / (T - K - 1)).
6464
6565
"""
66-
function aicc(T::Int, K::Int, loglikeVal::Fl) where Fl<:AbstractFloat
67-
return aic(K, loglikeVal) + ((2*K*K + 2*K) / (T - K - 1))
66+
function aicc(T::Int, K::Int, loglikeVal::Fl) where {Fl<:AbstractFloat}
67+
return aic(K, loglikeVal) + ((2 * K * K + 2 * K) / (T - K - 1))
6868
end
6969

7070
"""
@@ -81,8 +81,8 @@ Calculate the Bayesian Information Criterion (BIC) for a given number of observa
8181
The BIC value calculated using the formula: BIC = log(T) * K - 2 * loglikeVal.
8282
8383
"""
84-
function bic(T::Int, K::Int, loglikeVal::Fl) where Fl<:AbstractFloat
85-
return log(T)*K - 2*loglikeVal
84+
function bic(T::Int, K::Int, loglikeVal::Fl) where {Fl<:AbstractFloat}
85+
return log(T) * K - 2 * loglikeVal
8686
end
8787

8888
"""
@@ -101,15 +101,16 @@ The AIC value calculated using the number of parameters and log-likelihood value
101101
- Throws a `MissingMethodImplementation` if the `getHyperparametersNumber` method is not implemented for the given model type.
102102
103103
"""
104-
function aic(model::SarimaxModel; offset::Fl=0.0) where Fl<:AbstractFloat
105-
!hasHyperparametersMethods(typeof(model)) && throw(MissingMethodImplementation("getHyperparametersNumber"))
104+
function aic(model::SarimaxModel; offset::Fl = 0.0) where {Fl<:AbstractFloat}
105+
!hasHyperparametersMethods(typeof(model)) &&
106+
throw(MissingMethodImplementation("getHyperparametersNumber"))
106107
K = Sarimax.getHyperparametersNumber(model)
107108
# T = length(model.ϵ)
108109
# return aic(K, loglike(model))
109110
# offset = -2 * loglike(model) - length(model.y) * log(model.σ²)
110111
# return offset + T * log(model.σ²) + 2*K
111-
T = length(model.y) - model.d - model.D * model.seasonality
112-
return 2*K + T * log(model.σ²) + offset
112+
T = length(model.y) - model.d - model.D * model.seasonality
113+
return 2 * K + T * log(model.σ²) + offset
113114
end
114115

115116
"""
@@ -128,13 +129,14 @@ The AICc value calculated using the number of parameters, sample size, and log-l
128129
- Throws a `MissingMethodImplementation` if the `getHyperparametersNumber` method is not implemented for the given model type.
129130
130131
"""
131-
function aicc(model::SarimaxModel; offset::Fl=0.0) where Fl<:AbstractFloat
132-
!hasHyperparametersMethods(typeof(model)) && throw(MissingMethodImplementation("getHyperparametersNumber"))
132+
function aicc(model::SarimaxModel; offset::Fl = 0.0) where {Fl<:AbstractFloat}
133+
!hasHyperparametersMethods(typeof(model)) &&
134+
throw(MissingMethodImplementation("getHyperparametersNumber"))
133135
K = getHyperparametersNumber(model)
134136
# T = length(model.ϵ)
135137
# return aicc(T, K, loglike(model))
136-
T = length(model.y) - model.d - model.D * model.seasonality
137-
return aic(model; offset=offset) + ((2*K*K + 2*K) / (T - K - 1))
138+
T = length(model.y) - model.d - model.D * model.seasonality
139+
return aic(model; offset = offset) + ((2 * K * K + 2 * K) / (T - K - 1))
138140
end
139141

140142
"""
@@ -153,11 +155,12 @@ The BIC value calculated using the number of parameters, sample size, and log-li
153155
- Throws a `MissingMethodImplementation` if the `getHyperparametersNumber` method is not implemented for the given model type.
154156
155157
"""
156-
function bic(model::SarimaxModel;offset::Fl=0.0) where Fl<:AbstractFloat
157-
!hasHyperparametersMethods(typeof(model)) && throw(MissingMethodImplementation("getHyperparametersNumber"))
158+
function bic(model::SarimaxModel; offset::Fl = 0.0) where {Fl<:AbstractFloat}
159+
!hasHyperparametersMethods(typeof(model)) &&
160+
throw(MissingMethodImplementation("getHyperparametersNumber"))
158161
K = getHyperparametersNumber(model)
159162
# T = length(model.ϵ)
160163
# return bic(T, K, loglike(model))
161164
T = length(model.y) - model.d - model.D * model.seasonality
162-
return aic(model; offset=offset) + K *(log(T) - 2)
163-
end
165+
return aic(model; offset = offset) + K * (log(T) - 2)
166+
end

0 commit comments

Comments
 (0)