Skip to content

Commit 8cddb2e

Browse files
committed
Fix _with_ladj_on_mapped
Was causing trouble with broadcast over scalars (that returns a scalar).
1 parent 8f78076 commit 8cddb2e

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/with_ladj.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ with_logabsdet_jacobian(f, x) = NoLogAbsDetJacobian(f, x)
117117
end
118118

119119

120+
function _with_ladj_on_mapped(@nospecialize(map_or_bc::F), y_with_ladj::NoLogAbsDetJacobian) where {F<:Union{typeof(map),typeof(broadcast)}}
121+
return y_with_ladj
122+
end
123+
120124
function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj::Tuple{Any,Real}) where {F<:Union{typeof(map),typeof(broadcast)}}
121125
return y_with_ladj
122126
end

test/test_with_ladj.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ include("getjacobian.jl")
3030
@test with_logabsdet_jacobian(sin log, 4.9) === NoLogAbsDetJacobian{typeof(sin ∘ log), Float64}()
3131
@test with_logabsdet_jacobian(log sin, 4.9) === NoLogAbsDetJacobian{typeof(log ∘ sin), Float64}()
3232

33+
@test with_logabsdet_jacobian(Base.Fix1(broadcast, sin), 4.9) === NoLogAbsDetJacobian{typeof(sin), Float64}()
34+
3335
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(foo), x)
3436
y = foo(x)
3537
ladj = -x + 2 * log(y)

0 commit comments

Comments
 (0)