Skip to content

Commit 9c2c05e

Browse files
committed
Make sure that ȳ's output is a compatible float type
1 parent aa4e64a commit 9c2c05e

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/back.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,9 @@ Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corr
173173
"""
174174
function jacobian(f, x::AbstractVector)
175175
y::AbstractVector, back = forward(f, x)
176+
z = float(zero(eltype(data(y))))
176177
# Using broadcasting so that output of `ȳ` is a GPU array if `y` is so:
177-
(i) = ((j, _) -> i == j).(1:length(y), y)
178+
(i) = ((j, _) -> i == j).(1:length(y), y) .+ z
178179
vcat([transpose(back((i))[1]) for i = 1:length(y)]...)
179180
end
180181

0 commit comments

Comments
 (0)