Skip to content

Commit 2856227

Browse files
guilhermebodinraphaelsaavedra
authored andcommitted
Fix initial parameters logic (#41)
* re-add distribution links * missed link for Weibull * cleaner solution
1 parent 5930d2a commit 2856227

File tree

11 files changed

+76
-20
lines changed

11 files changed

+76
-20
lines changed

src/distributions/beta.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ function log_likelihood(::Type{Beta}, y::Vector{T}, param::Vector{Vector{T}}, n:
3434
end
3535

3636
# Links
37+
function link(::Type{Beta}, param::Vector{T}) where T
38+
return [
39+
link(LogLink, param[1], zero(T));
40+
link(LogLink, param[2], zero(T))
41+
]
42+
end
3743
function unlink(::Type{Beta}, param_tilde::Vector{T}) where T
3844
return [
3945
unlink(LogLink, param_tilde[1], zero(T));

src/distributions/common_interface.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ end
1717
function fisher_information(D::Type{<:Distribution}, param::Vector{T}) where T
1818
return error("fisher_information not implemented for $D distribution")
1919
end
20+
function link(D::Type{<:Distribution}, param::Vector{T}) where T
21+
return error("link not implemented for $D distribution")
22+
end
2023
function unlink(D::Type{<:Distribution}, param_tilde::Vector{T}) where T
2124
return error("unlink not implemented for $D distribution")
2225
end

src/distributions/gamma.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ function log_likelihood(::Type{Gamma}, y::Vector{T}, param::Vector{Vector{T}}, n
3131
end
3232

3333
# Links
34+
function link(::Type{Gamma}, param::Vector{T}) where T
35+
return [
36+
link(LogLink, param[1], zero(T));
37+
link(LogLink, param[2], zero(T))
38+
]
39+
end
3440
function unlink(::Type{Gamma}, param_tilde::Vector{T}) where T
3541
return [
3642
unlink(LogLink, param_tilde[1], zero(T));

src/distributions/lognormal.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ function log_likelihood(::Type{LogNormal}, y::Vector{T}, param::Vector{Vector{T}
2828
end
2929

3030
# Links
31+
function link(::Type{LogNormal}, param::Vector{T}) where T
32+
return [
33+
link(IdentityLink, param[1]);
34+
link(LogLink, param[2], zero(T))
35+
]
36+
end
3137
function unlink(::Type{LogNormal}, param_tilde::Vector{T}) where T
3238
return [
3339
unlink(IdentityLink, param_tilde[1]);

src/distributions/normal.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ function log_likelihood(::Type{Normal}, y::Vector{T}, param::Vector{Vector{T}},
3131
end
3232

3333
# Links
34+
function link(::Type{Normal}, param::Vector{T}) where T
35+
return [
36+
link(IdentityLink, param[1]);
37+
link(LogLink, param[2], zero(T))
38+
]
39+
end
3440
function unlink(::Type{Normal}, param_tilde::Vector{T}) where T
3541
return [
3642
unlink(IdentityLink, param_tilde[1]);

src/distributions/poisson.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ function log_likelihood(::Type{Poisson}, y::Vector{Int}, param::Vector{Vector{T}
2424
end
2525

2626
# Links
27+
link(::Type{Poisson}, param::Vector{T}) where T = link.(LogLink, param, zero(T))
2728
unlink(::Type{Poisson}, param_tilde::Vector{T}) where T = unlink.(LogLink, param_tilde, zero(T))
2829
jacobian_link(::Type{Poisson}, param_tilde::Vector{T}) where T = Diagonal(jacobian_link.(LogLink, param_tilde, zero(T)))
2930

src/distributions/weibull.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ function log_likelihood(::Type{Weibull}, y::Vector{T}, param::Vector{Vector{T}},
2828
end
2929

3030
# Links
31+
function link(::Type{Weibull}, param::Vector{T}) where T
32+
return [
33+
link(LogLink, param[1], zero(T));
34+
link(LogLink, param[2], zero(T))
35+
]
36+
end
3137
function unlink(::Type{Weibull}, param_tilde::Vector{T}) where T
3238
return [
3339
unlink(LogLink, param_tilde[1], zero(T));

src/gas/initial_params.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
1-
export stationary_initial_params, dynamic_initial_params
1+
export stationary_initial_params_tilde, stationary_initial_params, dynamic_initial_params
22

33
"""
4-
stationary_initial_params(gas::GAS)
4+
stationary_initial_param_tilde(gas::GAS{D, T}) where {D, T}
55
"""
6-
function stationary_initial_params(gas::GAS)
6+
function stationary_initial_params_tilde(gas::GAS{D, T}) where {D, T}
77
biggest_lag = number_of_lags(gas)
8-
initial_params = Vector{Vector{Float64}}(undef, biggest_lag)
8+
initial_params_tilde = Vector{Vector{T}}(undef, biggest_lag)
99
for i in 1:biggest_lag
10-
initial_params[i] = gas.ω./diag(I - gas.B[1])
10+
initial_params_tilde[i] = gas.ω./diag(I - gas.B[1])
11+
end
12+
return initial_params_tilde
13+
end
14+
15+
"""
16+
stationary_initial_params(gas::GAS{D, T}) where {D, T}
17+
"""
18+
function stationary_initial_params(gas::GAS{D, T}) where {D, T}
19+
biggest_lag = number_of_lags(gas)
20+
initial_params_tilde = Vector{Vector{T}}(undef, biggest_lag)
21+
initial_params = Vector{Vector{T}}(undef, biggest_lag)
22+
for i in 1:biggest_lag
23+
initial_params_tilde[i] = gas.ω./diag(I - gas.B[1])
24+
initial_params[i] = unlink(D, initial_params_tilde[i])
1125
end
1226
return initial_params
1327
end

src/gas/simulate.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ function simulate(gas::GAS{D, T}, n::Int, s::Int) where {D, T}
1111
end
1212

1313
function simulate(gas::GAS{D, T}, n::Int) where {D, T}
14-
initial_param_tilde = stationary_initial_params(gas)
15-
return simulate(gas, n, initial_param_tilde)
14+
initial_params = stationary_initial_params(gas)
15+
return simulate(gas, n, initial_params)
1616
end
1717

18-
function simulate(gas::GAS{D, T}, n::Int, initial_param_tilde::Vector{Vector{T}}) where {D, T}
18+
function simulate(gas::GAS{D, T}, n::Int, initial_params::Vector{Vector{T}}) where {D, T}
1919
# Allocations
2020
serie = zeros(n)
2121
param = Vector{Vector{T}}(undef, n)
@@ -26,8 +26,8 @@ function simulate(gas::GAS{D, T}, n::Int, initial_param_tilde::Vector{Vector{T}}
2626

2727
# initial_values
2828
for i in 1:biggest_lag
29-
param_tilde[i] = initial_param_tilde[i]
30-
param[i] = unlink(D, initial_param_tilde[i])
29+
param[i] = initial_params[i]
30+
param_tilde[i] = link(D, param[i])
3131
# Sample
3232
updated_dist = update_dist(D, unlink(D, param_tilde[i]))
3333
serie[i] = sample_observation(updated_dist)

src/gas/univariate_score_driven_recursion.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,25 @@ score_driven_recursion(sd_model::SDM, observations::Vector{T}) where T
55
66
start with the stationary params for a
77
"""
8-
function score_driven_recursion(gas::GAS, observations::Vector{T}) where T
9-
initial_param_tilde = stationary_initial_params(gas)
10-
return score_driven_recursion(gas, observations, initial_param_tilde)
8+
function score_driven_recursion(gas::GAS{D, T}, observations::Vector{T}) where {D, T}
9+
initial_params = stationary_initial_params(gas)
10+
return score_driven_recursion(gas, observations, initial_params)
1111
end
1212

13-
function score_driven_recursion(gas::GAS{D, T}, observations::Vector{T}, initial_param_tilde::Vector{Vector{T}}) where {D, T}
13+
function score_driven_recursion(gas::GAS{D, T}, observations::Vector{T}, initial_param::Vector{Vector{T}}) where {D, T}
1414
# Allocations
1515
n = length(observations)
1616
param = Vector{Vector{T}}(undef, n + 1)
1717
param_tilde = Vector{Vector{T}}(undef, n + 1)
1818
scores_tilde = Vector{Vector{T}}(undef, n)
1919

2020
# Query the biggest lag
21-
biggest_lag = length(initial_param_tilde)
21+
biggest_lag = number_of_lags(gas)
2222

2323
# initial_values
2424
for i in 1:biggest_lag
25-
param_tilde[i] = initial_param_tilde[i]
26-
param[i] = unlink(D, initial_param_tilde[i])
25+
param[i] = initial_param[i]
26+
param_tilde[i] = link(D, param[i])
2727
scores_tilde[i] = score_tilde(observations[i], D, param[i], param_tilde[i], gas.scaling)
2828
end
2929

0 commit comments

Comments
 (0)