From 9695e95e62cd13a62cd886bb36fcea7dbac18303 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 22 Oct 2025 18:48:49 +0200 Subject: [PATCH 01/11] Determine rand type using axes --- src/genericrand.jl | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/genericrand.jl b/src/genericrand.jl index bcc1969d1..4027935ac 100644 --- a/src/genericrand.jl +++ b/src/genericrand.jl @@ -26,7 +26,9 @@ rand(rng::AbstractRNG, s::Sampleable, dim1::Int, moredims::Int...) = # default fallback (redefined for univariate distributions) function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate}) - return rand!(rng, s, Array{eltype(s)}(undef, size(s))) + out = similar(Array{eltype(s)}, axes(s)) + rand!(rng, s, out) + return out end # multiple samples @@ -37,15 +39,17 @@ end function rand( rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate}, dims::Dims, ) - sz = size(s) + sax = axes(s) ax = map(Base.OneTo, dims) - out = [Array{eltype(s)}(undef, sz) for _ in Iterators.product(ax...)] + out = [similar(Array{eltype(s)}, sax) for _ in Iterators.product(ax...)] return rand!(rng, sampler(s), out, false) end # these are workarounds for sampleables that incorrectly base `eltype` on the parameters function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate,Continuous}) - return rand!(rng, sampler(s), Array{float(eltype(s))}(undef, size(s))) + out = similar(Array{float(eltype(s))}, axes(s)) + rand!(rng, sampler(s), out) + return out end function rand(rng::AbstractRNG, s::Sampleable{Univariate,Continuous}, dims::Dims) out = Array{float(eltype(s))}(undef, dims) @@ -54,9 +58,9 @@ end function rand( rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate,Continuous}, dims::Dims, ) - sz = size(s) + sax = axes(s) ax = map(Base.OneTo, dims) - out = [Array{float(eltype(s))}(undef, sz) for _ in Iterators.product(ax...)] + out = [similar(Array{float(eltype(s))}, sax) for _ in Iterators.product(ax...)] return rand!(rng, sampler(s), out, false) end From 44ccc9fcef7fa85c7825d64d0deef2e5633eeb2b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 22 Oct 2025 21:21:56 +0200 Subject: [PATCH 02/11] Document axes for Sampleable --- src/common.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/common.jl b/src/common.jl index 421906100..684d9c514 100644 --- a/src/common.jl +++ b/src/common.jl @@ -99,6 +99,15 @@ Base.size(s::Sampleable) Base.size(s::Sampleable{Univariate}) = () Base.size(s::Sampleable{Multivariate}) = (length(s),) +""" + axes(s::Sampleable{<:ArrayLikeVariate}) + +Return a tuple of valid indices for a sample from `s`. This can e.g. be used to construct +an empty sample array using `similar(Array{eltype(s)}, axes(s))`. This can be overloaded +for a given distribution to return e.g. axes that define custom array types. +""" +Base.axes(s::Sampleable{<:ArrayLikeVariate}) + """ eltype(::Type{Sampleable}) From fd2d8304c31d06e69551a6f06875d7cee28a05a1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 22 Oct 2025 21:34:30 +0200 Subject: [PATCH 03/11] Determine axes from param arrays where possible --- src/matrix/inversewishart.jl | 1 + src/matrix/matrixfdist.jl | 2 ++ src/matrix/matrixnormal.jl | 2 ++ src/matrix/matrixtdist.jl | 2 ++ src/matrix/wishart.jl | 2 ++ src/mixtures/mixturemodel.jl | 1 + src/multivariate/dirichlet.jl | 4 ++++ src/multivariate/dirichletmultinomial.jl | 1 + src/multivariate/jointorderstatistics.jl | 1 + src/multivariate/multinomial.jl | 1 + src/multivariate/mvlognormal.jl | 1 + src/multivariate/mvnormal.jl | 1 + src/multivariate/mvnormalcanon.jl | 1 + src/multivariate/mvtdist.jl | 1 + src/multivariate/product.jl | 1 + src/multivariate/vonmisesfisher.jl | 1 + src/product.jl | 1 + 17 files changed, 24 insertions(+) diff --git a/src/matrix/inversewishart.jl b/src/matrix/inversewishart.jl index 4c605caf2..9b5f1542f 100644 --- a/src/matrix/inversewishart.jl +++ b/src/matrix/inversewishart.jl @@ -75,6 +75,7 @@ insupport(::Type{InverseWishart}, X::Matrix) = isposdef(X) insupport(d::InverseWishart, X::Matrix) = size(X) == size(d) && isposdef(X) size(d::InverseWishart) = size(d.Ψ) +Base.axes(d::InverseWishart) = axes(d.Ψ) rank(d::InverseWishart) = rank(d.Ψ) params(d::InverseWishart) = (d.df, d.Ψ) diff --git a/src/matrix/matrixfdist.jl b/src/matrix/matrixfdist.jl index 5a7e18efb..d9273ab0c 100644 --- a/src/matrix/matrixfdist.jl +++ b/src/matrix/matrixfdist.jl @@ -80,6 +80,8 @@ end size(d::MatrixFDist) = size(d.W) +Base.axes(d::MatrixFDist) = axes(d.W) + rank(d::MatrixFDist) = size(d, 1) insupport(d::MatrixFDist, Σ::AbstractMatrix) = isreal(Σ) && size(Σ) == size(d) && isposdef(Σ) diff --git a/src/matrix/matrixnormal.jl b/src/matrix/matrixnormal.jl index 9d480bd69..952ad0aa9 100644 --- a/src/matrix/matrixnormal.jl +++ b/src/matrix/matrixnormal.jl @@ -82,6 +82,8 @@ end size(d::MatrixNormal) = size(d.M) +Base.axes(d::MatrixNormal) = axes(d.M) + rank(d::MatrixNormal) = minimum( size(d) ) insupport(d::MatrixNormal, X::AbstractMatrix) = isreal(X) && size(X) == size(d) diff --git a/src/matrix/matrixtdist.jl b/src/matrix/matrixtdist.jl index e37119603..29e52638a 100644 --- a/src/matrix/matrixtdist.jl +++ b/src/matrix/matrixtdist.jl @@ -100,6 +100,8 @@ end size(d::MatrixTDist) = size(d.M) +Base.axes(d::MatrixTDist) = axes(d.M) + rank(d::MatrixTDist) = minimum( size(d) ) insupport(d::MatrixTDist, X::Matrix) = isreal(X) && size(X) == size(d) diff --git a/src/matrix/wishart.jl b/src/matrix/wishart.jl index 701df1d5a..e60438dc4 100644 --- a/src/matrix/wishart.jl +++ b/src/matrix/wishart.jl @@ -102,6 +102,8 @@ end size(d::Wishart) = size(d.S) +Base.axes(d::Wishart) = axes(d.S) + rank(d::Wishart) = d.rank params(d::Wishart) = (d.df, d.S) @inline partype(d::Wishart{T}) where {T<:Real} = T diff --git a/src/mixtures/mixturemodel.jl b/src/mixtures/mixturemodel.jl index f0aeebce9..53a8422f8 100644 --- a/src/mixtures/mixturemodel.jl +++ b/src/mixtures/mixturemodel.jl @@ -159,6 +159,7 @@ end The length of each sample (only for `Multivariate`). """ length(d::MultivariateMixture) = length(d.components[1]) +Base.axes(d::MultivariateMixture) = axes(d.components[1]) size(d::MatrixvariateMixture) = size(d.components[1]) ncomponents(d::MixtureModel) = length(d.components) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 0717cdfe1..1716a7e99 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -44,12 +44,16 @@ function Dirichlet(d::Integer, alpha::Real; check_args::Bool=true) return Dirichlet{typeof(alpha)}(Fill(alpha, d); check_args=false) end +Base.axes(d::Dirichlet) = axes(d.alpha) + struct DirichletCanon{T<:Real,Ts<:AbstractVector{T}} alpha::Ts end length(d::DirichletCanon) = length(d.alpha) +Base.axes(d::DirichletCanon) = axes(d.alpha) + Base.eltype(::Type{<:Dirichlet{T}}) where {T} = T #### Conversions diff --git a/src/multivariate/dirichletmultinomial.jl b/src/multivariate/dirichletmultinomial.jl index 878467fc2..82403617e 100644 --- a/src/multivariate/dirichletmultinomial.jl +++ b/src/multivariate/dirichletmultinomial.jl @@ -51,6 +51,7 @@ Base.show(io::IO, d::DirichletMultinomial) = show(io, d, (:n, :α,)) # Parameters ncategories(d::DirichletMultinomial) = length(d.α) +Base.axes(d::DirichletMultinomial) = axes(d.α) length(d::DirichletMultinomial) = ncategories(d) ntrials(d::DirichletMultinomial) = d.n params(d::DirichletMultinomial) = (d.n, d.α) diff --git a/src/multivariate/jointorderstatistics.jl b/src/multivariate/jointorderstatistics.jl index 1fbed0d1b..ddb3865dc 100644 --- a/src/multivariate/jointorderstatistics.jl +++ b/src/multivariate/jointorderstatistics.jl @@ -71,6 +71,7 @@ function _are_ranks_valid(ranks::AbstractRange, n) end length(d::JointOrderStatistics) = length(d.ranks) +Base.axes(d::JointOrderStatistics) = axes(d.ranks) function insupport(d::JointOrderStatistics, x::AbstractVector) length(d) == length(x) || return false xi, state = iterate(x) # at least one element! diff --git a/src/multivariate/multinomial.jl b/src/multivariate/multinomial.jl index f365450ca..568fa4cbe 100644 --- a/src/multivariate/multinomial.jl +++ b/src/multivariate/multinomial.jl @@ -43,6 +43,7 @@ end ncategories(d::Multinomial) = length(d.p) length(d::Multinomial) = ncategories(d) +Base.axes(d::Multinomial) = axes(d.p) probs(d::Multinomial) = d.p ntrials(d::Multinomial) = d.n diff --git a/src/multivariate/mvlognormal.jl b/src/multivariate/mvlognormal.jl index 046f70949..0ab9b84ee 100644 --- a/src/multivariate/mvlognormal.jl +++ b/src/multivariate/mvlognormal.jl @@ -189,6 +189,7 @@ function convert(::Type{MvLogNormal{T}}, pars...) where T<:Real end length(d::MvLogNormal) = length(d.normal) +Base.axes(d::MvLogNormal) = axes(d.normal) params(d::MvLogNormal) = params(d.normal) @inline partype(d::MvLogNormal{T}) where {T<:Real} = T diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 202f449f2..8203c8c83 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -247,6 +247,7 @@ Base.show(io::IO, d::MvNormal) = ### Basic statistics length(d::MvNormal) = length(d.μ) +Base.axes(d::MvNormal) = axes(d.μ) mean(d::MvNormal) = d.μ params(d::MvNormal) = (d.μ, d.Σ) @inline partype(d::MvNormal{T}) where {T<:Real} = T diff --git a/src/multivariate/mvnormalcanon.jl b/src/multivariate/mvnormalcanon.jl index 79b43e9ba..dacb88fdd 100644 --- a/src/multivariate/mvnormalcanon.jl +++ b/src/multivariate/mvnormalcanon.jl @@ -151,6 +151,7 @@ canonform(d::MvNormal{T,C,Zeros{T}}) where {C, T<:Real} = MvNormalCanon(inv(d.Σ ### Basic statistics length(d::MvNormalCanon) = length(d.μ) +Base.axes(d::MvNormalCanon) = axes(d.h) mean(d::MvNormalCanon) = convert(Vector{eltype(d.μ)}, d.μ) params(d::MvNormalCanon) = (d.μ, d.h, d.J) @inline partype(d::MvNormalCanon{T}) where {T<:Real} = T diff --git a/src/multivariate/mvtdist.jl b/src/multivariate/mvtdist.jl index 9076c364b..667750cad 100644 --- a/src/multivariate/mvtdist.jl +++ b/src/multivariate/mvtdist.jl @@ -87,6 +87,7 @@ mvtdist(df::Real, Σ::Matrix{<:Real}) = MvTDist(df, Σ) # Basic statistics length(d::GenericMvTDist) = d.dim +Base.axes(d::GenericMvTDist) = axes(d.μ) mean(d::GenericMvTDist) = d.df>1 ? d.μ : NaN mode(d::GenericMvTDist) = d.μ diff --git a/src/multivariate/product.jl b/src/multivariate/product.jl index ada1c4e5f..1af2bbd31 100644 --- a/src/multivariate/product.jl +++ b/src/multivariate/product.jl @@ -35,6 +35,7 @@ function Product(v::V) where {S<:ValueSupport,T<:UnivariateDistribution{S},V<:Ab end length(d::Product) = length(d.v) +Base.axes(d::Product) = axes(d.v) function Base.eltype(::Type{<:Product{S,T}}) where {S<:ValueSupport, T<:UnivariateDistribution{S}} return eltype(T) diff --git a/src/multivariate/vonmisesfisher.jl b/src/multivariate/vonmisesfisher.jl index e4fe981fe..85d1434e4 100644 --- a/src/multivariate/vonmisesfisher.jl +++ b/src/multivariate/vonmisesfisher.jl @@ -51,6 +51,7 @@ convert(::Type{VonMisesFisher{T}}, μ::Vector, κ, logCκ) where {T<:Real} = Vo ### Basic properties length(d::VonMisesFisher) = length(d.μ) +Base.axes(d::VonMisesFisher) = axes(d.μ) meandir(d::VonMisesFisher) = d.μ concentration(d::VonMisesFisher) = d.κ diff --git a/src/product.jl b/src/product.jl index 71b02beaf..2e50a55e5 100644 --- a/src/product.jl +++ b/src/product.jl @@ -75,6 +75,7 @@ function Base.eltype(::Type{<:ProductDistribution{<:Any,<:Any,<:Any,<:ValueSuppo end size(d::ProductDistribution) = d.size +Base.axes(d::ProductDistribution) = (axes(first(d.dists))..., axes(d.dists)...) mean(d::ProductDistribution) = reshape(mapreduce(vec ∘ mean, vcat, d.dists), size(d)) var(d::ProductDistribution) = reshape(mapreduce(vec ∘ var, vcat, d.dists), size(d)) From 2f79bbcf2f5326ccf4f37f4ece7682fb5147ba3f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 22 Oct 2025 23:04:40 +0200 Subject: [PATCH 04/11] Use axes in mixture expectations --- src/mixtures/mixturemodel.jl | 53 +++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/src/mixtures/mixturemodel.jl b/src/mixtures/mixturemodel.jl index 53a8422f8..649e3432b 100644 --- a/src/mixtures/mixturemodel.jl +++ b/src/mixtures/mixturemodel.jl @@ -180,13 +180,11 @@ function mean(d::UnivariateMixture) end function mean(d::MultivariateMixture) - K = ncomponents(d) p = probs(d) - m = zeros(length(d)) - for i = 1:K - pi = p[i] + m = similar(p, float(eltype(d)), axes(d)) + fill!(m, 0.0) + for (c, pi) in zip(components(d), p) if pi > 0.0 - c = component(d, i) axpy!(pi, mean(c), m) end end @@ -223,33 +221,50 @@ function var(d::UnivariateMixture) end function var(d::MultivariateMixture) - return diag(cov(d)) + p = probs(d) + ax = axes(d) + m = similar(p, ax) + md = similar(p, ax) + v = similar(p, ax) + fill!(m, 0.0) + fill!(v, 0.0) + + for (c, pi) in zip(components(d), p) + if pi > 0.0 + axpy!(pi, mean(c), m) + axpy!(pi, var(c), v) + end + end + for (c, pi) in zip(components(d), p) + if pi > 0.0 + md .= mean(c) .- m + @. v += pi * abs2(md) + end + end + return v end function cov(d::MultivariateMixture) - K = ncomponents(d) p = probs(d) - m = zeros(length(d)) - md = zeros(length(d)) - V = zeros(length(d),length(d)) - - for i = 1:K - pi = p[i] + ax = axes(d) + m = similar(p, ax) + md = similar(p, ax) + V = similar(p, (ax[1], ax[1])) + fill!(m, 0.0) + fill!(V, 0.0) + + for (c, pi) in zip(components(d), p) if pi > 0.0 - c = component(d, i) axpy!(pi, mean(c), m) axpy!(pi, cov(c), V) end end - for i = 1:K - pi = p[i] + for (c, pi) in zip(components(d), p) if pi > 0.0 - c = component(d, i) md .= mean(c) .- m - BLAS.syr!('U', Float64(pi), md, V) + @. V += pi * md * md' end end - LinearAlgebra.copytri!(V, 'U') return V end From 859d49a85ba71c6d33485e7cae0d0a91bcc90040 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 22 Oct 2025 23:06:02 +0200 Subject: [PATCH 05/11] Use axes in computing Dirichlet cov --- src/multivariate/dirichlet.jl | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 1716a7e99..53cdac945 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -89,23 +89,14 @@ end function cov(d::Dirichlet) α = d.alpha + ax = axes(α, 1) α0 = d.alpha0 c = inv(α0^2 * (α0 + 1)) T = typeof(zero(eltype(α))^2 * c) - k = length(α) - C = Matrix{T}(undef, k, k) - for j = 1:k - αj = α[j] - αjc = αj * c - for i in 1:(j-1) - C[i,j] = C[j,i] - end - C[j,j] = (α0 - αj) * αjc - for i in (j+1):k - C[i,j] = - α[i] * αjc - end - end + C = similar(α, T, (ax, ax)) + @. C = -α * α' * c + C[diagind(C)] .+= (α0 * c) .* α return C end From 86edc74a4e528ab510ef6aa2884a309692decfbe Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 22 Oct 2025 23:08:03 +0200 Subject: [PATCH 06/11] Broadcast DirichletMultinomial var --- src/multivariate/dirichletmultinomial.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/multivariate/dirichletmultinomial.jl b/src/multivariate/dirichletmultinomial.jl index 82403617e..0a99b27cd 100644 --- a/src/multivariate/dirichletmultinomial.jl +++ b/src/multivariate/dirichletmultinomial.jl @@ -60,11 +60,9 @@ params(d::DirichletMultinomial) = (d.n, d.α) # Statistics mean(d::DirichletMultinomial) = d.α .* (d.n / d.α0) function var(d::DirichletMultinomial{T}) where T <: Real - v = fill(d.n * (d.n + d.α0) / (1 + d.α0), length(d)) + v0 = d.n * (d.n + d.α0) / (1 + d.α0) p = d.α / d.α0 - for i in eachindex(v) - v[i] *= p[i] * (1 - p[i]) - end + v = @. v0 * p * (1 - p) v end function cov(d::DirichletMultinomial{<:Real}) From 9eccd49bf173f9c9aca9e6cd93cdcea968f0d1b1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 22 Oct 2025 23:08:32 +0200 Subject: [PATCH 07/11] Make minimum/maximum use axes --- src/multivariate/jointorderstatistics.jl | 4 ++-- src/multivariate/mvnormal.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/multivariate/jointorderstatistics.jl b/src/multivariate/jointorderstatistics.jl index ddb3865dc..d5e6401ac 100644 --- a/src/multivariate/jointorderstatistics.jl +++ b/src/multivariate/jointorderstatistics.jl @@ -84,8 +84,8 @@ function insupport(d::JointOrderStatistics, x::AbstractVector) end return true end -minimum(d::JointOrderStatistics) = Fill(minimum(d.dist), length(d)) -maximum(d::JointOrderStatistics) = Fill(maximum(d.dist), length(d)) +minimum(d::JointOrderStatistics) = Fill(minimum(d.dist), axes(d)) +maximum(d::JointOrderStatistics) = Fill(maximum(d.dist), axes(d)) params(d::JointOrderStatistics) = tuple(params(d.dist)..., d.n, d.ranks) partype(d::JointOrderStatistics) = partype(d.dist) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 8203c8c83..cbde45285 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -80,8 +80,8 @@ abstract type AbstractMvNormal <: ContinuousMultivariateDistribution end insupport(d::AbstractMvNormal, x::AbstractVector) = length(d) == length(x) && all(isfinite, x) -minimum(d::AbstractMvNormal) = fill(eltype(d)(-Inf), length(d)) -maximum(d::AbstractMvNormal) = fill(eltype(d)(Inf), length(d)) +minimum(d::AbstractMvNormal) = fill(eltype(d)(-Inf), axes(d)) +maximum(d::AbstractMvNormal) = fill(eltype(d)(Inf), axes(d)) mode(d::AbstractMvNormal) = mean(d) modes(d::AbstractMvNormal) = [mean(d)] From 9ecf6781496b3d41ba9d589b4f494584ad7d8fe3 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 22 Oct 2025 23:11:17 +0200 Subject: [PATCH 08/11] Broadcast multinomial var and cov to use axes --- src/multivariate/multinomial.jl | 27 +++++---------------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/src/multivariate/multinomial.jl b/src/multivariate/multinomial.jl index 568fa4cbe..87c746906 100644 --- a/src/multivariate/multinomial.jl +++ b/src/multivariate/multinomial.jl @@ -64,37 +64,20 @@ mean(d::Multinomial) = d.n .* d.p function var(d::Multinomial{T}) where T<:Real p = probs(d) - k = length(p) n = ntrials(d) - v = Vector{T}(undef, k) - for i = 1:k - p_i = p[i] - v[i] = n * p_i * (1 - p_i) - end + v = @. n * p * (1 - p) v end function cov(d::Multinomial{T}) where T<:Real p = probs(d) - k = length(p) + ax = axes(p, 1) n = ntrials(d) - C = Matrix{T}(undef, k, k) - for j = 1:k - pj = p[j] - for i = 1:j-1 - C[i,j] = - n * p[i] * pj - end - - C[j,j] = n * pj * (1-pj) - end - - for j = 1:k-1 - for i = j+1:k - C[i,j] = C[j,i] - end - end + C = similar(p, T, (ax, ax)) + @. C = -n * p * p' + C[diagind(C)] .+= n .* p C end From 3c93af9c609d48cbffe1dc71027c6fd553d8d420 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 22 Oct 2025 23:12:41 +0200 Subject: [PATCH 09/11] Make sure probs is a Vector MultinomialSampler expects a Vector --- src/multivariate/multinomial.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate/multinomial.jl b/src/multivariate/multinomial.jl index 87c746906..870b6cbd4 100644 --- a/src/multivariate/multinomial.jl +++ b/src/multivariate/multinomial.jl @@ -152,7 +152,7 @@ end _rand!(rng::AbstractRNG, d::Multinomial, x::AbstractVector{<:Real}) = multinom_rand!(rng, ntrials(d), probs(d), x) -sampler(d::Multinomial) = MultinomialSampler(ntrials(d), probs(d)) +sampler(d::Multinomial) = MultinomialSampler(ntrials(d), convert(Vector, probs(d))) ## Fit model From 6f5b20b26d242617828b85ceeef043f309a436c6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 22 Oct 2025 23:13:56 +0200 Subject: [PATCH 10/11] Choose axes of zero-mean MvNormal based on scale --- src/multivariate/mvnormal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index cbde45285..6d437faa2 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -211,7 +211,7 @@ end Construct a multivariate normal distribution with zero mean and covariance matrix `Σ`. """ -MvNormal(Σ::AbstractMatrix{<:Real}) = MvNormal(Zeros{eltype(Σ)}(size(Σ, 1)), Σ) +MvNormal(Σ::AbstractMatrix{<:Real}) = MvNormal(Zeros{eltype(Σ)}(axes(Σ, 1)), Σ) # deprecated constructors with standard deviations Base.@deprecate MvNormal(μ::AbstractVector{<:Real}, σ::AbstractVector{<:Real}) MvNormal(μ, LinearAlgebra.Diagonal(map(abs2, σ))) From 46654f311cc7dbbc7ae48ab947489f5ddfd8a4a6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 22 Oct 2025 23:14:21 +0200 Subject: [PATCH 11/11] Choose cor array type from cov --- src/multivariates.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariates.jl b/src/multivariates.jl index 962b90ed5..1aa6bb0b3 100644 --- a/src/multivariates.jl +++ b/src/multivariates.jl @@ -99,7 +99,7 @@ function cor(d::MultivariateDistribution) C = cov(d) n = size(C, 1) @assert size(C, 2) == n - R = Matrix{eltype(C)}(undef, n, n) + R = similar(cov) for j = 1:n for i = 1:j-1