Skip to content

Commit 04e24b0

Browse files
committed
Add mreshape
1 parent 0fb2009 commit 04e24b0

File tree

4 files changed

+58
-0
lines changed

4 files changed

+58
-0
lines changed

src/MeasureBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ include("primitives/trivial.jl")
166166

167167
include("combinators/bind.jl")
168168
include("combinators/transformedmeasure.jl")
169+
include("combinators/reshape.jl")
169170
include("combinators/weighted.jl")
170171
include("combinators/superpose.jl")
171172
include("combinators/product.jl")

src/combinators/reshape.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# ToDo: Support static resizes for static arrays
2+
3+
"""
4+
struct MeasureBase.Reshape <: Function
5+
6+
Represents a function that reshapes an array.
7+
8+
Supports `InverseFunctions.inverse` and
9+
`ChangesOfVariables.with_logabsdet_jacobian`.
10+
11+
Constructor:
12+
13+
```julia
14+
Reshape(output_size::Dims, input_size::Dims)
15+
```
16+
"""
17+
struct Reshape{M,N} <: Function
18+
output_size::NTuple{M,Int}
19+
input_size::NTuple{N,Int}
20+
end
21+
22+
_throw_reshape_mismatch(sz, sz_x) = throw(DimensionMismatch("Reshape input size is $sz but got input of size $sz_x"))
23+
24+
function (f::Reshape)(x::AbstractArray)
25+
sz_x = size(x)
26+
f.input_size == sz_x || _throw_reshape_mismatch(f.input_size, sz_x)
27+
return reshape(x, f.output_size)
28+
end
29+
30+
InverseFunctions.inverse(f::Reshape) = Reshape(f.input_size, f.output_size)
31+
32+
ChangesOfVariables.with_logabsdet_jacobian(::Reshape, x::AbstractArray) = zero(real_numtype(typeof(x)))
33+
34+
35+
"""
36+
mreshape(m::AbstractMeasure, sz::Vararg{N,Integer}) where N
37+
mreshape(m::AbstractMeasure, sz::NTuple{N,Integer}) where N
38+
39+
Reshape a measure `m` over an array-valued space, returning a measure over
40+
a space of arrays with shape `sz`.
41+
"""
42+
function mreshape end
43+
44+
_elsize_for_reshape(m::AbstractMeasure) = _elsize_for_reshape(some_mspace_elsize(m), m)
45+
_elsize_for_reshape(sz::NTuple{<:Any,Integer}, ::AbstractMeasure) = sz
46+
_elsize_for_reshape(::NoMSpaceElementSize, m::AbstractMeasure) = size(testvalue(m))
47+
48+
mreshape(m::AbstractMeasure, sz::Vararg{<:Any,Integer}) = mreshape(m, sz)
49+
mreshape(m::AbstractMeasure, sz::NTuple{<:Any,Integer}) = pushfwd(Reshape(sz, _elsize_for_reshape(m)), m)

test/combinators/reshape.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
using Test
2+
3+
using MeasureBase
4+
5+
@testset "reshape" begin
6+
7+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ include("smf.jl")
1919

2020
include("combinators/weighted.jl")
2121
include("combinators/transformedmeasure.jl")
22+
include("combinators/reshape.jl")
2223

2324
include("test_docs.jl")

0 commit comments

Comments
 (0)