-
Notifications
You must be signed in to change notification settings - Fork 433
Generalize JointOrderStatistics to discrete distributions #2038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sethaxen
wants to merge
9
commits into
JuliaStats:master
Choose a base branch
from
sethaxen:joint_order_stats_discrete
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+465
−34
Open
Changes from 7 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
008f8fc
feat!: Allow discrete JointOrderStatistics
sethaxen 1af67b5
feat: Generalize rand to work with discrete distributions
sethaxen 7cea817
reformat: Wrap long signatures
sethaxen a06e135
test: Restrict rand tests to continuous
sethaxen 4158ea5
test: Test rand for discrete joint order statistics
sethaxen 8f127fa
feat: Add logpdf for discrete joint order stats
sethaxen 2906996
test: Test logpdf for discrete joint order stats
sethaxen 9678f81
feat: Use fused vec-mat-mat product to speed up logpdf eval
sethaxen 6d0c78f
Merge branch 'master' into joint_order_stats_discrete
sethaxen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,13 +3,13 @@ | |
| # A first course in order statistics. Society for Industrial and Applied Mathematics, 2008. | ||
|
|
||
| """ | ||
| JointOrderStatistics <: ContinuousMultivariateDistribution | ||
| JointOrderStatistics <: MultivariateDistribution | ||
|
|
||
| The joint distribution of a subset of order statistics from a sample from a continuous | ||
| The joint distribution of a subset of order statistics from a sample from a | ||
| univariate distribution. | ||
|
|
||
| JointOrderStatistics( | ||
| dist::ContinuousUnivariateDistribution, | ||
| dist::UnivariateDistribution, | ||
| n::Int, | ||
| ranks=Base.OneTo(n); | ||
| check_args::Bool=true, | ||
|
|
@@ -35,13 +35,15 @@ JointOrderStatistics(Cauchy(), 10, (1, 10)) # joint distribution of only the ex | |
| ``` | ||
| """ | ||
| struct JointOrderStatistics{ | ||
| D<:ContinuousUnivariateDistribution,R<:Union{AbstractVector{Int},Tuple{Int,Vararg{Int}}} | ||
| } <: ContinuousMultivariateDistribution | ||
| D<:UnivariateDistribution, | ||
| R<:Union{AbstractVector{Int},Tuple{Int,Vararg{Int}}}, | ||
| S<:ValueSupport, | ||
| } <: MultivariateDistribution{S} | ||
| dist::D | ||
| n::Int | ||
| ranks::R | ||
| function JointOrderStatistics( | ||
| dist::ContinuousUnivariateDistribution, | ||
| dist::UnivariateDistribution, | ||
| n::Int, | ||
| ranks::Union{AbstractVector{Int},Tuple{Int,Vararg{Int}}}=Base.OneTo(n); | ||
| check_args::Bool=true, | ||
|
|
@@ -55,7 +57,7 @@ struct JointOrderStatistics{ | |
| "`ranks` must be a sorted vector or tuple of unique integers between 1 and `n`.", | ||
| ), | ||
| ) | ||
| return new{typeof(dist),typeof(ranks)}(dist, n, ranks) | ||
| return new{typeof(dist),typeof(ranks),value_support(typeof(dist))}(dist, n, ranks) | ||
| end | ||
| end | ||
|
|
||
|
|
@@ -91,7 +93,9 @@ partype(d::JointOrderStatistics) = partype(d.dist) | |
| Base.eltype(::Type{<:JointOrderStatistics{D}}) where {D} = Base.eltype(D) | ||
| Base.eltype(d::JointOrderStatistics) = eltype(d.dist) | ||
|
|
||
| function logpdf(d::JointOrderStatistics, x::AbstractVector{<:Real}) | ||
| function logpdf( | ||
| d::JointOrderStatistics{<:ContinuousUnivariateDistribution}, x::AbstractVector{<:Real} | ||
| ) | ||
| n = d.n | ||
| ranks = d.ranks | ||
| lp = loglikelihood(d.dist, x) | ||
|
|
@@ -125,6 +129,264 @@ function _marginalize_range(dist, i, j, xᵢ, xⱼ, T) | |
| return k * T(logdiffcdf(dist, xⱼ, xᵢ)) - loggamma(T(k + 1)) | ||
| end | ||
|
|
||
| # discrete case | ||
| # for y=unique(x), with known counts c, m=length(y), and parameters θ, the PMF is | ||
| # P(y,c|θ) = \sum_{d \in D(n, c)) P(d|n,p), where (taking y_0 = -Inf and y_{m+1} = Inf) | ||
| # - d_{2k}: the number of entries equal to y_k | ||
| # - d_{2k-1}: the number of entries in (y_k and y_{k-1}) | ||
| # - p_{2k}: the probability of a draw equal to y_k (P(y_k|θ)) | ||
| # - p_{2k-1}: the probability of a draw falling in (y_k, y_{k-1}) (P(y_k < x < y_{k-1}|θ)) | ||
| # - D(n, c): the set of all weak 2m+1-compositions d of n (i.e. sum(d)=n) constrained by d_{2k} >= c_k | ||
| # - P(d|n,p)=Multinomial(d|n,p) | ||
| # | ||
| # The sum marginalizes over all possible count vectors d that satisfy the constraints implied by y and c. | ||
| # It's here computed efficiently as a product of Hankel matrices; since a Hankel matrix-vector product is | ||
| # equivalent to a discrete cross-correlation, we instead construct the defining sequences of the | ||
| # Hankel matrices and compute the cross-correlations in log-space. | ||
| function logpdf( | ||
| d::JointOrderStatistics{<:DiscreteUnivariateDistribution}, x::AbstractVector{<:Real} | ||
| ) | ||
| (; n, ranks) = d | ||
| udist = d.dist | ||
|
|
||
| if length(ranks) == 1 | ||
| return logpdf(OrderStatistic(udist, n, first(ranks); check_args=false), first(x)) | ||
| end | ||
|
|
||
| y, rank_ranges = _rle_ranks(x, ranks) | ||
|
|
||
| if sum(length, rank_ranges) == n # no gaps => all values are either observed or fixed by rank constraints | ||
| # logpdf for Multinomial distribution over whole (potentially infinite) support | ||
| lp = _log_hankel_base(n, Iterators.map(Base.Fix1(logpdf, udist), y), rank_ranges) | ||
| issorted(x) && return lp | ||
| return oftype(lp, -Inf) | ||
| end | ||
|
|
||
| log_tie_probs = logpdf.(Ref(udist), y) | ||
| gap_lengths = _gap_lengths(n, rank_ranges) | ||
| lp = _log_hankel_base(n, log_tie_probs, rank_ranges) | ||
|
|
||
| # allocate workspaces | ||
| max_gap_length = maximum(gap_lengths) | ||
| max_total_gap_length = @views maximum(sum, zip(gap_lengths, gap_lengths[2:end])) | ||
| T = eltype(lp) | ||
| logh_work = similar(x, T, max_total_gap_length + 1) # defining sequence for Hankel matrices of log-multinomial factors | ||
| logv_work = similar(x, T, max_gap_length + 1) # logsumexp of log-multinomial factors from left | ||
| logc_work = similar(x, T, max_gap_length + 1) # intermediate vector for log-cross-correlation | ||
| init_state = (; logv_work, logc_work) | ||
|
|
||
| _log_hankel_product_init!(init_state, udist, y, gap_lengths, log_tie_probs) | ||
| (op!) = _make_log_hankel_product_op(logh_work, udist, y, rank_ranges, gap_lengths, log_tie_probs) | ||
| final_state = foldl(op!, eachindex(y, log_tie_probs, rank_ranges); init=init_state) | ||
| lp += first(final_state.logv_work) | ||
| return lp | ||
| end | ||
|
|
||
| function _log_hankel_base(n, log_probs, rank_ranges) | ||
| lp = sum(zip(log_probs, rank_ranges)) do (lp_i, range_i) | ||
| num_ties_i = length(range_i) | ||
| isone(num_ties_i) && return lp_i | ||
| num_ties_i * lp_i - loggamma(oftype(lp_i, num_ties_i + 1)) | ||
| end | ||
| return lp + loggamma(oftype(lp, n + 1)) | ||
| end | ||
|
|
||
| function _log_hankel_product_init!(state, udist, y, gap_lengths, log_tie_probs) | ||
| (; logv_work) = state | ||
| T = eltype(logv_work) | ||
| # initiate recurrence for left-flanking gap | ||
| gap_length_left = gap_lengths[1] | ||
| if gap_length_left == 0 | ||
| logv_work[begin] = 0 | ||
| else | ||
| log_gap_prob = logsubexp(T(logcdf(udist, y[1])), log_tie_probs[1]) | ||
| logv = _view_first(logv_work, gap_length_left + 1) | ||
| _log_gap_terms!(logv, log_gap_prob, gap_length_left) | ||
| end | ||
| return state | ||
| end | ||
|
|
||
| function _make_log_hankel_product_op(logh_work, udist, y, rank_ranges, gap_lengths, log_tie_probs) | ||
| T = eltype(logh_work) | ||
| ilast = lastindex(y) | ||
| function log_hankel_product_op(state, i) | ||
| (; logv_work, logc_work) = state | ||
| gap_length_left = gap_lengths[i] | ||
| gap_length_right = gap_lengths[i + 1] | ||
| gap_length_total = gap_length_left + gap_length_right | ||
| min_num_ties = length(rank_ranges[i]) | ||
|
|
||
| log_tie_prob = log_tie_probs[i] | ||
|
|
||
| logv = _view_first(logv_work, gap_length_left + 1) | ||
| logc = _view_first(logc_work, gap_length_right + 1) | ||
| if gap_length_left == 0 | ||
| _log_tie_terms!(logc, log_tie_prob, min_num_ties, gap_length_right) | ||
| logc .+= first(logv) | ||
| else | ||
| logh_ties = _view_first(logh_work, gap_length_total + 1) | ||
| _log_tie_terms!(logh_ties, log_tie_prob, min_num_ties, gap_length_total) | ||
| _log_xcorr_exp!(logc, logh_ties, logv) | ||
| end | ||
|
|
||
| if gap_length_right == 0 | ||
| logv_work, logc_work = logc_work, logv_work | ||
| return (; logv_work, logc_work) | ||
| end | ||
|
|
||
| logh_gap = _view_first(logh_work, gap_length_right + 1) | ||
| if i == ilast | ||
| log_gap_prob = T(logccdf(udist, y[i])) | ||
| # for right-flanking gap, logc is a row vector, and logh is a column vector, so | ||
| # we only need an inner product (i.e. first term of a cross-correlation). | ||
| logv = _view_first(logv_work, 1) | ||
| else | ||
| log_gap_prob = logsubexp( | ||
| T(logdiffcdf(udist, y[i + 1], y[i])), log_tie_probs[i + 1] | ||
| ) | ||
| logv = _view_first(logv_work, gap_length_right + 1) | ||
| end | ||
| _log_gap_terms!(logh_gap, log_gap_prob, gap_length_right) | ||
| _log_xcorr_exp!(logv, logh_gap, logc) | ||
| return (; logv_work, logc_work) | ||
| end | ||
| return log_hankel_product_op | ||
| end | ||
|
|
||
|
|
||
| _view_first(x, n) = @views x[begin:(begin - 1 + n)] | ||
|
|
||
| """ | ||
| _rle_ranks(values, ranks) -> Tuple{Vector,Vector} | ||
|
|
||
| Return the run-length encoding of the order statistics at the specified ranks. | ||
|
|
||
| If we observe xj = xi for ranks rj > ri, then we know that all ranks between ri and rj | ||
| are also equal to xi, and they are included in the range even if they are not included in | ||
| `ranks`. | ||
|
|
||
| # Arguments | ||
| - `values`: Sorted vector of observed values | ||
| - `ranks`: Sorted vector of corresponding ranks (integer-valued) | ||
|
|
||
| # Returns | ||
| - `distinct_vals`: Vector of distinct values (sorted) | ||
| - `rank_ranges`: Vector of ranges of ranks for each distinct value | ||
| """ | ||
| function _rle_ranks(values, ranks) | ||
| (val_last, rank_last), iter = Iterators.peel(zip(values, ranks)) | ||
| distinct_vals = eltype(values)[val_last] | ||
| rank_ranges = UnitRange{eltype(ranks)}[] | ||
| rank_first = rank_last | ||
| for (val, rank) in iter | ||
| if val != val_last | ||
| push!(rank_ranges, rank_first:rank_last) | ||
| push!(distinct_vals, val) | ||
| rank_first = rank | ||
| end | ||
| val_last = val | ||
| rank_last = rank | ||
| end | ||
| push!(rank_ranges, rank_first:rank_last) | ||
|
|
||
| return distinct_vals, rank_ranges | ||
| end | ||
|
|
||
| """ | ||
| _gap_lengths(n, rank_ranges) -> Vector{Int} | ||
|
|
||
| Compute the lengths of gaps between ranges of known ranks, including left- and right- tail gaps. | ||
| """ | ||
| function _gap_lengths(n::Integer, rank_ranges::Vector) | ||
| gap_lengths = Vector{Int}(undef, length(rank_ranges) + 1) | ||
| gap_lengths[1] = first(rank_ranges[1]) - 1 | ||
| for i in 2:length(rank_ranges) | ||
| gap_lengths[i] = first(rank_ranges[i]) - last(rank_ranges[i - 1]) - 1 | ||
| end | ||
| gap_lengths[end] = n - last(rank_ranges[end]) | ||
| return gap_lengths | ||
| end | ||
|
|
||
| """ | ||
| _log_gap_terms!(logh, log_gap_prob, gap_size) | ||
|
|
||
| Compute the log-multinomial term for a gap between observed ranks (or tail gaps). | ||
|
|
||
| For a gap between observed ranks ``r_i < r_j`` (with ``x_i < x_j``) of size ``k_i = r_j - r_i + 1`` `=gap_size`, | ||
| where the probability of a draw falling in the gap is ``p_i = P(x_i < x < x_j)`` `=exp(log_gap_prob)`, | ||
| computes the logarithm of the multinomial terms | ||
| ```math | ||
| h_{u+1} = p_i^{k_i - u} / (k_i - u)! | ||
| ``` | ||
| for ``u \\in [0, k_i]``. | ||
| """ | ||
| function _log_gap_terms!(logh, log_gap_prob, gap_size) | ||
| T = eltype(logh) | ||
| logh[end] = log_term = zero(T) | ||
| accumulate!(@view(logh[end-1:-1:begin]), 1:gap_size; init=log_term) do log_term, num_in_gap | ||
| return log_term + log_gap_prob - log(T(num_in_gap)) | ||
| end | ||
| return logh | ||
| end | ||
|
|
||
| """ | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is currently never used, but we need it if we use FFT acceleration.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same with the |
||
| _log_tie_terms!(logh, log_tie_prob, min_num_ties, gap_size_total) | ||
|
|
||
| Compute the log-multinomial term for the ties with observed ranks from adjacent gaps. | ||
|
|
||
| Let ``x_{r:n}`` be the rank ``r`` order statistic of the sample ``x_1, ..., x_n``. | ||
| For a block of known ranks ``r_{i}...r_{i+c_i-1}`` | ||
| (with ``x_{r_{i-1}:n} < x_{r_i:n} = ... = x_{r_{i+c_i-1}:n} < x_{r_{i+c_i}:n}``) | ||
| of size ``c_i`` `=min_num_ties`, flanked by gaps with sizes ``k_{i-1}`` and ``k_i`` and | ||
| total gap size ``g_i = k_{i-1} + k_i`` `=gap_size_total`, | ||
| computes the logarithm of the multinomial terms ``h`` where | ||
| ```math | ||
| h_{u+1} (f(x_i)^{c_i} / c_i!) = f(x_i)^{c_i + u} / (c_i + u)!, | ||
| ``` | ||
| and ``f(x_i)`` `=exp(log_tie_prob)``, for ``u \\in [0, g_i]``. | ||
| """ | ||
| function _log_tie_terms!(logh, log_tie_prob, min_num_ties, gap_size_total) | ||
| T = eltype(logh) | ||
| logh[begin] = log_term = zero(T) | ||
| accumulate!(@view(logh[begin+1:end]), 1:gap_size_total; init=log_term) do log_term, num_ties_gap | ||
| num_ties_total = num_ties_gap + min_num_ties | ||
| return log_term + log_tie_prob - log(T(num_ties_total)) | ||
| end | ||
| return logh | ||
| end | ||
|
|
||
|
|
||
| """ | ||
| _log_xcorr_exp!(log_c, log_a, log_b) | ||
|
|
||
| Compute in-place the logarithm of the cross-correlation of the exponential of `log_a` and `log_b`. | ||
|
|
||
| ```math | ||
| \\log(c_j) = \\log(\\sum_i \\exp(\\log(a_i) + \\log(b_{i+j-1}))) | ||
| ``` | ||
| with implicit `-Inf`-padding of `log_a` and `log_b` to the right as needed. | ||
|
|
||
| Only the requested entries of `log_c` are computed. | ||
|
|
||
| Note: this is equivalent to but more numerically stable than passing 0-indexed offset arrays for | ||
| `a` and `b` to `DSP.xcorr`, truncating the result `c` to `c[0:length(log_c)-1]`, and taking the | ||
| logarithm of the result. | ||
|
|
||
| # Arguments | ||
| - `log_c`: Vector to store the result | ||
| - `log_a`: Vector of logarithms of the first factor | ||
| - `log_b`: Vector of logarithms of the second factor | ||
| """ | ||
| function _log_xcorr_exp!(log_c, log_a, log_b) | ||
| n_r = length(log_b) | ||
| idx_last = lastindex(log_a) | ||
| map!(log_c, first(eachindex(log_a), length(log_c))) do j_idx | ||
| terms = Iterators.map(+, log_b, @views log_a[j_idx:min(j_idx + n_r - 1, idx_last)]) | ||
| return logsumexp(terms) | ||
| end | ||
| return log_c | ||
| end | ||
|
|
||
| function _rand!(rng::AbstractRNG, d::JointOrderStatistics, x::AbstractVector{<:Real}) | ||
| n = d.n | ||
| if n == length(d.ranks) # ranks == 1:n | ||
|
|
@@ -139,10 +401,13 @@ function _rand!(rng::AbstractRNG, d::JointOrderStatistics, x::AbstractVector{<:R | |
| # Carlo computations." The American Statistician 26.1 (1972): 26-27. | ||
| # this is slow if length(d.ranks) is close to n and quantile for d.dist is expensive, | ||
| # but this branch is probably taken when length(d.ranks) is small or much smaller than n. | ||
| T = typeof(one(eltype(x))) | ||
| s = zero(eltype(x)) | ||
|
|
||
| u = eltype(x) <: Integer ? similar(x, float(eltype(x))) : x | ||
|
|
||
| T = typeof(one(eltype(u))) | ||
| s = zero(eltype(u)) | ||
| i = 0 | ||
| for (m, j) in zip(eachindex(x), d.ranks) | ||
| for (m, j) in zip(eachindex(u), d.ranks) | ||
| k = j - i | ||
| if k > 1 | ||
| # specify GammaMTSampler directly to avoid unnecessarily checking the shape | ||
|
|
@@ -153,7 +418,7 @@ function _rand!(rng::AbstractRNG, d::JointOrderStatistics, x::AbstractVector{<:R | |
| s += randexp(rng, T) | ||
| end | ||
| i = j | ||
| x[m] = s | ||
| u[m] = s | ||
| end | ||
| j = n + 1 | ||
| k = j - i | ||
|
|
@@ -162,7 +427,7 @@ function _rand!(rng::AbstractRNG, d::JointOrderStatistics, x::AbstractVector{<:R | |
| else | ||
| s += randexp(rng, T) | ||
| end | ||
| x .= Base.Fix1(quantile, d.dist).(x ./ s) | ||
| x .= Base.Fix1(quantile, d.dist).(u ./ s) | ||
| end | ||
| return x | ||
| end | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically this change to the signature would be breaking. Should that block this PR?