diff --git a/src/lib/grad.jl b/src/lib/grad.jl index 92be71d34..6861272de 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -152,7 +152,7 @@ julia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2) ``` !!! warning - For arguments of any type except `Number` & `AbstractArray`, the result is `nothing`. + For arguments of any type except `Number` & `AbstractArray{<:Number}`, the result is `nothing`. ``` julia> jacobian((a,s) -> a.^length(s), [1,2,3], "str") @@ -181,7 +181,7 @@ function withjacobian(f, args...) y, back = pullback(_jvec∘f, args...) out = map(args) do x T = promote_type(eltype(x), eltype(y)) - dx = x isa AbstractArray ? similar(x, T, length(y), length(x)) : + dx = x isa AbstractArray{<:Number} ? similar(x, T, length(y), length(x)) : x isa Number ? similar(y, T, length(y)) : nothing end @@ -196,7 +196,7 @@ function withjacobian(f, args...) (val=y, grad=out) end -_jvec(x::AbstractArray) = vec(x) +_jvec(x::AbstractArray{<:Number}) = vec(x) _jvec(x::Number) = _jvec(vcat(x)) _jvec(x) = throw(ArgumentError("jacobian expected a function which returns an array, or a scalar, got $(typeof(x))")) _jvec(x::AbstractArray{<:Complex}) = throw(ArgumentError("jacobian does not accept complex output")) diff --git a/test/gradcheck_p1_tests.jl b/test/gradcheck_p1_tests.jl index a967979b9..c78da0737 100644 --- a/test/gradcheck_p1_tests.jl +++ b/test/gradcheck_p1_tests.jl @@ -190,8 +190,8 @@ end @test gradient(p -> sum(collect(p*i for i in Iterators.take([1.0, 2.0, 3.0], 3))), 2.0) == (6.0,) @test gradient(p -> sum(collect(p*i for i in Iterators.take(p*[1.0, 2.0, 3.0], 2))), 2.0) == (12.0,) # generator 0-d behavior handled incorrectly - @test_broken gradient(p -> sum(collect(p*i for i in 1.0)), 2.0) - @test_broken gradient(p -> sum(collect(p*i for i in fill(1.0))), 2.0) + @test_broken gradient(p -> sum(collect(p*i for i in 1.0)), 2.0) == (1.0,) + @test gradient(p -> sum(collect(p*i for i in fill(1.0))), 2.0) == (1.0,) # adjoints for iterators @test gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 4))), 1.0) == (10.0,)