Skip to content

Commit ef35820

Browse files
author
Carlos Parada
authored
Remove LoopVectorization.jl (#47)
* Remove LoopVectorization.jl * Remove FFTW dep * Accidentally added PSIS.jl * Remove FFTW * Fix method overwrite * fix method overwwrite * Remove StatsFuns dependency
1 parent c59fe08 commit ef35820

File tree

12 files changed

+91
-106
lines changed

12 files changed

+91
-106
lines changed

Project.toml

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,26 @@ version = "0.6.6"
55

66
[deps]
77
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
8-
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
9-
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
10-
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
118
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
129
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
13-
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
10+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1411
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
1512
NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f"
16-
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
1713
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1814
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1915
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2016
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2117
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2218
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
23-
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
2419
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
2520

2621
[compat]
2722
AxisKeys = "0.1.18"
28-
Distributions = "0.25.10"
29-
DocStringExtensions = "0.8"
30-
FFTW = "1.4.3"
31-
LoopVectorization = "0.12.37"
3223
MCMCDiagnosticTools = "0.1.0"
3324
NamedDims = "0.2.35"
34-
Polyester = "0.3.4, 0.4, 0.5"
3525
PrettyTables = "1.1.0"
3626
Requires = "1.1.3"
3727
StatsBase = "0.33.10"
38-
StatsFuns = "0.9.9"
3928
Tullio = "0.3.0"
4029
julia = "1.6"
4130

src/ESS.jl

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,30 @@
1-
using FFTW
2-
using LoopVectorization
31
using MCMCDiagnosticTools
4-
52
using Tullio
63

74
export relative_eff, psis_ess, sup_ess
85

96
"""
10-
relative_eff(
11-
sample::AbstractArray{Real, 3};
12-
method=MCMCDiagnosticTools.FFTESSMethod()
13-
)
7+
relative_eff(sample::AbstractArray{<:Real, 3}; [method])
148
159
Calculate the relative efficiency of an MCMC chain, i.e. the effective sample size divided
16-
by the nominal sample size.
10+
by the nominal sample size. If none is provided, use the default method from
11+
MCMCDiagnosticTools.
1712
"""
18-
function relative_eff(
19-
sample::AbstractArray{T, 3}; method=MCMCDiagnosticTools.FFTESSMethod()
20-
) where {T <: Union{Real, Missing}}
13+
function relative_eff(sample::AbstractArray{<:Real, 3}; maxlag=size(sample, 2), kwargs...)
2114
dims = size(sample)
2215
post_sample_size = dims[2] * dims[3]
2316
ess_sample = inv.(permutedims(sample, [2, 1, 3]))
24-
ess, = MCMCDiagnosticTools.ess_rhat(ess_sample; method=method, maxlag=dims[2])
17+
ess, = MCMCDiagnosticTools.ess_rhat(ess_sample; maxlag=maxlag, kwargs...)
2518
r_eff = ess / post_sample_size
2619
return r_eff
2720
end
2821

2922

3023
"""
3124
function psis_ess(
32-
weights::AbstractVector{T},
33-
r_eff::AbstractVector{T}
34-
) -> AbstractVector{T}
25+
weights::AbstractVector{<:Real},
26+
r_eff::AbstractVector{<:Real}
27+
) -> AbstractVector{<:Real}
3528
3629
Calculate the (approximate) effective sample size of a PSIS sample, using the correction in
3730
Vehtari et al. 2019. This uses the variance-based definition of ESS, and measures the L2
@@ -44,23 +37,15 @@ distance of the proposal and target distributions.
4437
4538
See `?relative_eff` to calculate `r_eff`.
4639
"""
47-
function psis_ess(
48-
weights::AbstractVector{T}, r_eff::AbstractVector{T}
49-
) where {T <: Union{Real, Missing}}
50-
@tullio sum_of_squares := weights[x]^2
51-
return @turbo r_eff ./ sum_of_squares
52-
end
53-
54-
5540
function psis_ess(
5641
weights::AbstractMatrix{T}, r_eff::AbstractVector{T}
5742
) where {T <: Union{Real, Missing}}
5843
@tullio sum_of_squares[x] := weights[x, y]^2
59-
return @turbo r_eff ./ sum_of_squares
44+
return r_eff ./ sum_of_squares
6045
end
6146

6247

63-
function psis_ess(weights::AbstractMatrix{<:Union{Real, Missing}})
48+
function psis_ess(weights::AbstractMatrix{<:Real})
6449
@warn "PSIS ESS not adjusted based on MCMC ESS. MCSE and ESS estimates " *
6550
"will be overoptimistic if samples are autocorrelated."
6651
return psis_ess(weights, ones(size(weights)))
@@ -69,8 +54,8 @@ end
6954

7055
"""
7156
function sup_ess(
72-
weights::AbstractVector{T},
73-
r_eff::AbstractVector{T}
57+
weights::AbstractVector{<:Real},
58+
r_eff::AbstractVector{<:Real}
7459
) -> AbstractVector
7560
7661
Calculate the supremum-based effective sample size of a PSIS sample, i.e. the inverse of the
@@ -79,10 +64,11 @@ L-∞ norm.
7964
8065
# Arguments
8166
- `weights`: A set of importance sampling weights derived from PSIS.
82-
- `r_eff`: The relative efficiency of the MCMC chains from which PSIS samples were derived.
67+
- `r_eff`: The relative efficiency of the MCMC chains from which PSIS samples were
68+
derived.
8369
"""
8470
function sup_ess(
8571
weights::AbstractMatrix{T}, r_eff::V
86-
) where {T<:Union{Real, Missing}, V<:AbstractVector{T}}
87-
return @turbo inv.(dropdims(maximum(weights; dims=2); dims=2)) .* r_eff
72+
) where {T<:Real, V<:AbstractVector{T}}
73+
return inv.(dropdims(maximum(weights; dims=2); dims=2)) .* r_eff
8874
end

src/GPD.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using LinearAlgebra
2-
using LoopVectorization
2+
using LogExpFunctions
33
using Statistics
44
using Tullio
55

@@ -45,22 +45,23 @@ function gpdfit(
4545

4646
grid_size = min_grid_pts + isqrt(len) # isqrt = floor sqrt
4747
n_0 = 10 # determines how strongly to nudge ξ towards .5
48-
x_star::T = inv(3 * sample[(len + 2) ÷ 4]) # magic number. ¯\_(ツ)_/¯
49-
48+
x_star = inv(3 * sample[(len + 2) ÷ 4]) # magic number. ¯\_(ツ)_/¯
49+
invmax = inv(sample[len])
5050

5151
# build pointwise estimates of ξ and θ at each grid point
5252
θ_hats = similar(sample, grid_size)
53-
ξ_hats = similar(sample, grid_size)
54-
invmax = inv(sample[len])
55-
@tullio threads=false θ_hats[i] = invmax + (1 - sqrt((grid_size + 1) / i)) * x_star
56-
@tullio threads=false ξ_hats[i] = log1p(-θ_hats[i] * sample[j]) |> _ / len
57-
log_like = similar(ξ_hats)
53+
@fastmath @. θ_hats = invmax + (1 - sqrt((grid_size + 1) / $(1:grid_size))) * x_star
54+
@tullio threads=false ξ_hats[i] := log1p(-θ_hats[i] * sample[j])
55+
ξ_hats /= len
56+
57+
log_like = ξ_hats # Reuse preallocated array
5858
# Calculate profile log-likelihood at each estimate:
59-
@tullio threads=false log_like[i] =
59+
@tullio threads=false ξ_hats[i] =
6060
len * (log(-θ_hats[i] / ξ_hats[i]) - ξ_hats[i] - 1)
6161
# Calculate weights from log-likelihood:
62-
weights = ξ_hats # Reuse preallocated array
63-
@tullio threads=false weights[y] = exp(log_like[x] - log_like[y]) |> inv
62+
weights = log_like # Reuse preallocated array
63+
log_norm = logsumexp(log_like)
64+
@tullio threads=false log_like[x] = exp(log_like[x] - log_norm)
6465
# Take weighted mean:
6566
@tullio threads=false θ_hat := weights[x] * θ_hats[x]
6667
@tullio threads=false ξ := log1p(-θ_hat * sample[i])
@@ -72,7 +73,7 @@ function gpdfit(
7273
@fastmath ξ =* len + 0.5 * n_0) / (len + n_0)
7374
end
7475

75-
return ξ::T, σ::T
76+
return ξ, σ
7677

7778
end
7879

src/ImportanceSampling.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using LoopVectorization
2-
using StatsBase
31
using Tullio
42

53
const LIKELY_ERROR_CAUSES = """
@@ -118,8 +116,9 @@ function psis(
118116
r_eff = _generate_r_eff(log_ratios, dims, r_eff, source)
119117
weights = similar(log_ratios)
120118
# Shift ratios by maximum to prevent overflow
121-
@tturbo @. weights = exp(log_ratios - $maximum(log_ratios; dims=2))
119+
@. weights = exp(log_ratios - $maximum(log_ratios; dims=2))
122120

121+
r_eff = _generate_r_eff(weights, dims, r_eff, source)
123122
_check_input_validity_psis(reshape(log_ratios, dims), r_eff)
124123

125124
tail_length = Vector{Int}(undef, data_size)
@@ -130,7 +129,7 @@ function psis(
130129
end
131130

132131
@tullio norm_const[i] := weights[i, j]
133-
@tturbo weights .= weights ./ norm_const
132+
@. weights = weights / norm_const
134133
ess = psis_ess(weights, r_eff)
135134
inf_ess = sup_ess(weights, r_eff)
136135

@@ -151,9 +150,10 @@ end
151150

152151
function psis(
153152
log_ratios::AbstractMatrix{<:Real};
154-
chain_index::AbstractVector{<:Integer}=_assume_one_chain(log_ratios),
153+
chain_index::AbstractVector=_assume_one_chain(log_ratios),
155154
kwargs...,
156155
)
156+
chain_index = Vector(Int.(chain_index))
157157
new_log_ratios = _convert_to_array(log_ratios, chain_index)
158158
return psis(new_log_ratios; kwargs...)
159159
end
@@ -166,7 +166,7 @@ Do PSIS on a single vector, smoothing its tail values.
166166
167167
# Arguments
168168
169-
- `is_ratios::AbstractVector{Real}`: A vector of importance sampling ratios,
169+
- `is_ratios::AbstractVector{<:Real}`: A vector of importance sampling ratios,
170170
scaled to have a maximum of 1.
171171
172172
# Returns
@@ -219,11 +219,11 @@ function _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T <: Real
219219
if any(isinf.(tail))
220220
return ξ = Inf
221221
else
222-
@turbo @. tail = tail - cutoff
222+
@. tail = tail - cutoff
223223

224224
# save time not sorting since tail is already sorted
225225
ξ, σ = gpdfit(tail)
226-
@turbo @. tail = gpd_quantile(($(1:len) - 0.5) / len, ξ, σ) + cutoff
226+
@. tail = gpd_quantile(($(1:len) - 0.5) / len, ξ, σ) + cutoff
227227
end
228228
return ξ
229229
end

src/InternalHelpers.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
const CHAIN_INDEX_DOC = """
2-
`chain_index::Vector`: An optional vector of integers specifying which chain each step
2+
`chain_index::Vector{Int}`: An optional vector of integers specifying which chain each step
33
belongs to. For instance, `chain_index[step]` should return `2` if `log_likelihood[:, step]`
44
belongs to the second chain.
55
"""
@@ -14,9 +14,10 @@ of that point. This function must take the form `f(θ[1], ..., θ[n], data)`, wh
1414
parameter vector. See also the `splat` keyword argument.
1515
"""
1616

17-
const LIKELIHOOD_ARRAY_ARG = """
17+
const LOG_LIK_ARR = """
1818
`log_likelihood::Array`: A matrix or 3d array of log-likelihood values indexed as
19-
`[data, step, chain]`. See the `chain_index` argument if leaving the `chain` index off.
19+
`[data, step, chain]`. The chain argument can be left off if `chain_index` is provided
20+
or if all posterior samples were drawn from a single chain.
2021
"""
2122

2223
const R_EFF_DOC = """

src/LeaveOneOut.jl

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
using AxisKeys
2-
using Distributions
32
using InteractiveUtils
4-
using LoopVectorization
53
using NamedDims
64
using Statistics
75
using Printf
@@ -90,16 +88,16 @@ end
9088

9189
"""
9290
function psis_loo(
93-
log_likelihood::Array{Real} [, args...];
94-
[, chain_index::Vector{Integer}, kwargs...]
91+
log_likelihood::AbstractArray{<:Real} [, args...];
92+
[, chain_index::Vector{Int}, kwargs...]
9593
) -> PsisLoo
9694
9795
Use Pareto-Smoothed Importance Sampling to calculate the leave-one-out cross validation
9896
score.
9997
10098
# Arguments
10199
102-
- $LIKELIHOOD_ARRAY_ARG
100+
- $LOG_LIK_ARR
103101
- $ARGS [`psis`](@ref).
104102
- $CHAIN_INDEX_DOC
105103
- $KWARGS [`psis`](@ref).
@@ -110,30 +108,31 @@ function psis_loo(log_likelihood::AbstractArray{<:Real, 3}, args...; kwargs...)
110108
psis_object = psis(-log_likelihood, args...; kwargs...)
111109
return loo_from_psis(log_likelihood, psis_object)
112110
end
113-
psis_loo
111+
114112

115113
function psis_loo(
116114
log_likelihood::AbstractMatrix{<:Real},
117115
args...;
118116
chain_index::AbstractVector=_assume_one_chain(log_likelihood),
119117
kwargs...,
120118
)
119+
chain_index = Int.(chain_index)
121120
new_log_ratios = _convert_to_array(log_likelihood, chain_index)
122121
return psis_loo(new_log_ratios, args...; kwargs...)
123122
end
124123

125124

126125
"""
127126
loo_from_psis(
128-
log_likelihood::AbstractArray, psis_object::Psis;
129-
chain_index::AbstractVector{Integer}
127+
log_likelihood::AbstractArray{<:Real}, psis_object::Psis;
128+
chain_index::Vector{<:Integer}
130129
)
131130
132131
Use a precalculated `Psis` object to estimate the leave-one-out cross validation score.
133132
134133
# Arguments
135134
136-
- $LIKELIHOOD_ARRAY_ARG
135+
- $LOG_LIK_ARR
137136
- `psis_object`: A precomputed `Psis` object used to estimate the LOO-CV score.
138137
- $CHAIN_INDEX_DOC
139138
@@ -171,9 +170,8 @@ function loo_from_psis(log_likelihood::AbstractArray{<:Real, 3}, psis_object::Ps
171170
table = _generate_loo_table(pointwise)
172171

173172
gmpd = exp.(table(column=:mean, statistic=:cv_elpd))
174-
@tullio mcse := pointwise_mcse[i]^2
175-
mcse = sqrt(mcse)
176173

174+
mcse = sum(abs2, pointwise_mcse) |> sqrt
177175
return PsisLoo(table, pointwise, psis_object, gmpd, mcse)
178176
end
179177

@@ -182,6 +180,7 @@ function loo_from_psis(
182180
log_likelihood::AbstractMatrix{<:Real}, psis_object::Psis, args...;
183181
chain_index::AbstractVector=_assume_one_chain(log_likelihood), kwargs...
184182
)
183+
chain_index = Int.(chain_index)
185184
new_log_ratios = _convert_to_array(log_likelihood, chain_index)
186185
return loo_from_psis(new_log_ratios, psis_object, args...; kwargs...)
187186
end
@@ -227,9 +226,9 @@ function _calc_mcse(weights, log_likelihood, pointwise_loo, r_eff)
227226
pointwise_gmpd = exp.(pointwise_loo)
228227
@tullio pointwise_var[i] :=
229228
(weights[i, j, k] * (exp(log_likelihood[i, j, k]) - pointwise_gmpd[i]))^2
230-
# If MCMC draws follow a log-normal distribution, we can use method of moments to est
231-
# the standard deviation of their log:
232-
@turbo @. pointwise_var = log1p(pointwise_var / pointwise_gmpd^2)
229+
# If MCMC draws follow a log-normal distribution, then their log has this std. error:
230+
@. pointwise_var = log1p(pointwise_var / pointwise_gmpd^2)
231+
# (google "log-normal method of moments" for a proof)
233232
# apply MCMC correlation correction:
234-
return @turbo @. sqrt(pointwise_var / r_eff)
233+
return @. sqrt(pointwise_var / r_eff)
235234
end

src/ModelComparison.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using StatsFuns
2-
using LoopVectorization
31
import Base.show
42

53
export loo_compare, ModelComparison
@@ -47,7 +45,7 @@ end
4745

4846
"""
4947
function loo_compare(
50-
cv_results::PsisLoo...;
48+
cv_results...;
5149
sort_models::Bool=true,
5250
best_to_worst::Bool=true,
5351
[, model_names::Tuple{Symbol}]
@@ -111,13 +109,13 @@ function loo_compare(
111109
se_total = NamedTuple{name_tuple}(se_total)
112110

113111
log_norm = logsumexp(cv_elpd)
114-
weights = @turbo warn_check_args=false @. exp(cv_elpd - log_norm)
112+
weights = @. exp(cv_elpd - log_norm)
115113

116-
gmpd = @turbo @. exp(cv_elpd / data_size)
114+
gmpd = @. exp(cv_elpd / data_size)
117115
gmpd = NamedTuple{name_tuple}(gmpd)
118116

119-
@turbo warn_check_args=false @. cv_elpd = cv_elpd - cv_elpd[1]
120-
@turbo warn_check_args=false avg_elpd = cv_elpd ./ data_size
117+
@. cv_elpd = cv_elpd - cv_elpd[1]
118+
avg_elpd = cv_elpd ./ data_size
121119
total_diffs = KeyedArray(
122120
hcat(cv_elpd, avg_elpd, weights);
123121
model=model_names,

0 commit comments

Comments
 (0)