Skip to content

Commit 5179a6f

Browse files
Merge pull request #23 from longemen3000/patch-2
Suppor more array shapes in `threaded_gradient!`
2 parents 5ca9194 + 0af7651 commit 5179a6f

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaSIMD.github.io/Polyester.jl/stable)
44
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://JuliaSIMD.github.io/Polyester.jl/dev)
55
[![CI](https://github.com/JuliaDiff/PolyesterForwardDiff.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/JuliaDiff/PolyesterForwardDiff.jl/actions/workflows/CI.yml)
6-
[![CI-Nightly](https://github.com/JuliaDiff/PolyesterForwardDiff.jl/actions/workflows/CI-julia-nightly.yml/badge.svg)](https://github.com/JuliaDiff/PolyesterForwardDiff.jl/actions/workflows/CI-julia-nightly.yml)
76
[![Coverage](https://codecov.io/gh/JuliaDiff/PolyesterForwardDiff.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaDiff/PolyesterForwardDiff.jl)
87

98

src/PolyesterForwardDiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function evaluate_chunks!(f::F, (r,Δx,x), start, stop, ::ForwardDiff.Chunk{C},
4949
end
5050
end
5151

52-
function threaded_gradient!(f::F, Δx::AbstractVector, x::AbstractVector, ::ForwardDiff.Chunk{C}, check = Val{false}()) where {F,C}
52+
function threaded_gradient!(f::F, Δx::AbstractArray, x::AbstractArray, ::ForwardDiff.Chunk{C}, check = Val{false}()) where {F,C}
5353
N = length(x)
5454
d = cld_fast(N, C)
5555
r = Ref{eltype(Δx)}()

test/runtests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,12 @@ ForwardDiff.jacobian!(dxref, g!, yref, x, ForwardDiff.JacobianConfig(g!, yref, x
3737
PolyesterForwardDiff.threaded_jacobian!(g!, y, dx, x, ForwardDiff.Chunk(8),Val{true}());
3838
@test dx dxref
3939
@test y yref
40+
41+
42+
X = randn(10,80);
43+
dXref = similar(X);
44+
dX = similar(X);
45+
ForwardDiff.gradient!(dXref, f, X, ForwardDiff.GradientConfig(f, X, ForwardDiff.Chunk(8), nothing));
46+
PolyesterForwardDiff.threaded_gradient!(f, dX, X, ForwardDiff.Chunk(8));
47+
48+
@test dX dXref

0 commit comments

Comments
 (0)