Skip to content

Commit de037ca

Browse files
oschulzdevmotion
andauthored
Improve _with_ladj_on_mapped pullback implementation
Co-authored-by: David Widmann <[email protected]>
1 parent 5ef880b commit de037ca

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/with_ladj.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,20 +81,20 @@ end
8181

8282
_with_ladj_on_mapped(map_or_bc::Function, y_with_ladj::Tuple{Any,Real}) = y_with_ladj
8383

84-
function _with_ladj_on_mapped(map_or_bc::Function, y_with_ladj)
85-
y = map_or_bc(_get_y, y_with_ladj)
86-
ladj = sum(map_or_bc(_get_ladj, y_with_ladj))
84+
function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj) where {F<:Union{typeof(map),typeof(broadcast)}}
85+
y = map_or_bc(first, y_with_ladj)
86+
ladj = sum(map_or_bc(last, y_with_ladj))
8787
#ladj = sum(_get_ladj, y_with_ladj)
8888
(y, ladj)
8989
end
9090

9191
function _with_ladj_on_mapped_pullback(thunked_ΔΩ)
92-
ys, ladj = ChainRulesCore.unthunk(thunked_ΔΩ)
93-
NoTangent(), NoTangent(), broadcast(x -> (x, ladj), ys)
92+
ys, ladj = unthunk(thunked_ΔΩ)
93+
return NoTangent(), NoTangent(), tuple.(ys, ladj)
9494
end
9595

96-
function ChainRulesCore.rrule(::typeof(ChangesOfVariables._with_ladj_on_mapped), map_or_bc::Function, y_with_ladj)
97-
return ChangesOfVariables._with_ladj_on_mapped(map_or_bc, y_with_ladj), _with_ladj_on_mapped_pullback
96+
function ChainRulesCore.rrule(::typeof(_with_ladj_on_mapped), map_or_bc::Function, y_with_ladj)
97+
return _with_ladj_on_mapped(map_or_bc, y_with_ladj), _with_ladj_on_mapped_pullback
9898
end
9999

100100
function with_logabsdet_jacobian(mapped_f::Base.Fix1{<:Union{typeof(map),typeof(broadcast)}}, X)

0 commit comments

Comments
 (0)