Skip to content

Commit f859e29

Browse files
committed
Fix construction of NoLogAbsDetJacobian
1 parent 974119a commit f859e29

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

src/with_ladj.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,24 +80,34 @@ export with_logabsdet_jacobian
8080
struct NoLogAbsDetJacobian{F,T}
8181
8282
An instance `NoLogAbsDetJacobian{F,T}()` signifies that `with_logabsdet_jacobian(::F, ::T)` is not defined.
83+
84+
Constructors:
85+
```julia
86+
NoLogAbsDetJacobian(f, x)
87+
NoLogAbsDetJacobian{F,T}()
88+
```
8389
"""
8490
struct NoLogAbsDetJacobian{F,T} end
8591
export NoLogAbsDetJacobian
8692

87-
with_logabsdet_jacobian(::F, ::T) where {F,T} = NoLogAbsDetJacobian{F,T}()
93+
@inline NoLogAbsDetJacobian(::F, ::T) where {F,T} = NoLogAbsDetJacobian{F,T}()
94+
@inline NoLogAbsDetJacobian(::Type{F}, ::T) where {F,T} = NoLogAbsDetJacobian{Type{F},T}()
95+
@inline NoLogAbsDetJacobian(::F, ::Type{T}) where {F,T} = NoLogAbsDetJacobian{F,Type{T}}()
96+
@inline NoLogAbsDetJacobian(::Type{F}, ::Type{T}) where {F,T} = NoLogAbsDetJacobian{Type{F},Type{T}}()
8897

98+
with_logabsdet_jacobian(f, x) = NoLogAbsDetJacobian(f, x)
8999

90100

91101
@static if VERSION >= v"1.6"
92102
function with_logabsdet_jacobian(f::Base.ComposedFunction, x)
93103
y_ladj_inner = with_logabsdet_jacobian(f.inner, x)
94104
if y_ladj_inner isa NoLogAbsDetJacobian
95-
NoLogAbsDetJacobian{typeof(f),typeof(x)}()
105+
NoLogAbsDetJacobian(f, x)
96106
else
97107
y_inner, ladj_inner = y_ladj_inner
98108
y_ladj_outer = with_logabsdet_jacobian(f.outer, y_inner)
99109
if y_ladj_outer isa NoLogAbsDetJacobian
100-
NoLogAbsDetJacobian{typeof(f),typeof(x)}()
110+
NoLogAbsDetJacobian(f, x)
101111
else
102112
y, ladj_outer = y_ladj_outer
103113
(y, ladj_inner + ladj_outer)

test/test_with_ladj.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,18 @@ include("getjacobian.jl")
1818
_bc_func(f) = Base.Fix1(broadcast, f)
1919
end
2020

21-
@test with_logabsdet_jacobian(sum, rand(5)) == NoLogAbsDetJacobian{typeof(sum),Vector{Float64}}()
22-
@test with_logabsdet_jacobian(sum log, 5.0f0) == NoLogAbsDetJacobian{typeof(sum ∘ log),Float32}()
23-
@test with_logabsdet_jacobian(log sum, 5.0f0) == NoLogAbsDetJacobian{typeof(log ∘ sum),Float32}()
21+
@test with_logabsdet_jacobian(sum, rand(5)) === NoLogAbsDetJacobian(sum, rand(5))
22+
@test with_logabsdet_jacobian(log sum, 5.0f0) === NoLogAbsDetJacobian(log sum, 5.0f0)
2423
@test_throws MethodError _, _ = with_logabsdet_jacobian(sum, rand(5))
2524

25+
@test with_logabsdet_jacobian(sin, 4.9) === NoLogAbsDetJacobian{typeof(sin), Float64}()
26+
@test with_logabsdet_jacobian(String, 4.9) === NoLogAbsDetJacobian{Type{String}, Float64}()
27+
@test with_logabsdet_jacobian(String, Float64) === NoLogAbsDetJacobian{Type{String}, Type{Float64}}()
28+
@test with_logabsdet_jacobian(sin, Float64) === NoLogAbsDetJacobian{typeof(sin), Type{Float64}}()
29+
30+
@test with_logabsdet_jacobian(sin log, 4.9) === NoLogAbsDetJacobian{typeof(sin ∘ log), Float64}()
31+
@test with_logabsdet_jacobian(log sin, 4.9) === NoLogAbsDetJacobian{typeof(log ∘ sin), Float64}()
32+
2633
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(foo), x)
2734
y = foo(x)
2835
ladj = -x + 2 * log(y)

0 commit comments

Comments
 (0)