|
| 1 | +using Distributions |
1 | 2 | using ChainRulesTestUtils
|
2 | 3 | using ChainRulesTestUtils: FiniteDifferences
|
3 | 4 |
|
| 5 | +using Random |
| 6 | +using Test |
| 7 | + |
4 | 8 | # Without this, `to_vec` will also include the `axes` field of `EachVariate`.
|
5 | 9 | function FiniteDifferences.to_vec(xs::Distributions.EachVariate{V}) where {V}
|
6 | 10 | vals, vals_from_vec = FiniteDifferences.to_vec(xs.parent)
|
7 | 11 | return vals, x -> Distributions.EachVariate{V}(vals_from_vec(x))
|
8 | 12 | end
|
9 | 13 |
|
| 14 | +# MWE in #1817 |
| 15 | +struct FooEachvariate <: Sampleable{Multivariate, Continuous} end |
| 16 | +Base.length(::FooEachvariate) = 3 |
| 17 | +Base.eltype(::FooEachvariate) = Float64 |
| 18 | +function Distributions._rand!(rng::AbstractRNG, ::FooEachvariate, x::AbstractVector{<:Real}) |
| 19 | + return rand!(rng, x) |
| 20 | +end |
| 21 | + |
10 | 22 | @testset "eachvariate.jl" begin
|
11 | 23 | @testset "ChainRules" begin
|
12 | 24 | xs = randn(2, 3, 4, 5)
|
13 | 25 | test_rrule(Distributions.EachVariate{1}, xs)
|
14 | 26 | test_rrule(Distributions.EachVariate{2}, xs)
|
15 | 27 | test_rrule(Distributions.EachVariate{3}, xs)
|
16 | 28 | end
|
| 29 | + |
| 30 | + @testset "No variates (#1817)" begin |
| 31 | + @test size(Distributions.eachvariate(rand(0), Univariate)) == (0,) |
| 32 | + @test size(Distributions.eachvariate(rand(3, 0, 1), Univariate)) == (3, 0, 1) |
| 33 | + @test size(Distributions.eachvariate(rand(3, 2, 0), Univariate)) == (3, 2, 0) |
| 34 | + |
| 35 | + @test size(Distributions.eachvariate(rand(4, 0), Multivariate)) == (0,) |
| 36 | + @test size(Distributions.eachvariate(rand(4, 0, 2), Multivariate)) == (0, 2) |
| 37 | + @test size(Distributions.eachvariate(rand(4, 5, 0), Multivariate)) == (5, 0) |
| 38 | + @test size(Distributions.eachvariate(rand(4, 5, 0, 2), Multivariate)) == (5, 0, 2) |
| 39 | + |
| 40 | + draws = @inferred(rand(FooEachvariate(), 0)) |
| 41 | + @test draws isa Matrix{Float64} |
| 42 | + @test size(draws) == (3, 0) |
| 43 | + end |
17 | 44 | end
|
0 commit comments