diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 6c6e6cc10..212b4af67 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -347,7 +347,9 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context,C}, ) where {F,C} - dx = map(CartesianIndices(x)) do j + ind = CartesianIndices(x) + T = typeof(similar(x, eltype(ind))) + dx = map(x, T(ind)) do xj, j a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...)) dot(a, dy) end @@ -362,7 +364,9 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context,C}, ) where {F,C} - dx = map(CartesianIndices(x)) do j + ind = CartesianIndices(x) + T = typeof(similar(x, eltype(ind))) + dx = map(x, T(ind)) do xj, j a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...)) b = only( pushforward(f, pushforward_prep, backend, x, (im * basis(x, j),), contexts...), @@ -472,7 +476,9 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context,C}, ) where {F,C} - dx = map(CartesianIndices(x)) do j # preserve shape + ind = CartesianIndices(x) + T = typeof(similar(x, eltype(ind))) + dx = map(x, T(ind)) do xj, j # preserve shape a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...)) dot(a, dy) end @@ -488,7 +494,9 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context,C}, ) where {F,C} - dx = map(CartesianIndices(x)) do j # preserve shape + ind = CartesianIndices(x) + T = typeof(similar(x, eltype(ind))) + dx = map(x, T(ind)) do xj, j # preserve shape a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...)) b = only( pushforward( diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 46d249d67..09b4b1b45 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -351,7 +351,9 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context,C}, ) where {F,C} - dy = map(CartesianIndices(y)) do i + ind = CartesianIndices(y) + T = typeof(similar(y, eltype(ind))) + dy = map(y, T(ind)) do yi, i a = only(pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...)) dot(a, dx) end @@ -367,7 +369,9 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context,C}, ) where {F,C} - dy = map(CartesianIndices(y)) do i + ind = CartesianIndices(y) + T = typeof(similar(y, eltype(ind))) + dy = map(y, T(ind)) do yi, i a = only(pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...)) b = only(pullback(f, pullback_prep, backend, x, (im * basis(y, i),), contexts...)) real(dot(a, dx)) + im * real(dot(b, dx)) @@ -444,7 +448,9 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context,C}, ) where {F,C} - dy = map(CartesianIndices(y)) do i # preserve shape + ind = CartesianIndices(y) + T = typeof(similar(y, eltype(ind))) + dy = map(y, T(ind)) do yi, i # preserve shape a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...)) dot(a, dx) end @@ -460,7 +466,9 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context,C}, ) where {F,C} - dy = map(CartesianIndices(y)) do i # preserve shape + ind = CartesianIndices(y) + T = typeof(similar(y, eltype(ind))) + dy = map(y, T(ind)) do yi, i # preserve shape a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...)) b = only( pullback(f!, y, pullback_prep, backend, x, (im * basis(y, i),), contexts...)