Skip to content

Commit c10a007

Browse files
authored
add unwrap (#103)
* add unwrap * using InverseFunctions: FunctionWithInverse * formatting * add tests
1 parent 29d154f commit c10a007

File tree

3 files changed

+8
-2
lines changed

3 files changed

+8
-2
lines changed

src/combinators/transformedmeasure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ end
5858
@inline function logdensity_def::PushforwardMeasure{F,I,M,<:WithVolCorr}, y) where {F,I,M}
5959
f = ν.f
6060
finv = ν.finv
61-
x_orig, inv_ladj = with_logabsdet_jacobian(finv, y)
61+
x_orig, inv_ladj = with_logabsdet_jacobian(unwrap(finv), y)
6262
logd_orig = logdensity_def.origin, x_orig)
6363
logd = float(logd_orig + inv_ladj)
6464
neginf = oftype(logd, -Inf)

src/utils.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,8 @@ insupport(m::AbstractMeasure) = Base.Fix1(insupport, m)
159159

160160
unstatic(::Type{T}) where {T} = T
161161
unstatic(::Type{StaticFloat64{X}}) where {X} = Float64
162+
163+
using InverseFunctions: FunctionWithInverse
164+
165+
unwrap(f) = f
166+
unwrap(f::FunctionWithInverse) = f.f

test/combinators/transformedmeasure.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import Statistics: var
88
using DensityInterface: logdensityof
99
using LogExpFunctions
1010
using SpecialFunctions: erfc, erfcinv
11-
import InverseFunctions: inverse, FunctionWithInverse
11+
import InverseFunctions: inverse, FunctionWithInverse, setinverse
1212
using IrrationalConstants: invsqrt2, sqrt2
1313
import ChangesOfVariables: with_logabsdet_jacobian
1414
using MeasureBase.Interface: transport_to, test_transport
@@ -69,6 +69,7 @@ end
6969

7070
for (f, μ, ν_ref) in triples
7171
test_pushfwd(f, μ, ν_ref)
72+
test_pushfwd(setinverse(f, inverse(f)), μ, ν_ref)
7273
end
7374

7475
@testset "Pushforward-of-pushforward" begin

0 commit comments

Comments
 (0)