Skip to content

Commit 28bf738

Browse files
authored
Fix eachvariate with zero variates (#1819)
1 parent 9e72f1f commit 28bf738

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Distributions"
22
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
33
authors = ["JuliaStats"]
4-
version = "0.25.104"
4+
version = "0.25.105"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/eachvariate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ end
77

88
function EachVariate{V}(x::AbstractArray{<:Real,M}) where {V,M}
99
ax = ntuple(i -> axes(x, i + V), Val(M - V))
10-
T = typeof(view(x, ntuple(i -> i <= V ? Colon() : firstindex(x, i), Val(M))...))
10+
T = Base.promote_op(view, typeof(x), ntuple(i -> i <= V ? Colon : eltype(axes(x, i)), Val(M))...)
1111
return EachVariate{V,typeof(x),typeof(ax),T,M-V}(x, ax)
1212
end
1313

test/eachvariate.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,44 @@
1+
using Distributions
12
using ChainRulesTestUtils
23
using ChainRulesTestUtils: FiniteDifferences
34

5+
using Random
6+
using Test
7+
48
# Without this, `to_vec` will also include the `axes` field of `EachVariate`.
59
function FiniteDifferences.to_vec(xs::Distributions.EachVariate{V}) where {V}
610
vals, vals_from_vec = FiniteDifferences.to_vec(xs.parent)
711
return vals, x -> Distributions.EachVariate{V}(vals_from_vec(x))
812
end
913

14+
# MWE in #1817
15+
struct FooEachvariate <: Sampleable{Multivariate, Continuous} end
16+
Base.length(::FooEachvariate) = 3
17+
Base.eltype(::FooEachvariate) = Float64
18+
function Distributions._rand!(rng::AbstractRNG, ::FooEachvariate, x::AbstractVector{<:Real})
19+
return rand!(rng, x)
20+
end
21+
1022
@testset "eachvariate.jl" begin
1123
@testset "ChainRules" begin
1224
xs = randn(2, 3, 4, 5)
1325
test_rrule(Distributions.EachVariate{1}, xs)
1426
test_rrule(Distributions.EachVariate{2}, xs)
1527
test_rrule(Distributions.EachVariate{3}, xs)
1628
end
29+
30+
@testset "No variates (#1817)" begin
31+
@test size(Distributions.eachvariate(rand(0), Univariate)) == (0,)
32+
@test size(Distributions.eachvariate(rand(3, 0, 1), Univariate)) == (3, 0, 1)
33+
@test size(Distributions.eachvariate(rand(3, 2, 0), Univariate)) == (3, 2, 0)
34+
35+
@test size(Distributions.eachvariate(rand(4, 0), Multivariate)) == (0,)
36+
@test size(Distributions.eachvariate(rand(4, 0, 2), Multivariate)) == (0, 2)
37+
@test size(Distributions.eachvariate(rand(4, 5, 0), Multivariate)) == (5, 0)
38+
@test size(Distributions.eachvariate(rand(4, 5, 0, 2), Multivariate)) == (5, 0, 2)
39+
40+
draws = @inferred(rand(FooEachvariate(), 0))
41+
@test draws isa Matrix{Float64}
42+
@test size(draws) == (3, 0)
43+
end
1744
end

0 commit comments

Comments
 (0)