Skip to content

Commit 4ecd6a3

Browse files
oschulzdevmotion
andauthored
Use map instead of broadcast(ed) in for LADJ aggregation
Co-authored-by: David Widmann <[email protected]>
1 parent d819f29 commit 4ecd6a3

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/with_ladj.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ end
8282

8383
function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj) where {F<:Union{typeof(map),typeof(broadcast)}}
8484
y = map_or_bc(first, y_with_ladj)
85-
ladj = sum(Broadcast.instantiate(Broadcast.broadcasted(last, y_with_ladj)))
85+
ladj = sum(last, y_with_ladj)
8686
(y, ladj)
8787
end
8888

@@ -92,7 +92,7 @@ end
9292
struct WithLadjOnMappedPullback{YLT} <: Function end
9393
function (::WithLadjOnMappedPullback{YLT})(thunked_ΔΩ) where YLT
9494
ys, ladj = unthunk(thunked_ΔΩ)
95-
return NoTangent(), NoTangent(), broadcast((y, l) -> Tangent{YLT}(y, l), ys, ladj)
95+
return NoTangent(), NoTangent(), map(y -> Tangent{YLT}(y, ladj), ys)
9696
end
9797

9898
function ChainRulesCore.rrule(::typeof(_with_ladj_on_mapped), map_or_bc::F, y_with_ladj) where {F<:Union{typeof(map),typeof(broadcast)}}

0 commit comments

Comments
 (0)